1use crate::types::*;
8use ahash::RandomState;
9use dashmap::DashMap;
10use std::collections::HashMap;
11use std::sync::Arc;
12use tokio::sync::RwLock;
13use tracing::{debug, info};
14use uuid::Uuid;
15
16#[derive(Clone)]
18pub struct IntentClassifier {
19 training_data: Arc<RwLock<Vec<TrainingExample>>>,
21
22 vocabulary: Arc<DashMap<String, usize, RandomState>>,
24
25 intent_patterns: Arc<DashMap<IntentId, Vec<String>, RandomState>>,
27
28 config: ClassifierConfig,
30}
31
32impl IntentClassifier {
33 pub async fn new() -> Result<Self> {
35 Self::with_config(ClassifierConfig::default()).await
36 }
37
38 pub async fn with_config(config: ClassifierConfig) -> Result<Self> {
40 let classifier = Self {
41 training_data: Arc::new(RwLock::new(Vec::new())),
42 vocabulary: Arc::new(DashMap::with_hasher(RandomState::new())),
43 intent_patterns: Arc::new(DashMap::with_hasher(RandomState::new())),
44 config,
45 };
46
47 classifier.load_bootstrap_data().await?;
49
50 if classifier.config.debug_mode {
51 info!("Intent classifier initialized with {} dimensions", classifier.config.feature_dimensions);
52 }
53
54 Ok(classifier)
55 }
56
57 pub async fn predict_intent(&self, text: &str) -> Result<IntentPrediction> {
59 let start_time = std::time::Instant::now();
60
61 if self.config.debug_mode {
62 debug!("Classifying intent for text: '{}'", text);
63 }
64
65 if let Some(exact_match) = self.find_exact_match(text).await? {
67 return Ok(exact_match);
68 }
69
70 let features = self.extract_features(text).await?;
72
73 let intent_scores = self.calculate_intent_scores(&features).await?;
75
76 let (best_intent, best_confidence) = self.find_best_intent(&intent_scores)?;
78
79 let alternative_intents = self.get_alternative_intents(&intent_scores, &best_intent);
81
82 let reasoning = self.generate_reasoning(text, &best_intent, &features).await;
84
85 let prediction = IntentPrediction {
86 intent: best_intent,
87 confidence: best_confidence,
88 alternative_intents,
89 reasoning,
90 };
91
92 if self.config.debug_mode {
93 let elapsed = start_time.elapsed();
94 info!("Intent prediction: {} (confidence: {:.3}, time: {:?})",
95 prediction.intent, prediction.confidence.value(), elapsed);
96 }
97
98 Ok(prediction)
99 }
100
101 pub async fn classify(&self, request: ClassificationRequest) -> Result<ClassificationResponse> {
103 let start_time = std::time::Instant::now();
104 let request_id = Uuid::new_v4();
105
106 let mut prediction = self.predict_intent(&request.text).await?;
107
108 if !request.include_alternatives {
110 prediction.alternative_intents.clear();
111 }
112
113 if !request.include_reasoning {
114 prediction.reasoning = String::new();
115 }
116
117 let processing_time_ms = start_time.elapsed().as_millis() as f64;
118
119 Ok(ClassificationResponse {
120 prediction,
121 processing_time_ms,
122 request_id,
123 })
124 }
125
126 pub async fn add_training_example(&self, example: TrainingExample) -> Result<()> {
128 if example.text.trim().is_empty() {
130 return Err(IntentError::InvalidParameter {
131 parameter: "text".to_string(),
132 message: "Training example text cannot be empty".to_string(),
133 });
134 }
135
136 if !(0.0..=1.0).contains(&example.confidence) {
137 return Err(IntentError::InvalidParameter {
138 parameter: "confidence".to_string(),
139 message: format!("Confidence must be between 0.0 and 1.0, got {}", example.confidence),
140 });
141 }
142
143 {
145 let mut training_data = self.training_data.write().await;
146 training_data.push(example.clone());
147 }
148
149 self.update_intent_patterns(&example.intent, &example.text).await?;
151
152 self.update_vocabulary(&example.text).await;
154
155 if self.config.debug_mode {
156 info!("Added training example: '{}' -> {}", example.text, example.intent);
157 }
158
159 Ok(())
160 }
161
162 pub async fn add_feedback(&self, feedback: IntentFeedback) -> Result<()> {
164 if self.config.debug_mode {
165 info!("Adding feedback: '{}' -> {} (predicted: {}, satisfaction: {})",
166 feedback.text, feedback.actual_intent, feedback.predicted_intent, feedback.satisfaction_score);
167 }
168
169 let confidence = feedback.satisfaction_score / 5.0; let example = TrainingExample {
172 text: feedback.text,
173 intent: feedback.actual_intent,
174 confidence,
175 source: TrainingSource::UserFeedback,
176 };
177
178 self.add_training_example(example).await?;
179
180 if self.should_retrain().await {
182 self.retrain().await?;
183 }
184
185 Ok(())
186 }
187
188 pub async fn get_stats(&self) -> ClassifierStats {
190 let training_data = self.training_data.read().await;
191
192 ClassifierStats {
193 training_examples: training_data.len(),
194 vocabulary_size: self.vocabulary.len(),
195 intent_count: self.intent_patterns.len(),
196 feedback_examples: training_data
197 .iter()
198 .filter(|e| matches!(e.source, TrainingSource::UserFeedback))
199 .count(),
200 last_updated: Some(chrono::Utc::now()),
201 }
202 }
203
204 pub async fn export_training_data(&self) -> Result<String> {
206 let training_data = self.training_data.read().await;
207 serde_json::to_string_pretty(&*training_data)
208 .map_err(IntentError::SerializationError)
209 }
210
211 pub async fn import_training_data(&self, json_data: &str) -> Result<()> {
213 let examples: Vec<TrainingExample> = serde_json::from_str(json_data)
214 .map_err(IntentError::SerializationError)?;
215
216 for example in examples {
217 self.add_training_example(example).await?;
218 }
219
220 Ok(())
221 }
222
223 pub async fn clear_training_data(&self) -> Result<()> {
225 {
226 let mut training_data = self.training_data.write().await;
227 training_data.clear();
228 }
229
230 self.vocabulary.clear();
231 self.intent_patterns.clear();
232
233 self.load_bootstrap_data().await?;
235
236 if self.config.debug_mode {
237 info!("Cleared all training data and reloaded bootstrap data");
238 }
239
240 Ok(())
241 }
242
243 async fn find_exact_match(&self, text: &str) -> Result<Option<IntentPrediction>> {
245 let training_data = self.training_data.read().await;
246
247 for example in training_data.iter() {
248 if example.text == text {
249 let confidence = Confidence::new(example.confidence)
250 .unwrap_or_else(|_| Confidence::default());
251
252 return Ok(Some(IntentPrediction {
253 intent: example.intent.clone(),
254 confidence,
255 alternative_intents: vec![],
256 reasoning: "Exact match found in training data".to_string(),
257 }));
258 }
259 }
260
261 Ok(None)
262 }
263
264 async fn extract_features(&self, text: &str) -> Result<FeatureVector> {
266 let cleaned_text = self.preprocess_text(text);
267
268 let text_features = self.extract_text_features(&cleaned_text).await?;
270
271 let context_features = self.extract_context_features(&cleaned_text);
273
274 let mut metadata = HashMap::new();
276 metadata.insert("text_length".to_string(), cleaned_text.len() as f64);
277 metadata.insert("word_count".to_string(), cleaned_text.split_whitespace().count() as f64);
278
279 Ok(FeatureVector {
280 text_features,
281 context_features,
282 metadata,
283 })
284 }
285
286 async fn calculate_intent_scores(&self, features: &FeatureVector) -> Result<HashMap<IntentId, f64>> {
288 let mut scores = HashMap::new();
289
290 for entry in self.intent_patterns.iter() {
291 let (intent, pattern_texts) = (entry.key(), entry.value());
292 let mut intent_score: f64 = 0.0;
293
294 for pattern_text in pattern_texts {
296 let pattern_features = self.extract_text_features(pattern_text).await?;
297 let similarity = self.cosine_similarity(&features.text_features, &pattern_features);
298 intent_score = intent_score.max(similarity);
299 }
300
301 intent_score += self.calculate_context_boost(intent, features);
303
304 scores.insert(intent.clone(), intent_score.min(1.0));
305 }
306
307 if scores.values().all(|&score| score < self.config.min_confidence_threshold) {
309 let fallback_scores = self.rule_based_classification(features).await;
310 for (intent, score) in fallback_scores {
311 scores.entry(intent).or_insert(score);
312 }
313 }
314
315 Ok(scores)
316 }
317
318 fn find_best_intent(&self, scores: &HashMap<IntentId, f64>) -> Result<(IntentId, Confidence)> {
320 let (best_intent, best_score) = scores
321 .iter()
322 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
323 .ok_or_else(|| IntentError::ClassificationFailed("No intents found".to_string()))?;
324
325 let confidence = Confidence::new(*best_score)
326 .unwrap_or_else(|_| Confidence::default());
327
328 Ok((best_intent.clone(), confidence))
329 }
330
331 fn get_alternative_intents(&self, scores: &HashMap<IntentId, f64>, best_intent: &IntentId) -> Vec<(IntentId, Confidence)> {
333 let mut alternatives: Vec<(IntentId, Confidence)> = scores
334 .iter()
335 .filter(|(intent, _)| *intent != best_intent)
336 .filter_map(|(intent, score)| {
337 Confidence::new(*score)
338 .ok()
339 .map(|confidence| (intent.clone(), confidence))
340 })
341 .collect();
342
343 alternatives.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
345 alternatives.truncate(3); alternatives
348 }
349
350 async fn generate_reasoning(&self, _text: &str, intent: &IntentId, features: &FeatureVector) -> String {
352 if let Some(intent_patterns) = self.intent_patterns.get(intent) {
353 if let Some(best_pattern) = intent_patterns.first() {
354 return format!(
355 "Classified as '{}' based on similarity to pattern: '{}' (using {} text features)",
356 intent, best_pattern, features.text_features.len()
357 );
358 }
359 }
360
361 format!("Classified as '{}' using rule-based analysis", intent)
362 }
363
364 async fn load_bootstrap_data(&self) -> Result<()> {
366 let bootstrap_examples = self.get_bootstrap_examples();
367
368 for (text, intent_str) in bootstrap_examples {
369 let example = TrainingExample {
370 text: text.to_string(),
371 intent: IntentId::from(intent_str),
372 confidence: 1.0,
373 source: TrainingSource::Bootstrap,
374 };
375
376 self.add_training_example(example).await?;
377 }
378
379 if self.config.debug_mode {
380 info!("Loaded {} bootstrap training examples", self.get_bootstrap_examples().len());
381 }
382
383 Ok(())
384 }
385
386 fn get_bootstrap_examples(&self) -> Vec<(&'static str, &'static str)> {
388 vec![
389 ("merge these JSON files together", "data_merge"),
391 ("combine multiple JSON documents", "data_merge"),
392 ("join several data files into one", "data_merge"),
393 ("consolidate JSON objects", "data_merge"),
394 ("split this large JSON file", "data_split"),
395 ("break apart this data into smaller pieces", "data_split"),
396 ("divide this file into multiple parts", "data_split"),
397 ("convert JSON to CSV format", "data_transform"),
398 ("transform this data structure", "data_transform"),
399 ("change the format of this file", "data_transform"),
400 ("analyze this dataset for patterns", "data_analyze"),
401 ("examine the data for insights", "data_analyze"),
402 ("what trends do you see in this data", "data_analyze"),
403 ("give me statistics about this data", "data_analyze"),
404
405 ("read the contents of this file", "file_read"),
407 ("load this document", "file_read"),
408 ("open and parse this file", "file_read"),
409 ("save this data to a file", "file_write"),
410 ("write this content to disk", "file_write"),
411 ("create a new file with this data", "file_write"),
412 ("convert PDF to markdown", "file_convert"),
413 ("change this file format", "file_convert"),
414 ("export as different format", "file_convert"),
415 ("compare these two files", "file_compare"),
416 ("what's different between these documents", "file_compare"),
417 ("find differences in these files", "file_compare"),
418
419 ("make an API request to this URL", "network_request"),
421 ("call this REST endpoint", "network_request"),
422 ("send HTTP request", "network_request"),
423 ("download this file from the internet", "network_download"),
424 ("fetch data from this URL", "network_download"),
425 ("retrieve file from web", "network_download"),
426 ("check if this website is up", "network_monitor"),
427 ("monitor API endpoint", "network_monitor"),
428 ("test connectivity to server", "network_monitor"),
429
430 ("extract text from this document", "extraction"),
432 ("pull out specific information", "extraction"),
433 ("get the important parts from this", "extraction"),
434 ("validate this data against schema", "validation"),
435 ("check if this data is correct", "validation"),
436 ("verify the format of this file", "validation"),
437 ("generate a report from this data", "generation"),
438 ("create summary of this information", "generation"),
439 ("produce documentation", "generation"),
440 ("classify this content", "classification"),
441 ("categorize this data", "classification"),
442 ("determine the type of this file", "classification"),
443
444 ("analyze this code for issues", "code_analyze"),
446 ("review this source code", "code_analyze"),
447 ("check code quality", "code_analyze"),
448 ("process this text document", "text_process"),
449 ("clean up this text", "text_process"),
450 ("parse natural language", "text_process"),
451 ]
452 }
453
454 fn preprocess_text(&self, text: &str) -> String {
456 text.to_lowercase()
457 .chars()
458 .filter(|c| c.is_alphanumeric() || c.is_whitespace())
459 .collect::<String>()
460 .split_whitespace()
461 .collect::<Vec<_>>()
462 .join(" ")
463 }
464
465 async fn extract_text_features(&self, text: &str) -> Result<Vec<f64>> {
467 let mut features = vec![0.0; self.config.feature_dimensions];
468
469 let words: Vec<&str> = text.split_whitespace().collect();
470 let word_count = words.len() as f64;
471
472 if word_count == 0.0 {
473 return Ok(features);
474 }
475
476 for word in words {
478 if let Some(index) = self.vocabulary.get(word) {
479 if *index < features.len() {
480 features[*index] += 1.0 / word_count;
481 }
482 }
483 }
484
485 Ok(features)
486 }
487
488 fn extract_context_features(&self, text: &str) -> Vec<f64> {
490 vec![
491 text.len() as f64 / 100.0, text.split_whitespace().count() as f64 / 20.0, if text.contains('?') { 1.0 } else { 0.0 }, if text.contains("file") { 1.0 } else { 0.0 }, if text.contains("data") { 1.0 } else { 0.0 }, ]
497 }
498
499 fn cosine_similarity(&self, a: &[f64], b: &[f64]) -> f64 {
501 if a.len() != b.len() {
502 return 0.0;
503 }
504
505 let dot_product: f64 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
506 let norm_a: f64 = a.iter().map(|x| x * x).sum::<f64>().sqrt();
507 let norm_b: f64 = b.iter().map(|x| x * x).sum::<f64>().sqrt();
508
509 if norm_a == 0.0 || norm_b == 0.0 {
510 0.0
511 } else {
512 dot_product / (norm_a * norm_b)
513 }
514 }
515
516 fn calculate_context_boost(&self, intent: &IntentId, features: &FeatureVector) -> f64 {
518 let mut boost = 0.0;
519
520 if intent.0.contains("file") && features.context_features.get(3).unwrap_or(&0.0) > &0.0 {
521 boost += 0.1;
522 }
523
524 if intent.0.contains("data") && features.context_features.get(4).unwrap_or(&0.0) > &0.0 {
525 boost += 0.1;
526 }
527
528 boost
529 }
530
531 async fn rule_based_classification(&self, _features: &FeatureVector) -> HashMap<IntentId, f64> {
533 let mut scores = HashMap::new();
534 scores.insert(IntentId::from("general_processing"), 0.5);
535 scores
536 }
537
538 async fn update_intent_patterns(&self, intent: &IntentId, text: &str) -> Result<()> {
540 self.intent_patterns
541 .entry(intent.clone())
542 .or_insert_with(Vec::new)
543 .push(text.to_string());
544 Ok(())
545 }
546
547 async fn update_vocabulary(&self, text: &str) {
549 for word in text.split_whitespace() {
550 let vocab_len = self.vocabulary.len();
551 if vocab_len < self.config.max_vocabulary_size && !self.vocabulary.contains_key(word) {
552 self.vocabulary.insert(word.to_string(), vocab_len);
553 }
554 }
555 }
556
557 async fn should_retrain(&self) -> bool {
559 let training_data = self.training_data.read().await;
560 let feedback_count = training_data
561 .iter()
562 .filter(|example| matches!(example.source, TrainingSource::UserFeedback))
563 .count();
564
565 feedback_count >= self.config.retraining_threshold
566 }
567
568 async fn retrain(&self) -> Result<()> {
570 if self.config.debug_mode {
571 info!("Retraining intent classification model");
572 }
573
574 let training_data = self.training_data.read().await;
578
579 self.vocabulary.clear();
581 for example in training_data.iter() {
582 self.update_vocabulary(&example.text).await;
583 }
584
585 if self.config.debug_mode {
586 info!("Model retraining completed. Vocabulary size: {}", self.vocabulary.len());
587 }
588
589 Ok(())
590 }
591}
592
593impl Default for IntentClassifier {
594 fn default() -> Self {
595 Self {
598 training_data: Arc::new(RwLock::new(Vec::new())),
599 vocabulary: Arc::new(DashMap::with_hasher(RandomState::new())),
600 intent_patterns: Arc::new(DashMap::with_hasher(RandomState::new())),
601 config: ClassifierConfig::default(),
602 }
603 }
604}
605
606#[cfg(test)]
607mod tests {
608 use super::*;
609
610 #[tokio::test]
611 async fn test_intent_classification() {
612 let classifier = IntentClassifier::new().await.unwrap();
613
614 let prediction = classifier
615 .predict_intent("merge these JSON files together")
616 .await
617 .unwrap();
618
619 assert_eq!(prediction.intent.0, "data_merge");
620 assert!(prediction.confidence.value() > 0.5);
621 }
622
623 #[tokio::test]
624 async fn test_feedback_learning() {
625 let classifier = IntentClassifier::new().await.unwrap();
626
627 let feedback = IntentFeedback {
628 text: "combine data files".to_string(),
629 predicted_intent: IntentId::from("data_transform"),
630 actual_intent: IntentId::from("data_merge"),
631 satisfaction_score: 5.0,
632 notes: None,
633 timestamp: chrono::Utc::now(),
634 };
635
636 classifier.add_feedback(feedback).await.unwrap();
637
638 let stats = classifier.get_stats().await;
639 assert!(stats.feedback_examples > 0);
640 }
641
642 #[tokio::test]
643 async fn test_training_data_export_import() {
644 let classifier = IntentClassifier::new().await.unwrap();
645
646 let example = TrainingExample {
647 text: "test example".to_string(),
648 intent: IntentId::from("test_intent"),
649 confidence: 0.9,
650 source: TrainingSource::Programmatic,
651 };
652
653 classifier.add_training_example(example).await.unwrap();
654
655 let exported = classifier.export_training_data().await.unwrap();
656
657 let new_classifier = IntentClassifier::new().await.unwrap();
658 new_classifier.import_training_data(&exported).await.unwrap();
659
660 let stats = new_classifier.get_stats().await;
661 assert!(stats.training_examples > 0);
662 }
663}
664
665#[cfg(test)]
667impl IntentClassifier {
668 pub fn test_cosine_similarity(&self, a: &[f64], b: &[f64]) -> f64 {
670 self.cosine_similarity(a, b)
671 }
672
673 pub fn test_extract_context_features(&self, text: &str) -> Vec<f64> {
675 self.extract_context_features(text)
676 }
677
678 pub fn test_preprocess_text(&self, text: &str) -> String {
680 self.preprocess_text(text)
681 }
682}