1use std::collections::VecDeque;
7use std::sync::atomic::{AtomicU64, Ordering};
8use std::time::Duration;
9
10use dashmap::DashMap;
11use parking_lot::RwLock;
12
13use super::statistics::QueryExecution;
14use super::{CostReport, UserCost, AgentCost};
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
18pub enum QueryIntent {
19 Retrieval,
21
22 Storage,
24
25 Embedding,
27
28 Schema,
30
31 Transaction,
33
34 Utility,
36
37 RagRetrieval,
39
40 RagIndexing,
42
43 AgentMemory,
45
46 Unknown,
48}
49
50impl std::fmt::Display for QueryIntent {
51 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
52 match self {
53 QueryIntent::Retrieval => write!(f, "retrieval"),
54 QueryIntent::Storage => write!(f, "storage"),
55 QueryIntent::Embedding => write!(f, "embedding"),
56 QueryIntent::Schema => write!(f, "schema"),
57 QueryIntent::Transaction => write!(f, "transaction"),
58 QueryIntent::Utility => write!(f, "utility"),
59 QueryIntent::RagRetrieval => write!(f, "rag_retrieval"),
60 QueryIntent::RagIndexing => write!(f, "rag_indexing"),
61 QueryIntent::AgentMemory => write!(f, "agent_memory"),
62 QueryIntent::Unknown => write!(f, "unknown"),
63 }
64 }
65}
66
67pub struct QueryClassifier {
69 embedding_tables: Vec<String>,
71
72 rag_tables: Vec<String>,
74
75 memory_tables: Vec<String>,
77}
78
79impl QueryClassifier {
80 pub fn new() -> Self {
82 Self {
83 embedding_tables: vec![
84 "embeddings".to_string(),
85 "vectors".to_string(),
86 "embedding".to_string(),
87 "vector_store".to_string(),
88 ],
89 rag_tables: vec![
90 "documents".to_string(),
91 "chunks".to_string(),
92 "doc_chunks".to_string(),
93 "knowledge_base".to_string(),
94 "context".to_string(),
95 ],
96 memory_tables: vec![
97 "memory".to_string(),
98 "agent_memory".to_string(),
99 "conversation_history".to_string(),
100 "chat_history".to_string(),
101 "sessions".to_string(),
102 ],
103 }
104 }
105
106 pub fn with_patterns(
108 embedding_tables: Vec<String>,
109 rag_tables: Vec<String>,
110 memory_tables: Vec<String>,
111 ) -> Self {
112 Self {
113 embedding_tables,
114 rag_tables,
115 memory_tables,
116 }
117 }
118
119 pub fn classify(&self, query: &str) -> QueryIntent {
121 let upper = query.trim().to_uppercase();
122 let lower = query.to_lowercase();
123
124 if upper.starts_with("BEGIN")
126 || upper.starts_with("COMMIT")
127 || upper.starts_with("ROLLBACK")
128 || upper.starts_with("START TRANSACTION")
129 || upper.starts_with("SAVEPOINT")
130 {
131 return QueryIntent::Transaction;
132 }
133
134 if upper.starts_with("SET")
136 || upper.starts_with("SHOW")
137 || upper.starts_with("EXPLAIN")
138 || upper.starts_with("ANALYZE")
139 || upper.starts_with("VACUUM")
140 {
141 return QueryIntent::Utility;
142 }
143
144 if upper.starts_with("CREATE")
146 || upper.starts_with("ALTER")
147 || upper.starts_with("DROP")
148 || upper.starts_with("TRUNCATE")
149 {
150 return QueryIntent::Schema;
151 }
152
153 if self.matches_table_pattern(&lower, &self.rag_tables) {
156 if upper.starts_with("SELECT") {
157 return QueryIntent::RagRetrieval;
158 } else if upper.starts_with("INSERT") || upper.starts_with("UPDATE") {
159 return QueryIntent::RagIndexing;
160 }
161 }
162
163 if self.matches_table_pattern(&lower, &self.embedding_tables) {
165 if upper.starts_with("SELECT") {
166 return QueryIntent::Embedding;
167 } else if upper.starts_with("INSERT") || upper.starts_with("UPDATE") {
168 return QueryIntent::Embedding;
169 }
170 }
171
172 if self.matches_table_pattern(&lower, &self.memory_tables) {
174 return QueryIntent::AgentMemory;
175 }
176
177 if lower.contains("cosine_similarity")
179 || lower.contains("l2_distance")
180 || lower.contains("inner_product")
181 || lower.contains("<->") || lower.contains("<=>") {
184 return QueryIntent::Embedding;
185 }
186
187 if upper.starts_with("SELECT") {
189 return QueryIntent::Retrieval;
190 }
191
192 if upper.starts_with("INSERT")
193 || upper.starts_with("UPDATE")
194 || upper.starts_with("DELETE")
195 {
196 return QueryIntent::Storage;
197 }
198
199 QueryIntent::Unknown
200 }
201
202 fn matches_table_pattern(&self, query: &str, patterns: &[String]) -> bool {
204 for pattern in patterns {
205 if query.contains(pattern) {
206 return true;
207 }
208 }
209 false
210 }
211
212 pub fn add_embedding_pattern(&mut self, pattern: impl Into<String>) {
214 self.embedding_tables.push(pattern.into());
215 }
216
217 pub fn add_rag_pattern(&mut self, pattern: impl Into<String>) {
219 self.rag_tables.push(pattern.into());
220 }
221
222 pub fn add_memory_pattern(&mut self, pattern: impl Into<String>) {
224 self.memory_tables.push(pattern.into());
225 }
226}
227
228impl Default for QueryClassifier {
229 fn default() -> Self {
230 Self::new()
231 }
232}
233
234pub struct RagAnalytics {
236 retrieval_count: AtomicU64,
238 retrieval_time_us: AtomicU64,
240 indexing_count: AtomicU64,
242 indexing_time_us: AtomicU64,
244 documents_indexed: AtomicU64,
246 chunks_created: AtomicU64,
248}
249
250impl RagAnalytics {
251 pub fn new() -> Self {
253 Self {
254 retrieval_count: AtomicU64::new(0),
255 retrieval_time_us: AtomicU64::new(0),
256 indexing_count: AtomicU64::new(0),
257 indexing_time_us: AtomicU64::new(0),
258 documents_indexed: AtomicU64::new(0),
259 chunks_created: AtomicU64::new(0),
260 }
261 }
262
263 pub fn record_retrieval(&self, duration: Duration) {
265 self.retrieval_count.fetch_add(1, Ordering::Relaxed);
266 self.retrieval_time_us
267 .fetch_add(duration.as_micros() as u64, Ordering::Relaxed);
268 }
269
270 pub fn record_indexing(&self, duration: Duration, chunks: u64) {
272 self.indexing_count.fetch_add(1, Ordering::Relaxed);
273 self.indexing_time_us
274 .fetch_add(duration.as_micros() as u64, Ordering::Relaxed);
275 self.chunks_created.fetch_add(chunks, Ordering::Relaxed);
276 }
277
278 pub fn retrieval_stats(&self) -> (u64, Duration) {
280 let count = self.retrieval_count.load(Ordering::Relaxed);
281 let time = Duration::from_micros(self.retrieval_time_us.load(Ordering::Relaxed));
282 (count, time)
283 }
284
285 pub fn indexing_stats(&self) -> (u64, Duration, u64) {
287 let count = self.indexing_count.load(Ordering::Relaxed);
288 let time = Duration::from_micros(self.indexing_time_us.load(Ordering::Relaxed));
289 let chunks = self.chunks_created.load(Ordering::Relaxed);
290 (count, time, chunks)
291 }
292
293 pub fn reset(&self) {
295 self.retrieval_count.store(0, Ordering::Relaxed);
296 self.retrieval_time_us.store(0, Ordering::Relaxed);
297 self.indexing_count.store(0, Ordering::Relaxed);
298 self.indexing_time_us.store(0, Ordering::Relaxed);
299 self.documents_indexed.store(0, Ordering::Relaxed);
300 self.chunks_created.store(0, Ordering::Relaxed);
301 }
302}
303
304impl Default for RagAnalytics {
305 fn default() -> Self {
306 Self::new()
307 }
308}
309
310#[derive(Debug, Clone)]
312pub struct WorkflowStep {
313 pub index: usize,
315 pub query: String,
317 pub duration: Duration,
319 pub timestamp_nanos: u64,
321 pub intent: QueryIntent,
323 pub rows: usize,
325 pub error: Option<String>,
327}
328
329#[derive(Debug, Clone)]
331pub struct WorkflowTrace {
332 pub workflow_id: String,
334 pub start_nanos: u64,
336 pub end_nanos: Option<u64>,
338 pub steps: Vec<WorkflowStep>,
340 pub total_duration: Duration,
342 pub user: String,
344 pub agent_id: Option<String>,
346}
347
348impl WorkflowTrace {
349 pub fn new(workflow_id: impl Into<String>, user: impl Into<String>) -> Self {
351 Self {
352 workflow_id: workflow_id.into(),
353 start_nanos: now_nanos(),
354 end_nanos: None,
355 steps: Vec::new(),
356 total_duration: Duration::ZERO,
357 user: user.into(),
358 agent_id: None,
359 }
360 }
361
362 pub fn add_step(&mut self, step: WorkflowStep) {
364 self.steps.push(step);
365 self.update_duration();
366 }
367
368 pub fn complete(&mut self) {
370 self.end_nanos = Some(now_nanos());
371 self.update_duration();
372 }
373
374 fn update_duration(&mut self) {
376 self.total_duration = self.steps.iter().map(|s| s.duration).sum();
377 }
378
379 pub fn is_complete(&self) -> bool {
381 self.end_nanos.is_some()
382 }
383
384 pub fn step_count(&self) -> usize {
386 self.steps.len()
387 }
388
389 pub fn error_count(&self) -> usize {
391 self.steps.iter().filter(|s| s.error.is_some()).count()
392 }
393}
394
395pub struct WorkflowTracer {
397 workflows: DashMap<String, WorkflowTrace>,
399 completed: RwLock<VecDeque<WorkflowTrace>>,
401 max_completed: usize,
403 total_workflows: AtomicU64,
405}
406
407impl WorkflowTracer {
408 pub fn new() -> Self {
410 Self::with_max_completed(100)
411 }
412
413 pub fn with_max_completed(max: usize) -> Self {
415 Self {
416 workflows: DashMap::new(),
417 completed: RwLock::new(VecDeque::new()),
418 max_completed: max,
419 total_workflows: AtomicU64::new(0),
420 }
421 }
422
423 pub fn record_step(&self, workflow_id: &str, execution: &QueryExecution) {
425 let classifier = QueryClassifier::new();
426 let intent = classifier.classify(&execution.query);
427
428 let mut workflow = self.workflows.entry(workflow_id.to_string()).or_insert_with(|| {
429 self.total_workflows.fetch_add(1, Ordering::Relaxed);
430 WorkflowTrace::new(workflow_id, &execution.user)
431 });
432
433 let step = WorkflowStep {
434 index: workflow.steps.len(),
435 query: execution.query.clone(),
436 duration: execution.duration,
437 timestamp_nanos: now_nanos(),
438 intent,
439 rows: execution.rows,
440 error: execution.error.clone(),
441 };
442
443 workflow.add_step(step);
444 }
445
446 pub fn complete_workflow(&self, workflow_id: &str) {
448 if let Some((_, mut workflow)) = self.workflows.remove(workflow_id) {
449 workflow.complete();
450
451 let mut completed = self.completed.write();
452 completed.push_back(workflow);
453
454 while completed.len() > self.max_completed {
455 completed.pop_front();
456 }
457 }
458 }
459
460 pub fn get_workflow(&self, workflow_id: &str) -> Option<WorkflowTrace> {
462 self.workflows.get(workflow_id).map(|w| w.clone())
463 }
464
465 pub fn recent(&self, limit: usize) -> Vec<WorkflowTrace> {
467 self.completed
468 .read()
469 .iter()
470 .rev()
471 .take(limit)
472 .cloned()
473 .collect()
474 }
475
476 pub fn active_count(&self) -> usize {
478 self.workflows.len()
479 }
480
481 pub fn total_count(&self) -> u64 {
483 self.total_workflows.load(Ordering::Relaxed)
484 }
485
486 pub fn reset(&self) {
488 self.workflows.clear();
489 self.completed.write().clear();
490 self.total_workflows.store(0, Ordering::Relaxed);
491 }
492}
493
494impl Default for WorkflowTracer {
495 fn default() -> Self {
496 Self::new()
497 }
498}
499
500struct UserCostTracker {
502 queries: AtomicU64,
503 time_us: AtomicU64,
504}
505
506impl UserCostTracker {
507 fn new() -> Self {
508 Self {
509 queries: AtomicU64::new(0),
510 time_us: AtomicU64::new(0),
511 }
512 }
513
514 fn record(&self, duration: Duration) {
515 self.queries.fetch_add(1, Ordering::Relaxed);
516 self.time_us
517 .fetch_add(duration.as_micros() as u64, Ordering::Relaxed);
518 }
519}
520
521pub struct CostAttribution {
523 users: DashMap<String, UserCostTracker>,
525 agents: DashMap<String, UserCostTracker>,
527 total_queries: AtomicU64,
529 total_time_us: AtomicU64,
531 cost_per_query_second: f64,
533}
534
535impl CostAttribution {
536 pub fn new() -> Self {
538 Self {
539 users: DashMap::new(),
540 agents: DashMap::new(),
541 total_queries: AtomicU64::new(0),
542 total_time_us: AtomicU64::new(0),
543 cost_per_query_second: 0.0001,
544 }
545 }
546
547 pub fn set_cost_rate(&mut self, rate: f64) {
549 self.cost_per_query_second = rate;
550 }
551
552 pub fn record(&self, execution: &QueryExecution) {
554 self.total_queries.fetch_add(1, Ordering::Relaxed);
555 self.total_time_us
556 .fetch_add(execution.duration.as_micros() as u64, Ordering::Relaxed);
557
558 self.users
560 .entry(execution.user.clone())
561 .or_insert_with(UserCostTracker::new)
562 .record(execution.duration);
563
564 if let Some(ref workflow_id) = execution.workflow_id {
566 let agent_id = workflow_id
568 .split('-')
569 .take(2)
570 .collect::<Vec<_>>()
571 .join("-");
572
573 self.agents
574 .entry(agent_id)
575 .or_insert_with(UserCostTracker::new)
576 .record(execution.duration);
577 }
578 }
579
580 pub fn report(&self) -> CostReport {
582 let total_queries = self.total_queries.load(Ordering::Relaxed);
583 let total_time_us = self.total_time_us.load(Ordering::Relaxed);
584 let total_time_seconds = total_time_us as f64 / 1_000_000.0;
585 let estimated_cost = total_time_seconds * self.cost_per_query_second;
586
587 let by_user: Vec<_> = self
588 .users
589 .iter()
590 .map(|entry| {
591 let queries = entry.value().queries.load(Ordering::Relaxed);
592 let time_us = entry.value().time_us.load(Ordering::Relaxed);
593 let time_seconds = time_us as f64 / 1_000_000.0;
594
595 UserCost {
596 user: entry.key().clone(),
597 queries,
598 time_seconds,
599 cost_usd: time_seconds * self.cost_per_query_second,
600 }
601 })
602 .collect();
603
604 let by_agent: Vec<_> = self
605 .agents
606 .iter()
607 .map(|entry| {
608 let queries = entry.value().queries.load(Ordering::Relaxed);
609 let time_us = entry.value().time_us.load(Ordering::Relaxed);
610 let time_seconds = time_us as f64 / 1_000_000.0;
611
612 AgentCost {
613 agent_id: entry.key().clone(),
614 queries,
615 time_seconds,
616 cost_usd: time_seconds * self.cost_per_query_second,
617 }
618 })
619 .collect();
620
621 CostReport {
622 total_queries,
623 total_time_seconds,
624 estimated_cost_usd: estimated_cost,
625 by_user,
626 by_agent,
627 }
628 }
629
630 pub fn reset(&self) {
632 self.users.clear();
633 self.agents.clear();
634 self.total_queries.store(0, Ordering::Relaxed);
635 self.total_time_us.store(0, Ordering::Relaxed);
636 }
637}
638
639impl Default for CostAttribution {
640 fn default() -> Self {
641 Self::new()
642 }
643}
644
645fn now_nanos() -> u64 {
646 std::time::SystemTime::now()
647 .duration_since(std::time::SystemTime::UNIX_EPOCH)
648 .map(|d| d.as_nanos() as u64)
649 .unwrap_or(0)
650}
651
652#[cfg(test)]
653mod tests {
654 use super::*;
655
656 #[test]
657 fn test_query_classifier_basic() {
658 let classifier = QueryClassifier::new();
659
660 assert_eq!(
661 classifier.classify("SELECT * FROM users"),
662 QueryIntent::Retrieval
663 );
664 assert_eq!(
665 classifier.classify("INSERT INTO users VALUES (1)"),
666 QueryIntent::Storage
667 );
668 assert_eq!(
669 classifier.classify("UPDATE users SET name = 'Bob'"),
670 QueryIntent::Storage
671 );
672 assert_eq!(
673 classifier.classify("DELETE FROM users WHERE id = 1"),
674 QueryIntent::Storage
675 );
676 }
677
678 #[test]
679 fn test_query_classifier_transaction() {
680 let classifier = QueryClassifier::new();
681
682 assert_eq!(classifier.classify("BEGIN"), QueryIntent::Transaction);
683 assert_eq!(classifier.classify("COMMIT"), QueryIntent::Transaction);
684 assert_eq!(classifier.classify("ROLLBACK"), QueryIntent::Transaction);
685 assert_eq!(
686 classifier.classify("START TRANSACTION"),
687 QueryIntent::Transaction
688 );
689 }
690
691 #[test]
692 fn test_query_classifier_schema() {
693 let classifier = QueryClassifier::new();
694
695 assert_eq!(
696 classifier.classify("CREATE TABLE foo (id INT)"),
697 QueryIntent::Schema
698 );
699 assert_eq!(
700 classifier.classify("ALTER TABLE foo ADD COLUMN bar TEXT"),
701 QueryIntent::Schema
702 );
703 assert_eq!(classifier.classify("DROP TABLE foo"), QueryIntent::Schema);
704 }
705
706 #[test]
707 fn test_query_classifier_embedding() {
708 let classifier = QueryClassifier::new();
709
710 assert_eq!(
711 classifier.classify("SELECT * FROM embeddings WHERE id = 1"),
712 QueryIntent::Embedding
713 );
714 assert_eq!(
715 classifier.classify("INSERT INTO vectors (embedding) VALUES (?)"),
716 QueryIntent::Embedding
717 );
718 assert_eq!(
719 classifier.classify("SELECT * FROM items ORDER BY embedding <-> '[1,2,3]'"),
720 QueryIntent::Embedding
721 );
722 }
723
724 #[test]
725 fn test_query_classifier_rag() {
726 let classifier = QueryClassifier::new();
727
728 assert_eq!(
729 classifier.classify("SELECT * FROM documents WHERE topic = 'AI'"),
730 QueryIntent::RagRetrieval
731 );
732 assert_eq!(
733 classifier.classify("INSERT INTO chunks (content, embedding) VALUES (?, ?)"),
734 QueryIntent::RagIndexing
735 );
736 }
737
738 #[test]
739 fn test_query_classifier_agent_memory() {
740 let classifier = QueryClassifier::new();
741
742 assert_eq!(
743 classifier.classify("SELECT * FROM agent_memory WHERE session_id = ?"),
744 QueryIntent::AgentMemory
745 );
746 assert_eq!(
747 classifier.classify("INSERT INTO conversation_history (message) VALUES (?)"),
748 QueryIntent::AgentMemory
749 );
750 }
751
752 #[test]
753 fn test_workflow_tracer() {
754 let tracer = WorkflowTracer::new();
755
756 let execution = QueryExecution::new("SELECT 1", Duration::from_millis(5))
757 .with_user("alice");
758
759 tracer.record_step("workflow-1", &execution);
760 tracer.record_step("workflow-1", &execution);
761
762 let workflow = tracer.get_workflow("workflow-1").unwrap();
763 assert_eq!(workflow.step_count(), 2);
764 assert_eq!(workflow.user, "alice");
765
766 tracer.complete_workflow("workflow-1");
767 assert!(tracer.get_workflow("workflow-1").is_none());
768
769 let recent = tracer.recent(10);
770 assert_eq!(recent.len(), 1);
771 assert!(recent[0].is_complete());
772 }
773
774 #[test]
775 fn test_cost_attribution() {
776 let cost = CostAttribution::new();
777
778 let execution = QueryExecution::new("SELECT 1", Duration::from_secs(1))
779 .with_user("alice");
780
781 cost.record(&execution);
782 cost.record(&execution);
783
784 let report = cost.report();
785 assert_eq!(report.total_queries, 2);
786 assert!((report.total_time_seconds - 2.0).abs() < 0.001);
787 assert!(report.by_user.iter().any(|u| u.user == "alice" && u.queries == 2));
788 }
789
790 #[test]
791 fn test_rag_analytics() {
792 let rag = RagAnalytics::new();
793
794 rag.record_retrieval(Duration::from_millis(50));
795 rag.record_retrieval(Duration::from_millis(30));
796 rag.record_indexing(Duration::from_millis(100), 5);
797
798 let (retrieval_count, retrieval_time) = rag.retrieval_stats();
799 assert_eq!(retrieval_count, 2);
800 assert_eq!(retrieval_time, Duration::from_millis(80));
801
802 let (indexing_count, indexing_time, chunks) = rag.indexing_stats();
803 assert_eq!(indexing_count, 1);
804 assert_eq!(indexing_time, Duration::from_millis(100));
805 assert_eq!(chunks, 5);
806 }
807
808 #[test]
809 fn test_intent_display() {
810 assert_eq!(QueryIntent::Retrieval.to_string(), "retrieval");
811 assert_eq!(QueryIntent::RagRetrieval.to_string(), "rag_retrieval");
812 assert_eq!(QueryIntent::AgentMemory.to_string(), "agent_memory");
813 }
814}