intent_classifier/
classifier.rs

1//! Intent Classification Library
2//!
3//! This module provides the main `IntentClassifier` struct and its implementation.
4//! The classifier uses a combination of feature-based machine learning and rule-based
5//! approaches to classify user intents from natural language text.
6
7use crate::types::*;
8use ahash::RandomState;
9use dashmap::DashMap;
10use std::collections::HashMap;
11use std::sync::Arc;
12use tokio::sync::RwLock;
13use tracing::{debug, info};
14use uuid::Uuid;
15
16/// Main intent classifier that handles training and prediction
17#[derive(Clone)]
18pub struct IntentClassifier {
19    /// Training data storage
20    training_data: Arc<RwLock<Vec<TrainingExample>>>,
21    
22    /// Vocabulary mapping words to indices
23    vocabulary: Arc<DashMap<String, usize, RandomState>>,
24    
25    /// Intent patterns for matching
26    intent_patterns: Arc<DashMap<IntentId, Vec<String>, RandomState>>,
27    
28    /// Configuration for the classifier
29    config: ClassifierConfig,
30}
31
32impl IntentClassifier {
33    /// Create a new intent classifier with default configuration
34    pub async fn new() -> Result<Self> {
35        Self::with_config(ClassifierConfig::default()).await
36    }
37    
38    /// Create a new intent classifier with custom configuration
39    pub async fn with_config(config: ClassifierConfig) -> Result<Self> {
40        let classifier = Self {
41            training_data: Arc::new(RwLock::new(Vec::new())),
42            vocabulary: Arc::new(DashMap::with_hasher(RandomState::new())),
43            intent_patterns: Arc::new(DashMap::with_hasher(RandomState::new())),
44            config,
45        };
46        
47        // Load bootstrap data
48        classifier.load_bootstrap_data().await?;
49        
50        if classifier.config.debug_mode {
51            info!("Intent classifier initialized with {} dimensions", classifier.config.feature_dimensions);
52        }
53        
54        Ok(classifier)
55    }
56    
57    /// Predict intent from natural language text
58    pub async fn predict_intent(&self, text: &str) -> Result<IntentPrediction> {
59        let start_time = std::time::Instant::now();
60        
61        if self.config.debug_mode {
62            debug!("Classifying intent for text: '{}'", text);
63        }
64        
65        // Check for exact matches first (for high-confidence bootstrap cases)
66        if let Some(exact_match) = self.find_exact_match(text).await? {
67            return Ok(exact_match);
68        }
69        
70        // Extract features from the text
71        let features = self.extract_features(text).await?;
72        
73        // Calculate scores for all known intents
74        let intent_scores = self.calculate_intent_scores(&features).await?;
75        
76        // Find the best intent
77        let (best_intent, best_confidence) = self.find_best_intent(&intent_scores)?;
78        
79        // Get alternative intents
80        let alternative_intents = self.get_alternative_intents(&intent_scores, &best_intent);
81        
82        // Generate reasoning
83        let reasoning = self.generate_reasoning(text, &best_intent, &features).await;
84        
85        let prediction = IntentPrediction {
86            intent: best_intent,
87            confidence: best_confidence,
88            alternative_intents,
89            reasoning,
90        };
91        
92        if self.config.debug_mode {
93            let elapsed = start_time.elapsed();
94            info!("Intent prediction: {} (confidence: {:.3}, time: {:?})", 
95                  prediction.intent, prediction.confidence.value(), elapsed);
96        }
97        
98        Ok(prediction)
99    }
100    
101    /// Classify text with additional request options
102    pub async fn classify(&self, request: ClassificationRequest) -> Result<ClassificationResponse> {
103        let start_time = std::time::Instant::now();
104        let request_id = Uuid::new_v4();
105        
106        let mut prediction = self.predict_intent(&request.text).await?;
107        
108        // Filter response based on request options
109        if !request.include_alternatives {
110            prediction.alternative_intents.clear();
111        }
112        
113        if !request.include_reasoning {
114            prediction.reasoning = String::new();
115        }
116        
117        let processing_time_ms = start_time.elapsed().as_millis() as f64;
118        
119        Ok(ClassificationResponse {
120            prediction,
121            processing_time_ms,
122            request_id,
123        })
124    }
125    
126    /// Add a training example
127    pub async fn add_training_example(&self, example: TrainingExample) -> Result<()> {
128        // Validate the example
129        if example.text.trim().is_empty() {
130            return Err(IntentError::InvalidParameter {
131                parameter: "text".to_string(),
132                message: "Training example text cannot be empty".to_string(),
133            });
134        }
135        
136        if !(0.0..=1.0).contains(&example.confidence) {
137            return Err(IntentError::InvalidParameter {
138                parameter: "confidence".to_string(),
139                message: format!("Confidence must be between 0.0 and 1.0, got {}", example.confidence),
140            });
141        }
142        
143        // Add to training data
144        {
145            let mut training_data = self.training_data.write().await;
146            training_data.push(example.clone());
147        }
148        
149        // Update patterns
150        self.update_intent_patterns(&example.intent, &example.text).await?;
151        
152        // Update vocabulary
153        self.update_vocabulary(&example.text).await;
154        
155        if self.config.debug_mode {
156            info!("Added training example: '{}' -> {}", example.text, example.intent);
157        }
158        
159        Ok(())
160    }
161    
162    /// Add user feedback to improve the classifier
163    pub async fn add_feedback(&self, feedback: IntentFeedback) -> Result<()> {
164        if self.config.debug_mode {
165            info!("Adding feedback: '{}' -> {} (predicted: {}, satisfaction: {})", 
166                  feedback.text, feedback.actual_intent, feedback.predicted_intent, feedback.satisfaction_score);
167        }
168        
169        // Convert feedback to training example
170        let confidence = feedback.satisfaction_score / 5.0; // Normalize to 0-1
171        let example = TrainingExample {
172            text: feedback.text,
173            intent: feedback.actual_intent,
174            confidence,
175            source: TrainingSource::UserFeedback,
176        };
177        
178        self.add_training_example(example).await?;
179        
180        // Check if retraining is needed
181        if self.should_retrain().await {
182            self.retrain().await?;
183        }
184        
185        Ok(())
186    }
187    
188    /// Get classifier statistics
189    pub async fn get_stats(&self) -> ClassifierStats {
190        let training_data = self.training_data.read().await;
191        
192        ClassifierStats {
193            training_examples: training_data.len(),
194            vocabulary_size: self.vocabulary.len(),
195            intent_count: self.intent_patterns.len(),
196            feedback_examples: training_data
197                .iter()
198                .filter(|e| matches!(e.source, TrainingSource::UserFeedback))
199                .count(),
200            last_updated: Some(chrono::Utc::now()),
201        }
202    }
203    
204    /// Export training data as JSON
205    pub async fn export_training_data(&self) -> Result<String> {
206        let training_data = self.training_data.read().await;
207        serde_json::to_string_pretty(&*training_data)
208            .map_err(IntentError::SerializationError)
209    }
210    
211    /// Import training data from JSON
212    pub async fn import_training_data(&self, json_data: &str) -> Result<()> {
213        let examples: Vec<TrainingExample> = serde_json::from_str(json_data)
214            .map_err(IntentError::SerializationError)?;
215        
216        for example in examples {
217            self.add_training_example(example).await?;
218        }
219        
220        Ok(())
221    }
222    
223    /// Clear all training data
224    pub async fn clear_training_data(&self) -> Result<()> {
225        {
226            let mut training_data = self.training_data.write().await;
227            training_data.clear();
228        }
229        
230        self.vocabulary.clear();
231        self.intent_patterns.clear();
232        
233        // Reload bootstrap data
234        self.load_bootstrap_data().await?;
235        
236        if self.config.debug_mode {
237            info!("Cleared all training data and reloaded bootstrap data");
238        }
239        
240        Ok(())
241    }
242    
243    /// Find exact match in training data
244    async fn find_exact_match(&self, text: &str) -> Result<Option<IntentPrediction>> {
245        let training_data = self.training_data.read().await;
246        
247        for example in training_data.iter() {
248            if example.text == text {
249                let confidence = Confidence::new(example.confidence)
250                    .unwrap_or_else(|_| Confidence::default());
251                
252                return Ok(Some(IntentPrediction {
253                    intent: example.intent.clone(),
254                    confidence,
255                    alternative_intents: vec![],
256                    reasoning: "Exact match found in training data".to_string(),
257                }));
258            }
259        }
260        
261        Ok(None)
262    }
263    
264    /// Extract features from text
265    async fn extract_features(&self, text: &str) -> Result<FeatureVector> {
266        let cleaned_text = self.preprocess_text(text);
267        
268        // Extract text features using simple bag-of-words approach
269        let text_features = self.extract_text_features(&cleaned_text).await?;
270        
271        // Extract context features
272        let context_features = self.extract_context_features(&cleaned_text);
273        
274        // Create metadata
275        let mut metadata = HashMap::new();
276        metadata.insert("text_length".to_string(), cleaned_text.len() as f64);
277        metadata.insert("word_count".to_string(), cleaned_text.split_whitespace().count() as f64);
278        
279        Ok(FeatureVector {
280            text_features,
281            context_features,
282            metadata,
283        })
284    }
285    
286    /// Calculate intent scores for the given features
287    async fn calculate_intent_scores(&self, features: &FeatureVector) -> Result<HashMap<IntentId, f64>> {
288        let mut scores = HashMap::new();
289        
290        for entry in self.intent_patterns.iter() {
291            let (intent, pattern_texts) = (entry.key(), entry.value());
292            let mut intent_score: f64 = 0.0;
293            
294            // Calculate similarity to known patterns
295            for pattern_text in pattern_texts {
296                let pattern_features = self.extract_text_features(pattern_text).await?;
297                let similarity = self.cosine_similarity(&features.text_features, &pattern_features);
298                intent_score = intent_score.max(similarity);
299            }
300            
301            // Add context boost
302            intent_score += self.calculate_context_boost(intent, features);
303            
304            scores.insert(intent.clone(), intent_score.min(1.0));
305        }
306        
307        // Apply rule-based fallback if no good matches
308        if scores.values().all(|&score| score < self.config.min_confidence_threshold) {
309            let fallback_scores = self.rule_based_classification(features).await;
310            for (intent, score) in fallback_scores {
311                scores.entry(intent).or_insert(score);
312            }
313        }
314        
315        Ok(scores)
316    }
317    
318    /// Find the best intent from scores
319    fn find_best_intent(&self, scores: &HashMap<IntentId, f64>) -> Result<(IntentId, Confidence)> {
320        let (best_intent, best_score) = scores
321            .iter()
322            .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
323            .ok_or_else(|| IntentError::ClassificationFailed("No intents found".to_string()))?;
324        
325        let confidence = Confidence::new(*best_score)
326            .unwrap_or_else(|_| Confidence::default());
327        
328        Ok((best_intent.clone(), confidence))
329    }
330    
331    /// Get alternative intents from scores
332    fn get_alternative_intents(&self, scores: &HashMap<IntentId, f64>, best_intent: &IntentId) -> Vec<(IntentId, Confidence)> {
333        let mut alternatives: Vec<(IntentId, Confidence)> = scores
334            .iter()
335            .filter(|(intent, _)| *intent != best_intent)
336            .filter_map(|(intent, score)| {
337                Confidence::new(*score)
338                    .ok()
339                    .map(|confidence| (intent.clone(), confidence))
340            })
341            .collect();
342        
343        // Sort by confidence descending
344        alternatives.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
345        alternatives.truncate(3); // Keep top 3
346        
347        alternatives
348    }
349    
350    /// Generate human-readable reasoning
351    async fn generate_reasoning(&self, _text: &str, intent: &IntentId, features: &FeatureVector) -> String {
352        if let Some(intent_patterns) = self.intent_patterns.get(intent) {
353            if let Some(best_pattern) = intent_patterns.first() {
354                return format!(
355                    "Classified as '{}' based on similarity to pattern: '{}' (using {} text features)",
356                    intent, best_pattern, features.text_features.len()
357                );
358            }
359        }
360        
361        format!("Classified as '{}' using rule-based analysis", intent)
362    }
363    
364    /// Load bootstrap training data
365    async fn load_bootstrap_data(&self) -> Result<()> {
366        let bootstrap_examples = self.get_bootstrap_examples();
367        
368        for (text, intent_str) in bootstrap_examples {
369            let example = TrainingExample {
370                text: text.to_string(),
371                intent: IntentId::from(intent_str),
372                confidence: 1.0,
373                source: TrainingSource::Bootstrap,
374            };
375            
376            self.add_training_example(example).await?;
377        }
378        
379        if self.config.debug_mode {
380            info!("Loaded {} bootstrap training examples", self.get_bootstrap_examples().len());
381        }
382        
383        Ok(())
384    }
385    
386    /// Get bootstrap examples
387    fn get_bootstrap_examples(&self) -> Vec<(&'static str, &'static str)> {
388        vec![
389            // Data operations
390            ("merge these JSON files together", "data_merge"),
391            ("combine multiple JSON documents", "data_merge"),
392            ("join several data files into one", "data_merge"),
393            ("consolidate JSON objects", "data_merge"),
394            ("split this large JSON file", "data_split"),
395            ("break apart this data into smaller pieces", "data_split"),
396            ("divide this file into multiple parts", "data_split"),
397            ("convert JSON to CSV format", "data_transform"),
398            ("transform this data structure", "data_transform"),
399            ("change the format of this file", "data_transform"),
400            ("analyze this dataset for patterns", "data_analyze"),
401            ("examine the data for insights", "data_analyze"),
402            ("what trends do you see in this data", "data_analyze"),
403            ("give me statistics about this data", "data_analyze"),
404            
405            // File operations
406            ("read the contents of this file", "file_read"),
407            ("load this document", "file_read"),
408            ("open and parse this file", "file_read"),
409            ("save this data to a file", "file_write"),
410            ("write this content to disk", "file_write"),
411            ("create a new file with this data", "file_write"),
412            ("convert PDF to markdown", "file_convert"),
413            ("change this file format", "file_convert"),
414            ("export as different format", "file_convert"),
415            ("compare these two files", "file_compare"),
416            ("what's different between these documents", "file_compare"),
417            ("find differences in these files", "file_compare"),
418            
419            // Network operations
420            ("make an API request to this URL", "network_request"),
421            ("call this REST endpoint", "network_request"),
422            ("send HTTP request", "network_request"),
423            ("download this file from the internet", "network_download"),
424            ("fetch data from this URL", "network_download"),
425            ("retrieve file from web", "network_download"),
426            ("check if this website is up", "network_monitor"),
427            ("monitor API endpoint", "network_monitor"),
428            ("test connectivity to server", "network_monitor"),
429            
430            // Processing operations
431            ("extract text from this document", "extraction"),
432            ("pull out specific information", "extraction"),
433            ("get the important parts from this", "extraction"),
434            ("validate this data against schema", "validation"),
435            ("check if this data is correct", "validation"),
436            ("verify the format of this file", "validation"),
437            ("generate a report from this data", "generation"),
438            ("create summary of this information", "generation"),
439            ("produce documentation", "generation"),
440            ("classify this content", "classification"),
441            ("categorize this data", "classification"),
442            ("determine the type of this file", "classification"),
443            
444            // Code operations
445            ("analyze this code for issues", "code_analyze"),
446            ("review this source code", "code_analyze"),
447            ("check code quality", "code_analyze"),
448            ("process this text document", "text_process"),
449            ("clean up this text", "text_process"),
450            ("parse natural language", "text_process"),
451        ]
452    }
453    
454    /// Preprocess text for feature extraction
455    fn preprocess_text(&self, text: &str) -> String {
456        text.to_lowercase()
457            .chars()
458            .filter(|c| c.is_alphanumeric() || c.is_whitespace())
459            .collect::<String>()
460            .split_whitespace()
461            .collect::<Vec<_>>()
462            .join(" ")
463    }
464    
465    /// Extract text features using bag-of-words
466    async fn extract_text_features(&self, text: &str) -> Result<Vec<f64>> {
467        let mut features = vec![0.0; self.config.feature_dimensions];
468        
469        let words: Vec<&str> = text.split_whitespace().collect();
470        let word_count = words.len() as f64;
471        
472        if word_count == 0.0 {
473            return Ok(features);
474        }
475        
476        // Simple term frequency approach
477        for word in words {
478            if let Some(index) = self.vocabulary.get(word) {
479                if *index < features.len() {
480                    features[*index] += 1.0 / word_count;
481                }
482            }
483        }
484        
485        Ok(features)
486    }
487    
488    /// Extract context features
489    fn extract_context_features(&self, text: &str) -> Vec<f64> {
490        vec![
491            text.len() as f64 / 100.0, // Normalized text length
492            text.split_whitespace().count() as f64 / 20.0, // Normalized word count
493            if text.contains('?') { 1.0 } else { 0.0 }, // Question indicator
494            if text.contains("file") { 1.0 } else { 0.0 }, // File operation indicator
495            if text.contains("data") { 1.0 } else { 0.0 }, // Data operation indicator
496        ]
497    }
498    
499    /// Calculate cosine similarity between feature vectors
500    fn cosine_similarity(&self, a: &[f64], b: &[f64]) -> f64 {
501        if a.len() != b.len() {
502            return 0.0;
503        }
504        
505        let dot_product: f64 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
506        let norm_a: f64 = a.iter().map(|x| x * x).sum::<f64>().sqrt();
507        let norm_b: f64 = b.iter().map(|x| x * x).sum::<f64>().sqrt();
508        
509        if norm_a == 0.0 || norm_b == 0.0 {
510            0.0
511        } else {
512            dot_product / (norm_a * norm_b)
513        }
514    }
515    
516    /// Calculate context boost for intent scores
517    fn calculate_context_boost(&self, intent: &IntentId, features: &FeatureVector) -> f64 {
518        let mut boost = 0.0;
519        
520        if intent.0.contains("file") && features.context_features.get(3).unwrap_or(&0.0) > &0.0 {
521            boost += 0.1;
522        }
523        
524        if intent.0.contains("data") && features.context_features.get(4).unwrap_or(&0.0) > &0.0 {
525            boost += 0.1;
526        }
527        
528        boost
529    }
530    
531    /// Rule-based classification fallback
532    async fn rule_based_classification(&self, _features: &FeatureVector) -> HashMap<IntentId, f64> {
533        let mut scores = HashMap::new();
534        scores.insert(IntentId::from("general_processing"), 0.5);
535        scores
536    }
537    
538    /// Update intent patterns with new text
539    async fn update_intent_patterns(&self, intent: &IntentId, text: &str) -> Result<()> {
540        self.intent_patterns
541            .entry(intent.clone())
542            .or_insert_with(Vec::new)
543            .push(text.to_string());
544        Ok(())
545    }
546    
547    /// Update vocabulary with new words
548    async fn update_vocabulary(&self, text: &str) {
549        for word in text.split_whitespace() {
550            let vocab_len = self.vocabulary.len();
551            if vocab_len < self.config.max_vocabulary_size && !self.vocabulary.contains_key(word) {
552                self.vocabulary.insert(word.to_string(), vocab_len);
553            }
554        }
555    }
556    
557    /// Check if model should be retrained
558    async fn should_retrain(&self) -> bool {
559        let training_data = self.training_data.read().await;
560        let feedback_count = training_data
561            .iter()
562            .filter(|example| matches!(example.source, TrainingSource::UserFeedback))
563            .count();
564        
565        feedback_count >= self.config.retraining_threshold
566    }
567    
568    /// Retrain the model
569    async fn retrain(&self) -> Result<()> {
570        if self.config.debug_mode {
571            info!("Retraining intent classification model");
572        }
573        
574        // For now, just rebuild vocabulary and patterns
575        // In a more sophisticated implementation, this would retrain ML models
576        
577        let training_data = self.training_data.read().await;
578        
579        // Clear and rebuild vocabulary
580        self.vocabulary.clear();
581        for example in training_data.iter() {
582            self.update_vocabulary(&example.text).await;
583        }
584        
585        if self.config.debug_mode {
586            info!("Model retraining completed. Vocabulary size: {}", self.vocabulary.len());
587        }
588        
589        Ok(())
590    }
591}
592
593impl Default for IntentClassifier {
594    fn default() -> Self {
595        // Note: This can't be async, so we use a synchronous version
596        // In practice, users should use `IntentClassifier::new().await`
597        Self {
598            training_data: Arc::new(RwLock::new(Vec::new())),
599            vocabulary: Arc::new(DashMap::with_hasher(RandomState::new())),
600            intent_patterns: Arc::new(DashMap::with_hasher(RandomState::new())),
601            config: ClassifierConfig::default(),
602        }
603    }
604}
605
606#[cfg(test)]
607mod tests {
608    use super::*;
609    
610    #[tokio::test]
611    async fn test_intent_classification() {
612        let classifier = IntentClassifier::new().await.unwrap();
613        
614        let prediction = classifier
615            .predict_intent("merge these JSON files together")
616            .await
617            .unwrap();
618        
619        assert_eq!(prediction.intent.0, "data_merge");
620        assert!(prediction.confidence.value() > 0.5);
621    }
622    
623    #[tokio::test]
624    async fn test_feedback_learning() {
625        let classifier = IntentClassifier::new().await.unwrap();
626        
627        let feedback = IntentFeedback {
628            text: "combine data files".to_string(),
629            predicted_intent: IntentId::from("data_transform"),
630            actual_intent: IntentId::from("data_merge"),
631            satisfaction_score: 5.0,
632            notes: None,
633            timestamp: chrono::Utc::now(),
634        };
635        
636        classifier.add_feedback(feedback).await.unwrap();
637        
638        let stats = classifier.get_stats().await;
639        assert!(stats.feedback_examples > 0);
640    }
641    
642    #[tokio::test]
643    async fn test_training_data_export_import() {
644        let classifier = IntentClassifier::new().await.unwrap();
645        
646        let example = TrainingExample {
647            text: "test example".to_string(),
648            intent: IntentId::from("test_intent"),
649            confidence: 0.9,
650            source: TrainingSource::Programmatic,
651        };
652        
653        classifier.add_training_example(example).await.unwrap();
654        
655        let exported = classifier.export_training_data().await.unwrap();
656        
657        let new_classifier = IntentClassifier::new().await.unwrap();
658        new_classifier.import_training_data(&exported).await.unwrap();
659        
660        let stats = new_classifier.get_stats().await;
661        assert!(stats.training_examples > 0);
662    }
663}
664
665/// Test helper methods - only available in test builds
666#[cfg(test)]
667impl IntentClassifier {
668    /// Test helper to access cosine_similarity method
669    pub fn test_cosine_similarity(&self, a: &[f64], b: &[f64]) -> f64 {
670        self.cosine_similarity(a, b)
671    }
672    
673    /// Test helper to access extract_context_features method  
674    pub fn test_extract_context_features(&self, text: &str) -> Vec<f64> {
675        self.extract_context_features(text)
676    }
677    
678    /// Test helper to access preprocess_text method
679    pub fn test_preprocess_text(&self, text: &str) -> String {
680        self.preprocess_text(text)
681    }
682}