Skip to main content

heliosdb_proxy/schema_routing/
admin.rs

1//! Schema Routing Admin API
2//!
3//! REST API endpoints for managing schema-aware routing.
4
5use serde::{Deserialize, Serialize};
6use std::sync::atomic::Ordering;
7use std::sync::Arc;
8
9use super::registry::{ColumnSchema, StorageType, TableSchema};
10use super::router::SchemaAwareRouter;
11use super::{
12    AccessPattern, DataTemperature, LearningClassifier, SchemaDiscovery, SchemaRegistry,
13    SchemaRoutingMetrics, WorkloadType,
14};
15
16/// Admin API for schema routing
17pub struct SchemaRoutingAdmin {
18    pub registry: Arc<SchemaRegistry>,
19    pub router: Arc<SchemaAwareRouter>,
20    pub classifier: Arc<LearningClassifier>,
21    pub discovery: Arc<SchemaDiscovery>,
22    pub metrics: Arc<SchemaRoutingMetrics>,
23}
24
25impl SchemaRoutingAdmin {
26    /// Create a new admin API instance
27    pub fn new(
28        registry: Arc<SchemaRegistry>,
29        router: Arc<SchemaAwareRouter>,
30        classifier: Arc<LearningClassifier>,
31        discovery: Arc<SchemaDiscovery>,
32        metrics: Arc<SchemaRoutingMetrics>,
33    ) -> Self {
34        Self {
35            registry,
36            router,
37            classifier,
38            discovery,
39            metrics,
40        }
41    }
42
43    // =========================================================================
44    // TABLE ENDPOINTS
45    // =========================================================================
46
47    /// GET /schema/tables - List all registered tables
48    pub fn list_tables(&self) -> TablesResponse {
49        let tables = self.registry.list_tables();
50        TablesResponse {
51            tables: tables
52                .into_iter()
53                .map(|t| TableSummary {
54                    name: t.name.clone(),
55                    temperature: format!("{:?}", t.temperature),
56                    workload: format!("{:?}", t.workload),
57                    access_pattern: format!("{:?}", t.access_pattern),
58                    column_count: t.columns.len(),
59                    shard_key: t.shard_key.clone(),
60                    row_count_estimate: Some(t.estimated_rows),
61                })
62                .collect(),
63            total: self.registry.table_count(),
64        }
65    }
66
67    /// GET /schema/tables/:name - Get details for a specific table
68    pub fn get_table(&self, name: &str) -> Option<TableDetails> {
69        self.registry.get_table(name).map(|t| TableDetails {
70            name: t.name.clone(),
71            columns: t
72                .columns
73                .iter()
74                .map(|c| ColumnDetails {
75                    name: c.name.clone(),
76                    data_type: c.data_type.clone(),
77                    nullable: c.nullable,
78                    is_primary_key: c.is_primary_key,
79                    is_indexed: c.is_indexed,
80                    default_value: None, // ColumnSchema doesn't have default_value
81                    storage_type: Some(format!("{:?}", c.storage_type)),
82                })
83                .collect(),
84            temperature: format!("{:?}", t.temperature),
85            workload: format!("{:?}", t.workload),
86            access_pattern: format!("{:?}", t.access_pattern),
87            primary_key: t.primary_key.clone(),
88            shard_key: t.shard_key.clone(),
89            row_count_estimate: Some(t.estimated_rows),
90            size_bytes: Some(t.avg_row_size as u64 * t.estimated_rows),
91            partition_key: t.partition_key.as_ref().map(|p| format!("{:?}", p)),
92        })
93    }
94
95    /// POST /schema/tables - Register a new table
96    pub fn register_table(
97        &self,
98        request: RegisterTableRequest,
99    ) -> Result<TableDetails, AdminError> {
100        let temperature = DataTemperature::from_str(&request.temperature).ok_or_else(|| {
101            AdminError::InvalidInput(format!("Invalid temperature: {}", request.temperature))
102        })?;
103
104        let workload = WorkloadType::from_str(&request.workload).ok_or_else(|| {
105            AdminError::InvalidInput(format!("Invalid workload: {}", request.workload))
106        })?;
107
108        let access_pattern = parse_access_pattern(&request.access_pattern).ok_or_else(|| {
109            AdminError::InvalidInput(format!(
110                "Invalid access pattern: {}",
111                request.access_pattern
112            ))
113        })?;
114
115        let columns: Vec<ColumnSchema> = request
116            .columns
117            .iter()
118            .map(|c| ColumnSchema {
119                name: c.name.clone(),
120                data_type: c.data_type.clone(),
121                nullable: c.nullable,
122                is_primary_key: c.is_primary_key,
123                is_indexed: c.is_indexed.unwrap_or(false),
124                storage_type: StorageType::Row,
125            })
126            .collect();
127
128        let table = TableSchema {
129            name: request.name.clone(),
130            columns,
131            access_pattern,
132            temperature,
133            workload,
134            primary_key: request.primary_key.clone(),
135            shard_key: request.shard_key.clone(),
136            estimated_rows: request.row_count_estimate.unwrap_or(0),
137            avg_row_size: 0,
138            partition_key: None,
139            preferred_nodes: Vec::new(),
140        };
141
142        self.registry.register_table(table);
143
144        self.get_table(&request.name)
145            .ok_or_else(|| AdminError::InternalError("Failed to register table".to_string()))
146    }
147
148    /// DELETE /schema/tables/:name - Remove a table from routing
149    pub fn remove_table(&self, name: &str) -> Result<(), AdminError> {
150        if self.registry.get_table(name).is_none() {
151            return Err(AdminError::NotFound(format!("Table not found: {}", name)));
152        }
153        self.registry.remove_table(name);
154        Ok(())
155    }
156
157    // =========================================================================
158    // CLASSIFICATION ENDPOINTS
159    // =========================================================================
160
161    /// POST /schema/classify - Manually classify a table
162    pub fn classify_table(
163        &self,
164        request: ClassifyRequest,
165    ) -> Result<ClassificationResult, AdminError> {
166        let temperature = DataTemperature::from_str(&request.temperature).ok_or_else(|| {
167            AdminError::InvalidInput(format!("Invalid temperature: {}", request.temperature))
168        })?;
169
170        let workload = WorkloadType::from_str(&request.workload).ok_or_else(|| {
171            AdminError::InvalidInput(format!("Invalid workload: {}", request.workload))
172        })?;
173
174        // Get existing table
175        let mut table = self
176            .registry
177            .get_table(&request.table_name)
178            .ok_or_else(|| {
179                AdminError::NotFound(format!("Table not found: {}", request.table_name))
180            })?;
181
182        // Update classifications
183        let old_temperature = table.temperature;
184        let old_workload = table.workload;
185
186        table.temperature = temperature;
187        table.workload = workload;
188
189        // Re-register with new classification
190        self.registry.register_table(table);
191
192        Ok(ClassificationResult {
193            table_name: request.table_name,
194            previous_temperature: format!("{:?}", old_temperature),
195            new_temperature: format!("{:?}", temperature),
196            previous_workload: format!("{:?}", old_workload),
197            new_workload: format!("{:?}", workload),
198        })
199    }
200
201    /// GET /schema/classify/:table - Get classifier suggestions
202    pub fn get_classification_suggestion(
203        &self,
204        table_name: &str,
205    ) -> Result<ClassificationSuggestion, AdminError> {
206        // Get history from classifier
207        let history = self.classifier.get_history(table_name);
208
209        if history.is_none() {
210            return Err(AdminError::NotFound(format!(
211                "No query history for table: {}",
212                table_name
213            )));
214        }
215
216        let hist = history.expect("history checked above");
217        let query_count = hist.count();
218        let suggested_temp = self.classifier.suggest_temperature(table_name);
219        let suggested_workload = self.classifier.suggest_workload(table_name);
220        let confidence = self.classifier.get_confidence(table_name);
221
222        Ok(ClassificationSuggestion {
223            table_name: table_name.to_string(),
224            query_count,
225            suggested_temperature: suggested_temp.map(|t| format!("{:?}", t)),
226            suggested_workload: suggested_workload.map(|w| format!("{:?}", w)),
227            confidence: confidence.unwrap_or(0.0),
228            sample_size_sufficient: query_count >= 100,
229        })
230    }
231
232    // =========================================================================
233    // ANALYSIS ENDPOINTS
234    // =========================================================================
235
236    /// POST /schema/analyze - Analyze a query
237    pub fn analyze_query(&self, request: AnalyzeRequest) -> AnalysisResult {
238        use super::QueryAnalyzer;
239
240        let query = request.query;
241        let analyzer = QueryAnalyzer::new(self.registry.clone());
242        let analysis = analyzer.analyze(&query);
243
244        // Get primary access pattern from the list
245        let access_pattern = analysis
246            .access_patterns
247            .first()
248            .map(|p| format!("{:?}", p))
249            .unwrap_or_else(|| "Mixed".to_string());
250
251        let detected_workload = self
252            .classifier
253            .classify_query(&query)
254            .map(|w| format!("{:?}", w));
255
256        AnalysisResult {
257            query,
258            tables: analysis.tables.iter().map(|t| t.name.clone()).collect(),
259            access_pattern,
260            shard_keys: analysis
261                .shard_keys
262                .iter()
263                .map(|(k, v)| format!("{}={:?}", k, v))
264                .collect(),
265            is_read_only: analysis.is_read_only,
266            estimated_complexity: analysis.complexity,
267            estimated_selectivity: analysis.selectivity,
268            has_aggregation: analysis.has_aggregations,
269            has_join: analysis.has_joins,
270            has_subquery: analysis.has_subqueries,
271            columns: Vec::new(), // Not available in QueryAnalysis
272            detected_workload,
273        }
274    }
275
276    /// POST /schema/route - Get routing decision for a query (dry-run)
277    pub fn route_query(&self, request: RouteRequest) -> RouteResult {
278        let decision = self.router.route(&request.query);
279
280        RouteResult {
281            query: request.query,
282            target_type: format!("{:?}", decision.target),
283            reason: format!("{:?}", decision.reason),
284            preferred_node: decision.node_info.as_ref().map(|n| n.name.clone()),
285            alternative_nodes: Vec::new(), // Not available in current RoutingDecision
286            estimated_latency_ms: decision.node_info.as_ref().map(|n| n.current_latency_ms),
287        }
288    }
289
290    // =========================================================================
291    // ROUTING STATS ENDPOINTS
292    // =========================================================================
293
294    /// GET /schema/stats - Get overall routing statistics
295    pub fn get_stats(&self) -> RoutingStatsResponse {
296        let stats = self.metrics.get_routing_stats();
297
298        RoutingStatsResponse {
299            total_queries_routed: stats.total_queries.load(Ordering::Relaxed),
300            queries_to_primary: stats.primary_routes.load(Ordering::Relaxed),
301            queries_to_replica: stats.replica_routes.load(Ordering::Relaxed),
302            queries_scattered: stats.scatter_gather.load(Ordering::Relaxed),
303            avg_latency_ms: 0.0, // Not tracked globally in RoutingStats
304            cache_hit_rate: stats.classification_hit_rate(),
305        }
306    }
307
308    /// GET /schema/stats/tables - Get per-table statistics
309    pub fn get_table_stats(&self) -> Vec<TableStatsResponse> {
310        let stats = self.metrics.get_table_stats_for_admin();
311
312        stats
313            .into_iter()
314            .map(|(name, s)| TableStatsResponse {
315                table_name: name,
316                query_count: s.query_count,
317                avg_latency_ms: s.avg_latency_ms,
318                hit_rate: s.cache_hit_rate,
319                temperature: format!("{:?}", s.temperature),
320                workload: format!("{:?}", s.workload),
321            })
322            .collect()
323    }
324
325    /// GET /schema/stats/workloads - Get per-workload statistics
326    pub fn get_workload_stats(&self) -> Vec<WorkloadStatsResponse> {
327        let stats = self.metrics.get_workload_stats_for_admin();
328
329        stats
330            .into_iter()
331            .map(|(workload, s)| WorkloadStatsResponse {
332                workload: format!("{:?}", workload),
333                query_count: s.query_count,
334                avg_latency_ms: s.avg_latency_ms,
335                queries_to_primary: s.queries_to_primary,
336                queries_to_replica: s.queries_to_replica,
337            })
338            .collect()
339    }
340
341    // =========================================================================
342    // DISCOVERY ENDPOINTS
343    // =========================================================================
344
345    /// POST /schema/discover - Trigger schema discovery
346    pub async fn trigger_discovery(&self) -> Result<DiscoveryResult, AdminError> {
347        let tables = self
348            .discovery
349            .discover()
350            .await
351            .map_err(|e| AdminError::DiscoveryError(e.to_string()))?;
352
353        // Register discovered tables
354        for table in &tables {
355            self.registry.register_table(table.clone());
356        }
357
358        Ok(DiscoveryResult {
359            tables_discovered: tables.len(),
360            table_names: tables.iter().map(|t| t.name.clone()).collect(),
361        })
362    }
363
364    /// POST /schema/refresh - Refresh schema cache
365    pub async fn refresh_schema(&self) -> Result<RefreshResult, AdminError> {
366        self.discovery
367            .refresh()
368            .await
369            .map_err(|e| AdminError::DiscoveryError(e.to_string()))?;
370
371        Ok(RefreshResult {
372            success: true,
373            message: "Schema cache refreshed successfully".to_string(),
374        })
375    }
376
377    // =========================================================================
378    // AI/AGENT ENDPOINTS
379    // =========================================================================
380
381    /// GET /schema/ai/workloads - Get AI workload statistics
382    pub fn get_ai_workload_stats(&self) -> AIWorkloadStatsResponse {
383        let stats = self.metrics.get_ai_workload_stats();
384
385        AIWorkloadStatsResponse {
386            embedding_queries: stats.embedding_retrieval_count,
387            context_lookups: stats.context_lookup_count,
388            knowledge_base_queries: stats.knowledge_base_count,
389            tool_executions: stats.tool_execution_count,
390            total_ai_queries: stats.total_ai_queries(),
391            avg_vector_dimensions: stats.avg_vector_dimensions,
392        }
393    }
394
395    /// GET /schema/rag/stats - Get RAG pipeline statistics
396    pub fn get_rag_stats(&self) -> RAGStatsResponse {
397        let stats = self.metrics.get_rag_stats_for_admin();
398
399        RAGStatsResponse {
400            retrieval_count: stats.retrieval_count,
401            avg_retrieval_latency_ms: stats.avg_retrieval_latency_ms,
402            fetch_count: stats.fetch_count,
403            avg_fetch_latency_ms: stats.avg_fetch_latency_ms,
404            total_pipeline_executions: stats.total_pipeline_executions,
405            avg_total_latency_ms: stats.avg_total_latency_ms,
406        }
407    }
408}
409
410// =============================================================================
411// REQUEST/RESPONSE TYPES
412// =============================================================================
413
414#[derive(Debug, Serialize)]
415pub struct TablesResponse {
416    pub tables: Vec<TableSummary>,
417    pub total: usize,
418}
419
420#[derive(Debug, Serialize)]
421pub struct TableSummary {
422    pub name: String,
423    pub temperature: String,
424    pub workload: String,
425    pub access_pattern: String,
426    pub column_count: usize,
427    pub shard_key: Option<String>,
428    pub row_count_estimate: Option<u64>,
429}
430
431#[derive(Debug, Serialize)]
432pub struct TableDetails {
433    pub name: String,
434    pub columns: Vec<ColumnDetails>,
435    pub temperature: String,
436    pub workload: String,
437    pub access_pattern: String,
438    pub primary_key: Vec<String>,
439    pub shard_key: Option<String>,
440    pub row_count_estimate: Option<u64>,
441    pub size_bytes: Option<u64>,
442    pub partition_key: Option<String>,
443}
444
445#[derive(Debug, Serialize)]
446pub struct ColumnDetails {
447    pub name: String,
448    pub data_type: String,
449    pub nullable: bool,
450    pub is_primary_key: bool,
451    pub is_indexed: bool,
452    pub default_value: Option<String>,
453    pub storage_type: Option<String>,
454}
455
456#[derive(Debug, Deserialize)]
457pub struct RegisterTableRequest {
458    pub name: String,
459    pub columns: Vec<ColumnRequest>,
460    pub temperature: String,
461    pub workload: String,
462    pub access_pattern: String,
463    pub primary_key: Vec<String>,
464    pub shard_key: Option<String>,
465    pub row_count_estimate: Option<u64>,
466}
467
468#[derive(Debug, Deserialize)]
469pub struct ColumnRequest {
470    pub name: String,
471    pub data_type: String,
472    pub nullable: bool,
473    pub is_primary_key: bool,
474    pub is_indexed: Option<bool>,
475    pub default_value: Option<String>,
476}
477
478#[derive(Debug, Deserialize)]
479pub struct ClassifyRequest {
480    pub table_name: String,
481    pub temperature: String,
482    pub workload: String,
483}
484
485#[derive(Debug, Serialize)]
486pub struct ClassificationResult {
487    pub table_name: String,
488    pub previous_temperature: String,
489    pub new_temperature: String,
490    pub previous_workload: String,
491    pub new_workload: String,
492}
493
494#[derive(Debug, Serialize)]
495pub struct ClassificationSuggestion {
496    pub table_name: String,
497    pub query_count: u64,
498    pub suggested_temperature: Option<String>,
499    pub suggested_workload: Option<String>,
500    pub confidence: f64,
501    pub sample_size_sufficient: bool,
502}
503
504#[derive(Debug, Deserialize)]
505pub struct AnalyzeRequest {
506    pub query: String,
507}
508
509#[derive(Debug, Serialize)]
510pub struct AnalysisResult {
511    pub query: String,
512    pub tables: Vec<String>,
513    pub access_pattern: String,
514    pub shard_keys: Vec<String>,
515    pub is_read_only: bool,
516    pub estimated_complexity: u32,
517    pub estimated_selectivity: f64,
518    pub has_aggregation: bool,
519    pub has_join: bool,
520    pub has_subquery: bool,
521    pub columns: Vec<String>,
522    pub detected_workload: Option<String>,
523}
524
525#[derive(Debug, Deserialize)]
526pub struct RouteRequest {
527    pub query: String,
528}
529
530#[derive(Debug, Serialize)]
531pub struct RouteResult {
532    pub query: String,
533    pub target_type: String,
534    pub reason: String,
535    pub preferred_node: Option<String>,
536    pub alternative_nodes: Vec<String>,
537    pub estimated_latency_ms: Option<u64>,
538}
539
540#[derive(Debug, Serialize)]
541pub struct RoutingStatsResponse {
542    pub total_queries_routed: u64,
543    pub queries_to_primary: u64,
544    pub queries_to_replica: u64,
545    pub queries_scattered: u64,
546    pub avg_latency_ms: f64,
547    pub cache_hit_rate: f64,
548}
549
550#[derive(Debug, Serialize)]
551pub struct TableStatsResponse {
552    pub table_name: String,
553    pub query_count: u64,
554    pub avg_latency_ms: f64,
555    pub hit_rate: f64,
556    pub temperature: String,
557    pub workload: String,
558}
559
560#[derive(Debug, Serialize)]
561pub struct WorkloadStatsResponse {
562    pub workload: String,
563    pub query_count: u64,
564    pub avg_latency_ms: f64,
565    pub queries_to_primary: u64,
566    pub queries_to_replica: u64,
567}
568
569#[derive(Debug, Serialize)]
570pub struct DiscoveryResult {
571    pub tables_discovered: usize,
572    pub table_names: Vec<String>,
573}
574
575#[derive(Debug, Serialize)]
576pub struct RefreshResult {
577    pub success: bool,
578    pub message: String,
579}
580
581#[derive(Debug, Serialize)]
582pub struct AIWorkloadStatsResponse {
583    pub embedding_queries: u64,
584    pub context_lookups: u64,
585    pub knowledge_base_queries: u64,
586    pub tool_executions: u64,
587    pub total_ai_queries: u64,
588    pub avg_vector_dimensions: f64,
589}
590
591#[derive(Debug, Serialize)]
592pub struct RAGStatsResponse {
593    pub retrieval_count: u64,
594    pub avg_retrieval_latency_ms: f64,
595    pub fetch_count: u64,
596    pub avg_fetch_latency_ms: f64,
597    pub total_pipeline_executions: u64,
598    pub avg_total_latency_ms: f64,
599}
600
601// =============================================================================
602// ERRORS
603// =============================================================================
604
605#[derive(Debug)]
606pub enum AdminError {
607    NotFound(String),
608    InvalidInput(String),
609    DiscoveryError(String),
610    InternalError(String),
611}
612
613impl std::fmt::Display for AdminError {
614    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
615        match self {
616            Self::NotFound(msg) => write!(f, "Not found: {}", msg),
617            Self::InvalidInput(msg) => write!(f, "Invalid input: {}", msg),
618            Self::DiscoveryError(msg) => write!(f, "Discovery error: {}", msg),
619            Self::InternalError(msg) => write!(f, "Internal error: {}", msg),
620        }
621    }
622}
623
624impl std::error::Error for AdminError {}
625
626// =============================================================================
627// HELPER FUNCTIONS
628// =============================================================================
629
630fn parse_access_pattern(s: &str) -> Option<AccessPattern> {
631    match s.to_uppercase().as_str() {
632        "POINTLOOKUP" | "POINT_LOOKUP" => Some(AccessPattern::PointLookup),
633        "RANGESCAN" | "RANGE_SCAN" => Some(AccessPattern::RangeScan),
634        "FULLSCAN" | "FULL_SCAN" => Some(AccessPattern::FullScan),
635        "VECTORSEARCH" | "VECTOR_SEARCH" => Some(AccessPattern::VectorSearch),
636        "TIMESERIESAPPEND" | "TIME_SERIES_APPEND" => Some(AccessPattern::TimeSeriesAppend),
637        "MIXED" => Some(AccessPattern::Mixed),
638        _ => None,
639    }
640}
641
642#[cfg(test)]
643mod tests {
644    use super::*;
645
646    #[test]
647    fn test_parse_access_pattern() {
648        assert_eq!(
649            parse_access_pattern("PointLookup"),
650            Some(AccessPattern::PointLookup)
651        );
652        assert_eq!(
653            parse_access_pattern("POINT_LOOKUP"),
654            Some(AccessPattern::PointLookup)
655        );
656        assert_eq!(
657            parse_access_pattern("RangeScan"),
658            Some(AccessPattern::RangeScan)
659        );
660        assert_eq!(
661            parse_access_pattern("VectorSearch"),
662            Some(AccessPattern::VectorSearch)
663        );
664        assert_eq!(parse_access_pattern("Mixed"), Some(AccessPattern::Mixed));
665        assert_eq!(parse_access_pattern("Invalid"), None);
666    }
667
668    #[test]
669    fn test_admin_error_display() {
670        let err = AdminError::NotFound("users".to_string());
671        assert!(err.to_string().contains("Not found"));
672
673        let err = AdminError::InvalidInput("bad temp".to_string());
674        assert!(err.to_string().contains("Invalid input"));
675    }
676}