1use std::collections::VecDeque;
7use std::sync::atomic::{AtomicU64, Ordering};
8use std::time::Duration;
9
10use dashmap::DashMap;
11use parking_lot::RwLock;
12
13use super::statistics::QueryExecution;
14use super::{AgentCost, CostReport, UserCost};
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
18pub enum QueryIntent {
19 Retrieval,
21
22 Storage,
24
25 Embedding,
27
28 Schema,
30
31 Transaction,
33
34 Utility,
36
37 RagRetrieval,
39
40 RagIndexing,
42
43 AgentMemory,
45
46 Unknown,
48}
49
50impl std::fmt::Display for QueryIntent {
51 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
52 match self {
53 QueryIntent::Retrieval => write!(f, "retrieval"),
54 QueryIntent::Storage => write!(f, "storage"),
55 QueryIntent::Embedding => write!(f, "embedding"),
56 QueryIntent::Schema => write!(f, "schema"),
57 QueryIntent::Transaction => write!(f, "transaction"),
58 QueryIntent::Utility => write!(f, "utility"),
59 QueryIntent::RagRetrieval => write!(f, "rag_retrieval"),
60 QueryIntent::RagIndexing => write!(f, "rag_indexing"),
61 QueryIntent::AgentMemory => write!(f, "agent_memory"),
62 QueryIntent::Unknown => write!(f, "unknown"),
63 }
64 }
65}
66
67pub struct QueryClassifier {
69 embedding_tables: Vec<String>,
71
72 rag_tables: Vec<String>,
74
75 memory_tables: Vec<String>,
77}
78
79impl QueryClassifier {
80 pub fn new() -> Self {
82 Self {
83 embedding_tables: vec![
84 "embeddings".to_string(),
85 "vectors".to_string(),
86 "embedding".to_string(),
87 "vector_store".to_string(),
88 ],
89 rag_tables: vec![
90 "documents".to_string(),
91 "chunks".to_string(),
92 "doc_chunks".to_string(),
93 "knowledge_base".to_string(),
94 "context".to_string(),
95 ],
96 memory_tables: vec![
97 "memory".to_string(),
98 "agent_memory".to_string(),
99 "conversation_history".to_string(),
100 "chat_history".to_string(),
101 "sessions".to_string(),
102 ],
103 }
104 }
105
106 pub fn with_patterns(
108 embedding_tables: Vec<String>,
109 rag_tables: Vec<String>,
110 memory_tables: Vec<String>,
111 ) -> Self {
112 Self {
113 embedding_tables,
114 rag_tables,
115 memory_tables,
116 }
117 }
118
119 #[allow(clippy::if_same_then_else)]
121 pub fn classify(&self, query: &str) -> QueryIntent {
122 let upper = query.trim().to_uppercase();
123 let lower = query.to_lowercase();
124
125 if upper.starts_with("BEGIN")
127 || upper.starts_with("COMMIT")
128 || upper.starts_with("ROLLBACK")
129 || upper.starts_with("START TRANSACTION")
130 || upper.starts_with("SAVEPOINT")
131 {
132 return QueryIntent::Transaction;
133 }
134
135 if upper.starts_with("SET")
137 || upper.starts_with("SHOW")
138 || upper.starts_with("EXPLAIN")
139 || upper.starts_with("ANALYZE")
140 || upper.starts_with("VACUUM")
141 {
142 return QueryIntent::Utility;
143 }
144
145 if upper.starts_with("CREATE")
147 || upper.starts_with("ALTER")
148 || upper.starts_with("DROP")
149 || upper.starts_with("TRUNCATE")
150 {
151 return QueryIntent::Schema;
152 }
153
154 if self.matches_table_pattern(&lower, &self.rag_tables) {
157 if upper.starts_with("SELECT") {
158 return QueryIntent::RagRetrieval;
159 } else if upper.starts_with("INSERT") || upper.starts_with("UPDATE") {
160 return QueryIntent::RagIndexing;
161 }
162 }
163
164 if self.matches_table_pattern(&lower, &self.embedding_tables) {
166 if upper.starts_with("SELECT") {
167 return QueryIntent::Embedding;
168 } else if upper.starts_with("INSERT") || upper.starts_with("UPDATE") {
169 return QueryIntent::Embedding;
170 }
171 }
172
173 if self.matches_table_pattern(&lower, &self.memory_tables) {
175 return QueryIntent::AgentMemory;
176 }
177
178 if lower.contains("cosine_similarity")
180 || lower.contains("l2_distance")
181 || lower.contains("inner_product")
182 || lower.contains("<->") || lower.contains("<=>")
184 {
186 return QueryIntent::Embedding;
187 }
188
189 if upper.starts_with("SELECT") {
191 return QueryIntent::Retrieval;
192 }
193
194 if upper.starts_with("INSERT") || upper.starts_with("UPDATE") || upper.starts_with("DELETE")
195 {
196 return QueryIntent::Storage;
197 }
198
199 QueryIntent::Unknown
200 }
201
202 fn matches_table_pattern(&self, query: &str, patterns: &[String]) -> bool {
204 for pattern in patterns {
205 if query.contains(pattern) {
206 return true;
207 }
208 }
209 false
210 }
211
212 pub fn add_embedding_pattern(&mut self, pattern: impl Into<String>) {
214 self.embedding_tables.push(pattern.into());
215 }
216
217 pub fn add_rag_pattern(&mut self, pattern: impl Into<String>) {
219 self.rag_tables.push(pattern.into());
220 }
221
222 pub fn add_memory_pattern(&mut self, pattern: impl Into<String>) {
224 self.memory_tables.push(pattern.into());
225 }
226}
227
228impl Default for QueryClassifier {
229 fn default() -> Self {
230 Self::new()
231 }
232}
233
234pub struct RagAnalytics {
236 retrieval_count: AtomicU64,
238 retrieval_time_us: AtomicU64,
240 indexing_count: AtomicU64,
242 indexing_time_us: AtomicU64,
244 documents_indexed: AtomicU64,
246 chunks_created: AtomicU64,
248}
249
250impl RagAnalytics {
251 pub fn new() -> Self {
253 Self {
254 retrieval_count: AtomicU64::new(0),
255 retrieval_time_us: AtomicU64::new(0),
256 indexing_count: AtomicU64::new(0),
257 indexing_time_us: AtomicU64::new(0),
258 documents_indexed: AtomicU64::new(0),
259 chunks_created: AtomicU64::new(0),
260 }
261 }
262
263 pub fn record_retrieval(&self, duration: Duration) {
265 self.retrieval_count.fetch_add(1, Ordering::Relaxed);
266 self.retrieval_time_us
267 .fetch_add(duration.as_micros() as u64, Ordering::Relaxed);
268 }
269
270 pub fn record_indexing(&self, duration: Duration, chunks: u64) {
272 self.indexing_count.fetch_add(1, Ordering::Relaxed);
273 self.indexing_time_us
274 .fetch_add(duration.as_micros() as u64, Ordering::Relaxed);
275 self.chunks_created.fetch_add(chunks, Ordering::Relaxed);
276 }
277
278 pub fn retrieval_stats(&self) -> (u64, Duration) {
280 let count = self.retrieval_count.load(Ordering::Relaxed);
281 let time = Duration::from_micros(self.retrieval_time_us.load(Ordering::Relaxed));
282 (count, time)
283 }
284
285 pub fn indexing_stats(&self) -> (u64, Duration, u64) {
287 let count = self.indexing_count.load(Ordering::Relaxed);
288 let time = Duration::from_micros(self.indexing_time_us.load(Ordering::Relaxed));
289 let chunks = self.chunks_created.load(Ordering::Relaxed);
290 (count, time, chunks)
291 }
292
293 pub fn reset(&self) {
295 self.retrieval_count.store(0, Ordering::Relaxed);
296 self.retrieval_time_us.store(0, Ordering::Relaxed);
297 self.indexing_count.store(0, Ordering::Relaxed);
298 self.indexing_time_us.store(0, Ordering::Relaxed);
299 self.documents_indexed.store(0, Ordering::Relaxed);
300 self.chunks_created.store(0, Ordering::Relaxed);
301 }
302}
303
304impl Default for RagAnalytics {
305 fn default() -> Self {
306 Self::new()
307 }
308}
309
310#[derive(Debug, Clone)]
312pub struct WorkflowStep {
313 pub index: usize,
315 pub query: String,
317 pub duration: Duration,
319 pub timestamp_nanos: u64,
321 pub intent: QueryIntent,
323 pub rows: usize,
325 pub error: Option<String>,
327}
328
329#[derive(Debug, Clone)]
331pub struct WorkflowTrace {
332 pub workflow_id: String,
334 pub start_nanos: u64,
336 pub end_nanos: Option<u64>,
338 pub steps: Vec<WorkflowStep>,
340 pub total_duration: Duration,
342 pub user: String,
344 pub agent_id: Option<String>,
346}
347
348impl WorkflowTrace {
349 pub fn new(workflow_id: impl Into<String>, user: impl Into<String>) -> Self {
351 Self {
352 workflow_id: workflow_id.into(),
353 start_nanos: now_nanos(),
354 end_nanos: None,
355 steps: Vec::new(),
356 total_duration: Duration::ZERO,
357 user: user.into(),
358 agent_id: None,
359 }
360 }
361
362 pub fn add_step(&mut self, step: WorkflowStep) {
364 self.steps.push(step);
365 self.update_duration();
366 }
367
368 pub fn complete(&mut self) {
370 self.end_nanos = Some(now_nanos());
371 self.update_duration();
372 }
373
374 fn update_duration(&mut self) {
376 self.total_duration = self.steps.iter().map(|s| s.duration).sum();
377 }
378
379 pub fn is_complete(&self) -> bool {
381 self.end_nanos.is_some()
382 }
383
384 pub fn step_count(&self) -> usize {
386 self.steps.len()
387 }
388
389 pub fn error_count(&self) -> usize {
391 self.steps.iter().filter(|s| s.error.is_some()).count()
392 }
393}
394
395pub struct WorkflowTracer {
397 workflows: DashMap<String, WorkflowTrace>,
399 completed: RwLock<VecDeque<WorkflowTrace>>,
401 max_completed: usize,
403 total_workflows: AtomicU64,
405}
406
407impl WorkflowTracer {
408 pub fn new() -> Self {
410 Self::with_max_completed(100)
411 }
412
413 pub fn with_max_completed(max: usize) -> Self {
415 Self {
416 workflows: DashMap::new(),
417 completed: RwLock::new(VecDeque::new()),
418 max_completed: max,
419 total_workflows: AtomicU64::new(0),
420 }
421 }
422
423 pub fn record_step(&self, workflow_id: &str, execution: &QueryExecution) {
425 let classifier = QueryClassifier::new();
426 let intent = classifier.classify(&execution.query);
427
428 let mut workflow = self
429 .workflows
430 .entry(workflow_id.to_string())
431 .or_insert_with(|| {
432 self.total_workflows.fetch_add(1, Ordering::Relaxed);
433 WorkflowTrace::new(workflow_id, &execution.user)
434 });
435
436 let step = WorkflowStep {
437 index: workflow.steps.len(),
438 query: execution.query.clone(),
439 duration: execution.duration,
440 timestamp_nanos: now_nanos(),
441 intent,
442 rows: execution.rows,
443 error: execution.error.clone(),
444 };
445
446 workflow.add_step(step);
447 }
448
449 pub fn complete_workflow(&self, workflow_id: &str) {
451 if let Some((_, mut workflow)) = self.workflows.remove(workflow_id) {
452 workflow.complete();
453
454 let mut completed = self.completed.write();
455 completed.push_back(workflow);
456
457 while completed.len() > self.max_completed {
458 completed.pop_front();
459 }
460 }
461 }
462
463 pub fn get_workflow(&self, workflow_id: &str) -> Option<WorkflowTrace> {
465 self.workflows.get(workflow_id).map(|w| w.clone())
466 }
467
468 pub fn recent(&self, limit: usize) -> Vec<WorkflowTrace> {
470 self.completed
471 .read()
472 .iter()
473 .rev()
474 .take(limit)
475 .cloned()
476 .collect()
477 }
478
479 pub fn active_count(&self) -> usize {
481 self.workflows.len()
482 }
483
484 pub fn total_count(&self) -> u64 {
486 self.total_workflows.load(Ordering::Relaxed)
487 }
488
489 pub fn reset(&self) {
491 self.workflows.clear();
492 self.completed.write().clear();
493 self.total_workflows.store(0, Ordering::Relaxed);
494 }
495}
496
497impl Default for WorkflowTracer {
498 fn default() -> Self {
499 Self::new()
500 }
501}
502
503struct UserCostTracker {
505 queries: AtomicU64,
506 time_us: AtomicU64,
507}
508
509impl UserCostTracker {
510 fn new() -> Self {
511 Self {
512 queries: AtomicU64::new(0),
513 time_us: AtomicU64::new(0),
514 }
515 }
516
517 fn record(&self, duration: Duration) {
518 self.queries.fetch_add(1, Ordering::Relaxed);
519 self.time_us
520 .fetch_add(duration.as_micros() as u64, Ordering::Relaxed);
521 }
522}
523
524pub struct CostAttribution {
526 users: DashMap<String, UserCostTracker>,
528 agents: DashMap<String, UserCostTracker>,
530 total_queries: AtomicU64,
532 total_time_us: AtomicU64,
534 cost_per_query_second: f64,
536}
537
538impl CostAttribution {
539 pub fn new() -> Self {
541 Self {
542 users: DashMap::new(),
543 agents: DashMap::new(),
544 total_queries: AtomicU64::new(0),
545 total_time_us: AtomicU64::new(0),
546 cost_per_query_second: 0.0001,
547 }
548 }
549
550 pub fn set_cost_rate(&mut self, rate: f64) {
552 self.cost_per_query_second = rate;
553 }
554
555 pub fn record(&self, execution: &QueryExecution) {
557 self.total_queries.fetch_add(1, Ordering::Relaxed);
558 self.total_time_us
559 .fetch_add(execution.duration.as_micros() as u64, Ordering::Relaxed);
560
561 self.users
563 .entry(execution.user.clone())
564 .or_insert_with(UserCostTracker::new)
565 .record(execution.duration);
566
567 if let Some(ref workflow_id) = execution.workflow_id {
569 let agent_id = workflow_id.split('-').take(2).collect::<Vec<_>>().join("-");
571
572 self.agents
573 .entry(agent_id)
574 .or_insert_with(UserCostTracker::new)
575 .record(execution.duration);
576 }
577 }
578
579 pub fn report(&self) -> CostReport {
581 let total_queries = self.total_queries.load(Ordering::Relaxed);
582 let total_time_us = self.total_time_us.load(Ordering::Relaxed);
583 let total_time_seconds = total_time_us as f64 / 1_000_000.0;
584 let estimated_cost = total_time_seconds * self.cost_per_query_second;
585
586 let by_user: Vec<_> = self
587 .users
588 .iter()
589 .map(|entry| {
590 let queries = entry.value().queries.load(Ordering::Relaxed);
591 let time_us = entry.value().time_us.load(Ordering::Relaxed);
592 let time_seconds = time_us as f64 / 1_000_000.0;
593
594 UserCost {
595 user: entry.key().clone(),
596 queries,
597 time_seconds,
598 cost_usd: time_seconds * self.cost_per_query_second,
599 }
600 })
601 .collect();
602
603 let by_agent: Vec<_> = self
604 .agents
605 .iter()
606 .map(|entry| {
607 let queries = entry.value().queries.load(Ordering::Relaxed);
608 let time_us = entry.value().time_us.load(Ordering::Relaxed);
609 let time_seconds = time_us as f64 / 1_000_000.0;
610
611 AgentCost {
612 agent_id: entry.key().clone(),
613 queries,
614 time_seconds,
615 cost_usd: time_seconds * self.cost_per_query_second,
616 }
617 })
618 .collect();
619
620 CostReport {
621 total_queries,
622 total_time_seconds,
623 estimated_cost_usd: estimated_cost,
624 by_user,
625 by_agent,
626 }
627 }
628
629 pub fn reset(&self) {
631 self.users.clear();
632 self.agents.clear();
633 self.total_queries.store(0, Ordering::Relaxed);
634 self.total_time_us.store(0, Ordering::Relaxed);
635 }
636}
637
638impl Default for CostAttribution {
639 fn default() -> Self {
640 Self::new()
641 }
642}
643
644fn now_nanos() -> u64 {
645 std::time::SystemTime::now()
646 .duration_since(std::time::SystemTime::UNIX_EPOCH)
647 .map(|d| d.as_nanos() as u64)
648 .unwrap_or(0)
649}
650
651#[cfg(test)]
652mod tests {
653 use super::*;
654
655 #[test]
656 fn test_query_classifier_basic() {
657 let classifier = QueryClassifier::new();
658
659 assert_eq!(
660 classifier.classify("SELECT * FROM users"),
661 QueryIntent::Retrieval
662 );
663 assert_eq!(
664 classifier.classify("INSERT INTO users VALUES (1)"),
665 QueryIntent::Storage
666 );
667 assert_eq!(
668 classifier.classify("UPDATE users SET name = 'Bob'"),
669 QueryIntent::Storage
670 );
671 assert_eq!(
672 classifier.classify("DELETE FROM users WHERE id = 1"),
673 QueryIntent::Storage
674 );
675 }
676
677 #[test]
678 fn test_query_classifier_transaction() {
679 let classifier = QueryClassifier::new();
680
681 assert_eq!(classifier.classify("BEGIN"), QueryIntent::Transaction);
682 assert_eq!(classifier.classify("COMMIT"), QueryIntent::Transaction);
683 assert_eq!(classifier.classify("ROLLBACK"), QueryIntent::Transaction);
684 assert_eq!(
685 classifier.classify("START TRANSACTION"),
686 QueryIntent::Transaction
687 );
688 }
689
690 #[test]
691 fn test_query_classifier_schema() {
692 let classifier = QueryClassifier::new();
693
694 assert_eq!(
695 classifier.classify("CREATE TABLE foo (id INT)"),
696 QueryIntent::Schema
697 );
698 assert_eq!(
699 classifier.classify("ALTER TABLE foo ADD COLUMN bar TEXT"),
700 QueryIntent::Schema
701 );
702 assert_eq!(classifier.classify("DROP TABLE foo"), QueryIntent::Schema);
703 }
704
705 #[test]
706 fn test_query_classifier_embedding() {
707 let classifier = QueryClassifier::new();
708
709 assert_eq!(
710 classifier.classify("SELECT * FROM embeddings WHERE id = 1"),
711 QueryIntent::Embedding
712 );
713 assert_eq!(
714 classifier.classify("INSERT INTO vectors (embedding) VALUES (?)"),
715 QueryIntent::Embedding
716 );
717 assert_eq!(
718 classifier.classify("SELECT * FROM items ORDER BY embedding <-> '[1,2,3]'"),
719 QueryIntent::Embedding
720 );
721 }
722
723 #[test]
724 fn test_query_classifier_rag() {
725 let classifier = QueryClassifier::new();
726
727 assert_eq!(
728 classifier.classify("SELECT * FROM documents WHERE topic = 'AI'"),
729 QueryIntent::RagRetrieval
730 );
731 assert_eq!(
732 classifier.classify("INSERT INTO chunks (content, embedding) VALUES (?, ?)"),
733 QueryIntent::RagIndexing
734 );
735 }
736
737 #[test]
738 fn test_query_classifier_agent_memory() {
739 let classifier = QueryClassifier::new();
740
741 assert_eq!(
742 classifier.classify("SELECT * FROM agent_memory WHERE session_id = ?"),
743 QueryIntent::AgentMemory
744 );
745 assert_eq!(
746 classifier.classify("INSERT INTO conversation_history (message) VALUES (?)"),
747 QueryIntent::AgentMemory
748 );
749 }
750
751 #[test]
752 fn test_workflow_tracer() {
753 let tracer = WorkflowTracer::new();
754
755 let execution =
756 QueryExecution::new("SELECT 1", Duration::from_millis(5)).with_user("alice");
757
758 tracer.record_step("workflow-1", &execution);
759 tracer.record_step("workflow-1", &execution);
760
761 let workflow = tracer.get_workflow("workflow-1").unwrap();
762 assert_eq!(workflow.step_count(), 2);
763 assert_eq!(workflow.user, "alice");
764
765 tracer.complete_workflow("workflow-1");
766 assert!(tracer.get_workflow("workflow-1").is_none());
767
768 let recent = tracer.recent(10);
769 assert_eq!(recent.len(), 1);
770 assert!(recent[0].is_complete());
771 }
772
773 #[test]
774 fn test_cost_attribution() {
775 let cost = CostAttribution::new();
776
777 let execution = QueryExecution::new("SELECT 1", Duration::from_secs(1)).with_user("alice");
778
779 cost.record(&execution);
780 cost.record(&execution);
781
782 let report = cost.report();
783 assert_eq!(report.total_queries, 2);
784 assert!((report.total_time_seconds - 2.0).abs() < 0.001);
785 assert!(report
786 .by_user
787 .iter()
788 .any(|u| u.user == "alice" && u.queries == 2));
789 }
790
791 #[test]
792 fn test_rag_analytics() {
793 let rag = RagAnalytics::new();
794
795 rag.record_retrieval(Duration::from_millis(50));
796 rag.record_retrieval(Duration::from_millis(30));
797 rag.record_indexing(Duration::from_millis(100), 5);
798
799 let (retrieval_count, retrieval_time) = rag.retrieval_stats();
800 assert_eq!(retrieval_count, 2);
801 assert_eq!(retrieval_time, Duration::from_millis(80));
802
803 let (indexing_count, indexing_time, chunks) = rag.indexing_stats();
804 assert_eq!(indexing_count, 1);
805 assert_eq!(indexing_time, Duration::from_millis(100));
806 assert_eq!(chunks, 5);
807 }
808
809 #[test]
810 fn test_intent_display() {
811 assert_eq!(QueryIntent::Retrieval.to_string(), "retrieval");
812 assert_eq!(QueryIntent::RagRetrieval.to_string(), "rag_retrieval");
813 assert_eq!(QueryIntent::AgentMemory.to_string(), "agent_memory");
814 }
815}