mockforge_data/
drift_learning.rs

1//! Drift Learning System
2//!
3//! This module extends the DataDriftEngine with learning capabilities that allow
4//! mocks to gradually learn from recorded traffic and adapt their behavior.
5//!
6//! Features:
7//! - Traffic pattern learning from recorded requests
8//! - Persona behavior adaptation based on request patterns
9//! - Configurable learning rate and sensitivity
10//! - Opt-in per persona/endpoint learning
11
12use crate::drift::{DataDriftConfig, DataDriftEngine};
13use crate::Result;
14use serde::{Deserialize, Serialize};
15use serde_json::Value;
16use std::collections::HashMap;
17use std::sync::Arc;
18use std::time::Duration;
19use tokio::sync::RwLock;
20
21/// Learning configuration
22#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct LearningConfig {
24    /// Enable drift learning
25    #[serde(default)]
26    pub enabled: bool,
27
28    /// Learning mode
29    #[serde(default)]
30    pub mode: LearningMode,
31
32    /// Learning rate (0.0 to 1.0) - how quickly mocks learn from patterns
33    #[serde(default = "default_learning_rate")]
34    pub sensitivity: f64,
35
36    /// Decay rate (0.0 to 1.0) - drift resets if upstream patterns reverse
37    #[serde(default = "default_decay_rate")]
38    pub decay: f64,
39
40    /// Minimum number of samples before learning starts
41    #[serde(default = "default_min_samples")]
42    pub min_samples: usize,
43
44    /// Update interval for learning
45    #[serde(default = "default_update_interval")]
46    pub update_interval: Duration,
47
48    /// Enable persona adaptation
49    #[serde(default = "default_true")]
50    pub persona_adaptation: bool,
51
52    /// Enable traffic pattern mirroring
53    #[serde(default = "default_true")]
54    pub traffic_mirroring: bool,
55
56    /// Per-endpoint opt-in learning (endpoint pattern -> enabled)
57    #[serde(default)]
58    pub endpoint_learning: HashMap<String, bool>,
59
60    /// Per-persona opt-in learning (persona_id -> enabled)
61    #[serde(default)]
62    pub persona_learning: HashMap<String, bool>,
63}
64
65fn default_learning_rate() -> f64 {
66    0.2 // 20% learning rate - conservative default
67}
68
69fn default_decay_rate() -> f64 {
70    0.05 // 5% decay rate
71}
72
73fn default_min_samples() -> usize {
74    10 // Need at least 10 samples before learning
75}
76
77fn default_update_interval() -> Duration {
78    Duration::from_secs(60) // Update every minute
79}
80
81fn default_true() -> bool {
82    true
83}
84
85impl Default for LearningConfig {
86    fn default() -> Self {
87        Self {
88            enabled: false, // Opt-in by default
89            mode: LearningMode::Behavioral,
90            sensitivity: default_learning_rate(),
91            decay: default_decay_rate(),
92            min_samples: default_min_samples(),
93            update_interval: default_update_interval(),
94            persona_adaptation: true,
95            traffic_mirroring: true,
96            endpoint_learning: HashMap::new(),
97            persona_learning: HashMap::new(),
98        }
99    }
100}
101
102/// Learning mode
103#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)]
104#[serde(rename_all = "snake_case")]
105pub enum LearningMode {
106    /// Behavioral learning - adapts to behavior patterns
107    #[default]
108    Behavioral,
109    /// Statistical learning - adapts to statistical patterns
110    Statistical,
111    /// Hybrid - combines behavioral and statistical
112    Hybrid,
113}
114
115/// Drift Learning Engine
116///
117/// Extends DataDriftEngine with learning capabilities from recorded traffic.
118pub struct DriftLearningEngine {
119    /// Base drift engine
120    drift_engine: DataDriftEngine,
121    /// Learning configuration
122    learning_config: LearningConfig,
123    /// Traffic pattern learner
124    traffic_learner: Option<Arc<RwLock<TrafficPatternLearner>>>,
125    /// Persona behavior learner
126    persona_learner: Option<Arc<RwLock<PersonaBehaviorLearner>>>,
127    /// Learned patterns cache
128    learned_patterns: Arc<RwLock<HashMap<String, LearnedPattern>>>,
129}
130
131/// Learned pattern from traffic analysis
132#[derive(Debug, Clone)]
133pub struct LearnedPattern {
134    /// Pattern identifier
135    pub pattern_id: String,
136    /// Pattern type
137    pub pattern_type: PatternType,
138    /// Learned parameters
139    pub parameters: HashMap<String, Value>,
140    /// Confidence score (0.0 to 1.0)
141    pub confidence: f64,
142    /// Sample count used for learning
143    pub sample_count: usize,
144    /// Last updated timestamp
145    pub last_updated: chrono::DateTime<chrono::Utc>,
146}
147
148/// Pattern type
149#[derive(Debug, Clone, PartialEq, Eq)]
150pub enum PatternType {
151    /// Latency pattern
152    Latency,
153    /// Error rate pattern
154    ErrorRate,
155    /// Request sequence pattern
156    RequestSequence,
157    /// Persona behavior pattern
158    PersonaBehavior,
159}
160
161impl DriftLearningEngine {
162    /// Create a new drift learning engine
163    pub fn new(drift_config: DataDriftConfig, learning_config: LearningConfig) -> Result<Self> {
164        let drift_engine = DataDriftEngine::new(drift_config)?;
165
166        let traffic_learner = if learning_config.traffic_mirroring {
167            Some(Arc::new(RwLock::new(TrafficPatternLearner::new(learning_config.clone())?)))
168        } else {
169            None
170        };
171
172        let persona_learner = if learning_config.persona_adaptation {
173            Some(Arc::new(RwLock::new(PersonaBehaviorLearner::new(learning_config.clone())?)))
174        } else {
175            None
176        };
177
178        Ok(Self {
179            drift_engine,
180            learning_config,
181            traffic_learner,
182            persona_learner,
183            learned_patterns: Arc::new(RwLock::new(HashMap::new())),
184        })
185    }
186
187    /// Get the base drift engine
188    pub fn drift_engine(&self) -> &DataDriftEngine {
189        &self.drift_engine
190    }
191
192    /// Get learning configuration
193    pub fn learning_config(&self) -> &LearningConfig {
194        &self.learning_config
195    }
196
197    /// Update learning configuration
198    pub fn update_learning_config(&mut self, config: LearningConfig) -> Result<()> {
199        self.learning_config = config;
200        Ok(())
201    }
202
203    /// Get learned patterns
204    pub async fn get_learned_patterns(&self) -> HashMap<String, LearnedPattern> {
205        self.learned_patterns.read().await.clone()
206    }
207
208    /// Apply drift with learning
209    pub async fn apply_drift_with_learning(&self, data: Value) -> Result<Value> {
210        // First apply base drift
211        let mut data = self.drift_engine.apply_drift(data).await?;
212
213        // Then apply learned patterns if learning is enabled
214        if !self.learning_config.enabled {
215            return Ok(data);
216        }
217
218        // Apply learned patterns
219        let patterns = self.learned_patterns.read().await;
220        for (pattern_id, pattern) in patterns.iter() {
221            // Check if pattern should be applied based on confidence and decay
222            if pattern.confidence < 0.5 {
223                continue; // Low confidence, skip
224            }
225
226            // Apply pattern based on type
227            match pattern.pattern_type {
228                PatternType::Latency => {
229                    // Latency patterns are handled separately
230                }
231                PatternType::ErrorRate => {
232                    // Error rate patterns are handled separately
233                }
234                PatternType::RequestSequence => {
235                    // Request sequence patterns affect persona behavior
236                }
237                PatternType::PersonaBehavior => {
238                    // Persona behavior patterns affect data generation
239                    if let Some(obj) = data.as_object_mut() {
240                        for (key, value) in &pattern.parameters {
241                            if let Some(existing) = obj.get(key) {
242                                // Blend learned value with existing value
243                                let blended =
244                                    self.blend_values(existing, value, pattern.confidence)?;
245                                obj.insert(key.clone(), blended);
246                            }
247                        }
248                    }
249                }
250            }
251        }
252
253        Ok(data)
254    }
255
256    /// Blend two values based on confidence
257    fn blend_values(&self, existing: &Value, learned: &Value, confidence: f64) -> Result<Value> {
258        // Simple blending: existing * (1 - confidence * sensitivity) + learned * (confidence * sensitivity)
259        let weight = confidence * self.learning_config.sensitivity;
260
261        match (existing, learned) {
262            (Value::Number(existing_num), Value::Number(learned_num)) => {
263                if let (Some(existing_f64), Some(learned_f64)) =
264                    (existing_num.as_f64(), learned_num.as_f64())
265                {
266                    let blended = existing_f64 * (1.0 - weight) + learned_f64 * weight;
267                    Ok(Value::from(blended))
268                } else {
269                    Ok(existing.clone())
270                }
271            }
272            _ => Ok(existing.clone()), // For non-numeric, keep existing
273        }
274    }
275}
276
277/// Traffic Pattern Learner
278///
279/// Analyzes recorded traffic to detect patterns and trends.
280pub struct TrafficPatternLearner {
281    /// Learning configuration
282    config: LearningConfig,
283    /// Pattern window for analysis
284    pattern_window: Duration,
285    /// Detected patterns
286    patterns: HashMap<String, TrafficPattern>,
287}
288
289/// Traffic pattern detected from analysis
290#[derive(Debug, Clone)]
291struct TrafficPattern {
292    /// Pattern identifier
293    pattern_id: String,
294    /// Pattern type
295    pattern_type: PatternType,
296    /// Pattern parameters
297    parameters: HashMap<String, Value>,
298    /// Sample count
299    sample_count: usize,
300    /// First seen timestamp
301    first_seen: chrono::DateTime<chrono::Utc>,
302    /// Last seen timestamp
303    last_seen: chrono::DateTime<chrono::Utc>,
304}
305
306impl TrafficPatternLearner {
307    /// Create a new traffic pattern learner
308    pub fn new(config: LearningConfig) -> Result<Self> {
309        Ok(Self {
310            config,
311            pattern_window: Duration::from_secs(3600), // 1 hour window
312            patterns: HashMap::new(),
313        })
314    }
315
316    /// Analyze traffic patterns from recorded requests
317    ///
318    /// NOTE: This method is disabled to break circular dependencies.
319    /// The recorder integration has been moved to a higher-level crate.
320    pub async fn analyze_traffic_patterns(
321        &mut self,
322        _database: &dyn std::any::Any, // Use Any to avoid dependency on mockforge-recorder
323    ) -> Result<Vec<LearnedPattern>> {
324        // Disabled to break circular dependency
325        Ok(Vec::new())
326    }
327
328    /// Internal method to detect latency patterns from requests
329    ///
330    /// NOTE: This method is disabled to break circular dependency.
331    /// The recorder integration has been moved to a higher-level crate.
332    #[allow(dead_code)]
333    pub async fn detect_latency_patterns_from_requests(
334        &self,
335        _requests: &[serde_json::Value],
336    ) -> Result<Vec<LearnedPattern>> {
337        // Disabled to break circular dependency
338        Ok(Vec::new())
339    }
340
341    // NOTE: The following code is disabled to break circular dependency
342    // Original implementation would process requests here
343    /*
344    fn _detect_latency_patterns_original(
345        &self,
346        requests: &[serde_json::Value],
347    ) -> Result<Vec<LearnedPattern>> {
348        use std::collections::HashMap;
349        use chrono::Utc;
350
351        // Group requests by endpoint/method
352        let mut endpoint_latencies: HashMap<String, Vec<i64>> = HashMap::new();
353
354        for request in requests {
355            if let Some(duration) = request.duration_ms {
356                let key = format!("{} {}", request.method, request.path);
357                endpoint_latencies.entry(key).or_insert_with(Vec::new).push(duration);
358            }
359        }
360
361        let mut patterns = Vec::new();
362
363        for (endpoint_key, latencies) in endpoint_latencies {
364            if latencies.len() < 10 {
365                // Need at least 10 samples for meaningful analysis
366                continue;
367            }
368
369            // Calculate statistics
370            let sum: i64 = latencies.iter().sum();
371            let count = latencies.len();
372            let avg_latency = sum as f64 / count as f64;
373
374            // Calculate percentiles
375            let mut sorted = latencies.clone();
376            sorted.sort();
377            let p50 = sorted[sorted.len() / 2];
378            let p95 = sorted[(sorted.len() * 95) / 100];
379            let p99 = sorted[(sorted.len() * 99) / 100];
380
381            // Detect if latency is increasing (trend analysis)
382            let recent_avg = if latencies.len() >= 20 {
383                let recent: Vec<i64> = latencies.iter().rev().take(10).copied().collect();
384                let recent_sum: i64 = recent.iter().sum();
385                recent_sum as f64 / recent.len() as f64
386            } else {
387                avg_latency
388            };
389
390            let latency_trend = if recent_avg > avg_latency * 1.2 {
391                "increasing"
392            } else if recent_avg < avg_latency * 0.8 {
393                "decreasing"
394            } else {
395                "stable"
396            };
397
398            // Create pattern if there's significant variation or trend
399            if p99 > p50 * 2.0 || latency_trend != "stable" {
400                let mut parameters = HashMap::new();
401                parameters.insert("endpoint".to_string(), serde_json::json!(endpoint_key));
402                parameters.insert("avg_latency_ms".to_string(), serde_json::json!(avg_latency));
403                parameters.insert("p50_ms".to_string(), serde_json::json!(p50));
404                parameters.insert("p95_ms".to_string(), serde_json::json!(p95));
405                parameters.insert("p99_ms".to_string(), serde_json::json!(p99));
406                parameters.insert("sample_count".to_string(), serde_json::json!(count));
407                parameters.insert("trend".to_string(), serde_json::json!(latency_trend));
408
409                // Confidence based on sample size
410                let confidence = (count as f64 / 100.0).min(1.0);
411
412                patterns.push(LearnedPattern {
413                    pattern_id: format!("latency_{}", endpoint_key.replace('/', "_").replace(' ', "_")),
414                    pattern_type: PatternType::Latency,
415                    parameters,
416                    confidence,
417                    sample_count: count,
418                    last_updated: Utc::now(),
419                });
420            }
421        }
422
423        Ok(patterns)
424    }
425    */
426
427    /// Internal method to detect error rate patterns from requests
428    /// NOTE: Disabled to break circular dependency
429    #[allow(dead_code)]
430    async fn detect_error_patterns_internal(
431        &self,
432        _requests: &[serde_json::Value],
433    ) -> Result<Vec<LearnedPattern>> {
434        use chrono::Utc;
435        use std::collections::HashMap;
436
437        // Disabled to break circular dependency
438        let _requests = _requests;
439        let endpoint_errors: HashMap<String, (usize, usize)> = HashMap::new(); // (total, errors)
440
441        // Disabled - would iterate over requests here
442        /*
443        for request in requests {
444            let key = format!("{} {}", request.method, request.path);
445            let entry = endpoint_errors.entry(key).or_insert((0, 0));
446            entry.0 += 1;
447
448            // Consider 4xx and 5xx as errors
449            if let Some(status) = request.status_code {
450                if status >= 400 {
451                    entry.1 += 1;
452                }
453            }
454        }
455        */
456
457        let mut patterns = Vec::new();
458
459        for (endpoint_key, (total, errors)) in endpoint_errors {
460            if total < 20 {
461                // Need at least 20 samples for meaningful analysis
462                continue;
463            }
464
465            let error_rate = errors as f64 / total as f64;
466
467            // Create pattern if error rate is significant (>5%) or increasing
468            if error_rate > 0.05 {
469                let mut parameters = HashMap::new();
470                parameters.insert("endpoint".to_string(), serde_json::json!(endpoint_key));
471                parameters.insert("error_rate".to_string(), serde_json::json!(error_rate));
472                parameters.insert("total_requests".to_string(), serde_json::json!(total));
473                parameters.insert("error_count".to_string(), serde_json::json!(errors));
474
475                // Confidence based on sample size and error rate
476                let confidence = ((total as f64 / 100.0).min(1.0) * error_rate * 10.0).min(1.0);
477
478                patterns.push(LearnedPattern {
479                    pattern_id: format!("error_rate_{}", endpoint_key.replace(['/', ' '], "_")),
480                    pattern_type: PatternType::ErrorRate,
481                    parameters,
482                    confidence,
483                    sample_count: total,
484                    last_updated: Utc::now(),
485                });
486            }
487        }
488
489        Ok(patterns)
490    }
491
492    /// Internal method to detect request sequence patterns
493    /// NOTE: Disabled to break circular dependency
494    #[allow(dead_code)]
495    async fn detect_sequence_patterns_internal(
496        &self,
497        _requests: &[serde_json::Value],
498    ) -> Result<Vec<LearnedPattern>> {
499        use chrono::Utc;
500        use std::collections::HashMap;
501
502        // Disabled to break circular dependency
503        let _requests = _requests;
504        if _requests.len() < 50 {
505            // Need sufficient data for sequence detection
506            return Ok(Vec::new());
507        }
508
509        // Disabled - would process requests here
510        let trace_sequences: HashMap<Option<String>, Vec<String>> = HashMap::new();
511
512        /*
513        for request in requests {
514            let trace_id = request.trace_id.clone();
515            let endpoint_key = format!("{} {}", request.method, request.path);
516            trace_sequences
517                .entry(trace_id)
518                .or_insert_with(Vec::new)
519                .push(endpoint_key);
520        }
521        */
522
523        // Find common sequences (patterns that appear multiple times)
524        let mut sequence_counts: HashMap<String, usize> = HashMap::new();
525
526        for sequence in trace_sequences.values() {
527            if sequence.len() >= 2 {
528                // Create sequence signature (first 3 endpoints)
529                let signature: Vec<String> = sequence.iter().take(3).cloned().collect();
530                let signature_str = signature.join(" -> ");
531                *sequence_counts.entry(signature_str).or_insert(0) += 1;
532            }
533        }
534
535        let mut patterns = Vec::new();
536
537        for (sequence_str, count) in sequence_counts {
538            if count >= 5 {
539                // Pattern appears at least 5 times
540                let mut parameters = HashMap::new();
541                parameters.insert("sequence".to_string(), serde_json::json!(sequence_str));
542                parameters.insert("occurrence_count".to_string(), serde_json::json!(count));
543
544                // Confidence based on occurrence frequency
545                let confidence = (count as f64 / 20.0).min(1.0);
546
547                patterns.push(LearnedPattern {
548                    pattern_id: format!(
549                        "sequence_{}",
550                        sequence_str.replace(['/', ' '], "_").replace("->", "_")
551                    ),
552                    pattern_type: PatternType::RequestSequence,
553                    parameters,
554                    confidence,
555                    sample_count: count,
556                    last_updated: Utc::now(),
557                });
558            }
559        }
560
561        Ok(patterns)
562    }
563
564    /// Detect latency patterns
565    ///
566    /// This method is a convenience wrapper that requires a database.
567    /// Use `analyze_traffic_patterns` with a RecorderDatabase for full analysis.
568    pub async fn detect_latency_patterns(&mut self) -> Result<Vec<LearnedPattern>> {
569        // This method now requires database access - use analyze_traffic_patterns instead
570        Ok(Vec::new())
571    }
572
573    /// Detect error rate patterns
574    ///
575    /// This method is a convenience wrapper that requires a database.
576    /// Use `analyze_traffic_patterns` with a RecorderDatabase for full analysis.
577    pub async fn detect_error_patterns(&mut self) -> Result<Vec<LearnedPattern>> {
578        // This method now requires database access - use analyze_traffic_patterns instead
579        Ok(Vec::new())
580    }
581}
582
583/// Persona Behavior Learner
584///
585/// Adapts persona profiles based on request patterns.
586pub struct PersonaBehaviorLearner {
587    /// Learning configuration
588    config: LearningConfig,
589    /// Behavior history (persona_id -> behavior events)
590    behavior_history: HashMap<String, Vec<BehaviorEvent>>,
591}
592
593/// Behavior event for a persona
594#[derive(Debug, Clone)]
595pub struct BehaviorEvent {
596    /// Event timestamp
597    pub timestamp: chrono::DateTime<chrono::Utc>,
598    /// Event type
599    pub event_type: BehaviorEventType,
600    /// Event data
601    pub data: HashMap<String, Value>,
602}
603
604/// Behavior event type
605#[derive(Debug, Clone, PartialEq, Eq)]
606pub enum BehaviorEventType {
607    /// Request made to an endpoint
608    Request {
609        /// Endpoint path
610        endpoint: String,
611        /// HTTP method
612        method: String,
613    },
614    /// Request failed
615    RequestFailed {
616        /// Endpoint path
617        endpoint: String,
618        /// HTTP status code
619        status_code: u16,
620    },
621    /// Request succeeded after failure
622    RequestSucceededAfterFailure {
623        /// Endpoint path
624        endpoint: String,
625    },
626    /// Pattern detected
627    PatternDetected {
628        /// Pattern identifier
629        pattern: String,
630    },
631}
632
633impl PersonaBehaviorLearner {
634    /// Create a new persona behavior learner
635    pub fn new(config: LearningConfig) -> Result<Self> {
636        Ok(Self {
637            config,
638            behavior_history: HashMap::new(),
639        })
640    }
641
642    /// Record a behavior event for a persona
643    pub fn record_event(&mut self, persona_id: String, event: BehaviorEvent) {
644        if !self.config.enabled {
645            return;
646        }
647
648        // Check if persona learning is enabled for this persona
649        if let Some(&enabled) = self.config.persona_learning.get(&persona_id) {
650            if !enabled {
651                return; // Learning disabled for this persona
652            }
653        }
654
655        let events = self.behavior_history.entry(persona_id).or_default();
656        events.push(event);
657
658        // Keep only recent events (last 1000)
659        if events.len() > 1000 {
660            events.remove(0);
661        }
662    }
663
664    /// Analyze behavior patterns for a persona
665    pub async fn analyze_persona_behavior(
666        &self,
667        persona_id: &str,
668    ) -> Result<Option<LearnedPattern>> {
669        if !self.config.enabled {
670            return Ok(None);
671        }
672
673        let events = match self.behavior_history.get(persona_id) {
674            Some(events) => events,
675            None => return Ok(None),
676        };
677
678        if events.len() < self.config.min_samples {
679            return Ok(None); // Not enough samples
680        }
681
682        // Analyze patterns
683        // Example: If persona repeatedly requests /checkout after failure, learn this pattern
684        let mut checkout_after_failure_count = 0;
685        let mut total_failures = 0;
686
687        for i in 1..events.len() {
688            if let BehaviorEventType::RequestFailed { .. } = events[i - 1].event_type {
689                total_failures += 1;
690                if let BehaviorEventType::Request { endpoint, .. } = &events[i].event_type {
691                    if endpoint.contains("/checkout") {
692                        checkout_after_failure_count += 1;
693                    }
694                }
695            }
696        }
697
698        if total_failures > 0 && checkout_after_failure_count as f64 / total_failures as f64 > 0.5 {
699            // Pattern detected: persona requests /checkout after failure > 50% of the time
700            let mut parameters = HashMap::new();
701            parameters.insert("retry_checkout_after_failure".to_string(), Value::from(true));
702            parameters.insert(
703                "retry_probability".to_string(),
704                Value::from(checkout_after_failure_count as f64 / total_failures as f64),
705            );
706
707            return Ok(Some(LearnedPattern {
708                pattern_id: format!("persona_{}_checkout_retry", persona_id),
709                pattern_type: PatternType::PersonaBehavior,
710                parameters,
711                confidence: (checkout_after_failure_count as f64 / total_failures as f64).min(1.0),
712                sample_count: total_failures,
713                last_updated: chrono::Utc::now(),
714            }));
715        }
716
717        Ok(None)
718    }
719
720    /// Get behavior history for a persona
721    pub fn get_behavior_history(&self, persona_id: &str) -> Option<&Vec<BehaviorEvent>> {
722        self.behavior_history.get(persona_id)
723    }
724
725    /// Apply learned patterns to a persona in PersonaRegistry
726    ///
727    /// This method should be called periodically to update persona profiles
728    /// based on learned behavior patterns.
729    pub async fn apply_learned_patterns_to_persona(
730        &self,
731        persona_id: &str,
732        persona_registry: &crate::PersonaRegistry,
733    ) -> Result<()> {
734        if !self.config.enabled {
735            return Ok(());
736        }
737
738        // Analyze behavior for this persona
739        if let Some(pattern) = self.analyze_persona_behavior(persona_id).await? {
740            // Convert learned pattern parameters to traits
741            let mut learned_traits = std::collections::HashMap::new();
742            for (key, value) in &pattern.parameters {
743                let trait_key = format!("learned_{}", key);
744                let trait_value = if let Some(s) = value.as_str() {
745                    s.to_string()
746                } else if let Some(n) = value.as_f64() {
747                    n.to_string()
748                } else if let Some(b) = value.as_bool() {
749                    b.to_string()
750                } else {
751                    value.to_string()
752                };
753                learned_traits.insert(trait_key, trait_value);
754            }
755
756            // Update persona traits in registry
757            if !learned_traits.is_empty() {
758                persona_registry.update_persona(persona_id, learned_traits)?;
759            }
760        }
761
762        Ok(())
763    }
764}
765
766#[cfg(test)]
767mod tests {
768    use super::*;
769
770    #[test]
771    fn test_learning_config_default() {
772        let config = LearningConfig::default();
773        assert!(!config.enabled); // Opt-in by default
774        assert_eq!(config.sensitivity, 0.2);
775        assert_eq!(config.min_samples, 10);
776    }
777
778    #[test]
779    fn test_drift_learning_engine_creation() {
780        let drift_config = DataDriftConfig::new();
781        let learning_config = LearningConfig::default();
782        let engine = DriftLearningEngine::new(drift_config, learning_config);
783        assert!(engine.is_ok());
784    }
785
786    #[tokio::test]
787    async fn test_persona_behavior_learner() {
788        let config = LearningConfig {
789            enabled: true,
790            persona_adaptation: true,
791            ..Default::default()
792        };
793        let mut learner = PersonaBehaviorLearner::new(config).unwrap();
794
795        // Record failure
796        learner.record_event(
797            "persona-1".to_string(),
798            BehaviorEvent {
799                timestamp: chrono::Utc::now(),
800                event_type: BehaviorEventType::RequestFailed {
801                    endpoint: "/api/checkout".to_string(),
802                    status_code: 500,
803                },
804                data: HashMap::new(),
805            },
806        );
807
808        // Record checkout request after failure
809        learner.record_event(
810            "persona-1".to_string(),
811            BehaviorEvent {
812                timestamp: chrono::Utc::now(),
813                event_type: BehaviorEventType::Request {
814                    endpoint: "/api/checkout".to_string(),
815                    method: "POST".to_string(),
816                },
817                data: HashMap::new(),
818            },
819        );
820
821        // Analyze (won't find pattern with only 2 samples, need min_samples)
822        let pattern = learner.analyze_persona_behavior("persona-1").await.unwrap();
823        assert!(pattern.is_none()); // Not enough samples
824    }
825}