mockforge_data/drift_learning.rs
1//! Drift Learning System
2//!
3//! This module extends the DataDriftEngine with learning capabilities that allow
4//! mocks to gradually learn from recorded traffic and adapt their behavior.
5//!
6//! Features:
7//! - Traffic pattern learning from recorded requests
8//! - Persona behavior adaptation based on request patterns
9//! - Configurable learning rate and sensitivity
10//! - Opt-in per persona/endpoint learning
11
12use crate::drift::{DataDriftConfig, DataDriftEngine};
13use crate::Result;
14use serde::{Deserialize, Serialize};
15use serde_json::Value;
16use std::collections::HashMap;
17use std::sync::Arc;
18use std::time::Duration;
19use tokio::sync::RwLock;
20
21/// Learning configuration
22#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct LearningConfig {
24 /// Enable drift learning
25 #[serde(default)]
26 pub enabled: bool,
27
28 /// Learning mode
29 #[serde(default)]
30 pub mode: LearningMode,
31
32 /// Learning rate (0.0 to 1.0) - how quickly mocks learn from patterns
33 #[serde(default = "default_learning_rate")]
34 pub sensitivity: f64,
35
36 /// Decay rate (0.0 to 1.0) - drift resets if upstream patterns reverse
37 #[serde(default = "default_decay_rate")]
38 pub decay: f64,
39
40 /// Minimum number of samples before learning starts
41 #[serde(default = "default_min_samples")]
42 pub min_samples: usize,
43
44 /// Update interval for learning
45 #[serde(default = "default_update_interval")]
46 pub update_interval: Duration,
47
48 /// Enable persona adaptation
49 #[serde(default = "default_true")]
50 pub persona_adaptation: bool,
51
52 /// Enable traffic pattern mirroring
53 #[serde(default = "default_true")]
54 pub traffic_mirroring: bool,
55
56 /// Per-endpoint opt-in learning (endpoint pattern -> enabled)
57 #[serde(default)]
58 pub endpoint_learning: HashMap<String, bool>,
59
60 /// Per-persona opt-in learning (persona_id -> enabled)
61 #[serde(default)]
62 pub persona_learning: HashMap<String, bool>,
63}
64
65fn default_learning_rate() -> f64 {
66 0.2 // 20% learning rate - conservative default
67}
68
69fn default_decay_rate() -> f64 {
70 0.05 // 5% decay rate
71}
72
73fn default_min_samples() -> usize {
74 10 // Need at least 10 samples before learning
75}
76
77fn default_update_interval() -> Duration {
78 Duration::from_secs(60) // Update every minute
79}
80
81fn default_true() -> bool {
82 true
83}
84
85impl Default for LearningConfig {
86 fn default() -> Self {
87 Self {
88 enabled: false, // Opt-in by default
89 mode: LearningMode::Behavioral,
90 sensitivity: default_learning_rate(),
91 decay: default_decay_rate(),
92 min_samples: default_min_samples(),
93 update_interval: default_update_interval(),
94 persona_adaptation: true,
95 traffic_mirroring: true,
96 endpoint_learning: HashMap::new(),
97 persona_learning: HashMap::new(),
98 }
99 }
100}
101
102/// Learning mode
103#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)]
104#[serde(rename_all = "snake_case")]
105pub enum LearningMode {
106 /// Behavioral learning - adapts to behavior patterns
107 #[default]
108 Behavioral,
109 /// Statistical learning - adapts to statistical patterns
110 Statistical,
111 /// Hybrid - combines behavioral and statistical
112 Hybrid,
113}
114
115/// Drift Learning Engine
116///
117/// Extends DataDriftEngine with learning capabilities from recorded traffic.
118pub struct DriftLearningEngine {
119 /// Base drift engine
120 drift_engine: DataDriftEngine,
121 /// Learning configuration
122 learning_config: LearningConfig,
123 /// Traffic pattern learner
124 traffic_learner: Option<Arc<RwLock<TrafficPatternLearner>>>,
125 /// Persona behavior learner
126 persona_learner: Option<Arc<RwLock<PersonaBehaviorLearner>>>,
127 /// Learned patterns cache
128 learned_patterns: Arc<RwLock<HashMap<String, LearnedPattern>>>,
129}
130
131/// Learned pattern from traffic analysis
132#[derive(Debug, Clone)]
133pub struct LearnedPattern {
134 /// Pattern identifier
135 pub pattern_id: String,
136 /// Pattern type
137 pub pattern_type: PatternType,
138 /// Learned parameters
139 pub parameters: HashMap<String, Value>,
140 /// Confidence score (0.0 to 1.0)
141 pub confidence: f64,
142 /// Sample count used for learning
143 pub sample_count: usize,
144 /// Last updated timestamp
145 pub last_updated: chrono::DateTime<chrono::Utc>,
146}
147
148/// Pattern type
149#[derive(Debug, Clone, PartialEq, Eq)]
150pub enum PatternType {
151 /// Latency pattern
152 Latency,
153 /// Error rate pattern
154 ErrorRate,
155 /// Request sequence pattern
156 RequestSequence,
157 /// Persona behavior pattern
158 PersonaBehavior,
159}
160
161impl DriftLearningEngine {
162 /// Create a new drift learning engine
163 pub fn new(drift_config: DataDriftConfig, learning_config: LearningConfig) -> Result<Self> {
164 let drift_engine = DataDriftEngine::new(drift_config)?;
165
166 let traffic_learner = if learning_config.traffic_mirroring {
167 Some(Arc::new(RwLock::new(TrafficPatternLearner::new(learning_config.clone())?)))
168 } else {
169 None
170 };
171
172 let persona_learner = if learning_config.persona_adaptation {
173 Some(Arc::new(RwLock::new(PersonaBehaviorLearner::new(learning_config.clone())?)))
174 } else {
175 None
176 };
177
178 Ok(Self {
179 drift_engine,
180 learning_config,
181 traffic_learner,
182 persona_learner,
183 learned_patterns: Arc::new(RwLock::new(HashMap::new())),
184 })
185 }
186
187 /// Get the base drift engine
188 pub fn drift_engine(&self) -> &DataDriftEngine {
189 &self.drift_engine
190 }
191
192 /// Get learning configuration
193 pub fn learning_config(&self) -> &LearningConfig {
194 &self.learning_config
195 }
196
197 /// Update learning configuration
198 pub fn update_learning_config(&mut self, config: LearningConfig) -> Result<()> {
199 self.learning_config = config;
200 Ok(())
201 }
202
203 /// Get learned patterns
204 pub async fn get_learned_patterns(&self) -> HashMap<String, LearnedPattern> {
205 self.learned_patterns.read().await.clone()
206 }
207
208 /// Apply drift with learning
209 pub async fn apply_drift_with_learning(&self, data: Value) -> Result<Value> {
210 // First apply base drift
211 let mut data = self.drift_engine.apply_drift(data).await?;
212
213 // Then apply learned patterns if learning is enabled
214 if !self.learning_config.enabled {
215 return Ok(data);
216 }
217
218 // Apply learned patterns
219 let patterns = self.learned_patterns.read().await;
220 for (pattern_id, pattern) in patterns.iter() {
221 // Check if pattern should be applied based on confidence and decay
222 if pattern.confidence < 0.5 {
223 continue; // Low confidence, skip
224 }
225
226 // Apply pattern based on type
227 match pattern.pattern_type {
228 PatternType::Latency => {
229 // Latency patterns are handled separately
230 }
231 PatternType::ErrorRate => {
232 // Error rate patterns are handled separately
233 }
234 PatternType::RequestSequence => {
235 // Request sequence patterns affect persona behavior
236 }
237 PatternType::PersonaBehavior => {
238 // Persona behavior patterns affect data generation
239 if let Some(obj) = data.as_object_mut() {
240 for (key, value) in &pattern.parameters {
241 if let Some(existing) = obj.get(key) {
242 // Blend learned value with existing value
243 let blended =
244 self.blend_values(existing, value, pattern.confidence)?;
245 obj.insert(key.clone(), blended);
246 }
247 }
248 }
249 }
250 }
251 }
252
253 Ok(data)
254 }
255
256 /// Blend two values based on confidence
257 fn blend_values(&self, existing: &Value, learned: &Value, confidence: f64) -> Result<Value> {
258 // Simple blending: existing * (1 - confidence * sensitivity) + learned * (confidence * sensitivity)
259 let weight = confidence * self.learning_config.sensitivity;
260
261 match (existing, learned) {
262 (Value::Number(existing_num), Value::Number(learned_num)) => {
263 if let (Some(existing_f64), Some(learned_f64)) =
264 (existing_num.as_f64(), learned_num.as_f64())
265 {
266 let blended = existing_f64 * (1.0 - weight) + learned_f64 * weight;
267 Ok(Value::from(blended))
268 } else {
269 Ok(existing.clone())
270 }
271 }
272 _ => Ok(existing.clone()), // For non-numeric, keep existing
273 }
274 }
275}
276
277/// Traffic Pattern Learner
278///
279/// Analyzes recorded traffic to detect patterns and trends.
280pub struct TrafficPatternLearner {
281 /// Learning configuration
282 config: LearningConfig,
283 /// Pattern window for analysis
284 pattern_window: Duration,
285 /// Detected patterns
286 patterns: HashMap<String, TrafficPattern>,
287}
288
289/// Traffic pattern detected from analysis
290#[derive(Debug, Clone)]
291struct TrafficPattern {
292 /// Pattern identifier
293 pattern_id: String,
294 /// Pattern type
295 pattern_type: PatternType,
296 /// Pattern parameters
297 parameters: HashMap<String, Value>,
298 /// Sample count
299 sample_count: usize,
300 /// First seen timestamp
301 first_seen: chrono::DateTime<chrono::Utc>,
302 /// Last seen timestamp
303 last_seen: chrono::DateTime<chrono::Utc>,
304}
305
306impl TrafficPatternLearner {
307 /// Create a new traffic pattern learner
308 pub fn new(config: LearningConfig) -> Result<Self> {
309 Ok(Self {
310 config,
311 pattern_window: Duration::from_secs(3600), // 1 hour window
312 patterns: HashMap::new(),
313 })
314 }
315
316 /// Analyze traffic patterns from recorded requests
317 ///
318 /// NOTE: This method is disabled to break circular dependencies.
319 /// The recorder integration has been moved to a higher-level crate.
320 pub async fn analyze_traffic_patterns(
321 &mut self,
322 _database: &dyn std::any::Any, // Use Any to avoid dependency on mockforge-recorder
323 ) -> Result<Vec<LearnedPattern>> {
324 // Disabled to break circular dependency
325 Ok(Vec::new())
326 }
327
328 /// Internal method to detect latency patterns from requests
329 ///
330 /// NOTE: This method is disabled to break circular dependency.
331 /// The recorder integration has been moved to a higher-level crate.
332 #[allow(dead_code)]
333 pub async fn detect_latency_patterns_from_requests(
334 &self,
335 _requests: &[serde_json::Value],
336 ) -> Result<Vec<LearnedPattern>> {
337 // Disabled to break circular dependency
338 Ok(Vec::new())
339 }
340
341 // NOTE: The following code is disabled to break circular dependency
342 // Original implementation would process requests here
343 /*
344 fn _detect_latency_patterns_original(
345 &self,
346 requests: &[serde_json::Value],
347 ) -> Result<Vec<LearnedPattern>> {
348 use std::collections::HashMap;
349 use chrono::Utc;
350
351 // Group requests by endpoint/method
352 let mut endpoint_latencies: HashMap<String, Vec<i64>> = HashMap::new();
353
354 for request in requests {
355 if let Some(duration) = request.duration_ms {
356 let key = format!("{} {}", request.method, request.path);
357 endpoint_latencies.entry(key).or_insert_with(Vec::new).push(duration);
358 }
359 }
360
361 let mut patterns = Vec::new();
362
363 for (endpoint_key, latencies) in endpoint_latencies {
364 if latencies.len() < 10 {
365 // Need at least 10 samples for meaningful analysis
366 continue;
367 }
368
369 // Calculate statistics
370 let sum: i64 = latencies.iter().sum();
371 let count = latencies.len();
372 let avg_latency = sum as f64 / count as f64;
373
374 // Calculate percentiles
375 let mut sorted = latencies.clone();
376 sorted.sort();
377 let p50 = sorted[sorted.len() / 2];
378 let p95 = sorted[(sorted.len() * 95) / 100];
379 let p99 = sorted[(sorted.len() * 99) / 100];
380
381 // Detect if latency is increasing (trend analysis)
382 let recent_avg = if latencies.len() >= 20 {
383 let recent: Vec<i64> = latencies.iter().rev().take(10).copied().collect();
384 let recent_sum: i64 = recent.iter().sum();
385 recent_sum as f64 / recent.len() as f64
386 } else {
387 avg_latency
388 };
389
390 let latency_trend = if recent_avg > avg_latency * 1.2 {
391 "increasing"
392 } else if recent_avg < avg_latency * 0.8 {
393 "decreasing"
394 } else {
395 "stable"
396 };
397
398 // Create pattern if there's significant variation or trend
399 if p99 > p50 * 2.0 || latency_trend != "stable" {
400 let mut parameters = HashMap::new();
401 parameters.insert("endpoint".to_string(), serde_json::json!(endpoint_key));
402 parameters.insert("avg_latency_ms".to_string(), serde_json::json!(avg_latency));
403 parameters.insert("p50_ms".to_string(), serde_json::json!(p50));
404 parameters.insert("p95_ms".to_string(), serde_json::json!(p95));
405 parameters.insert("p99_ms".to_string(), serde_json::json!(p99));
406 parameters.insert("sample_count".to_string(), serde_json::json!(count));
407 parameters.insert("trend".to_string(), serde_json::json!(latency_trend));
408
409 // Confidence based on sample size
410 let confidence = (count as f64 / 100.0).min(1.0);
411
412 patterns.push(LearnedPattern {
413 pattern_id: format!("latency_{}", endpoint_key.replace('/', "_").replace(' ', "_")),
414 pattern_type: PatternType::Latency,
415 parameters,
416 confidence,
417 sample_count: count,
418 last_updated: Utc::now(),
419 });
420 }
421 }
422
423 Ok(patterns)
424 }
425 */
426
427 /// Internal method to detect error rate patterns from requests
428 /// NOTE: Disabled to break circular dependency
429 #[allow(dead_code)]
430 async fn detect_error_patterns_internal(
431 &self,
432 _requests: &[serde_json::Value],
433 ) -> Result<Vec<LearnedPattern>> {
434 use chrono::Utc;
435 use std::collections::HashMap;
436
437 // Disabled to break circular dependency
438 let _requests = _requests;
439 let endpoint_errors: HashMap<String, (usize, usize)> = HashMap::new(); // (total, errors)
440
441 // Disabled - would iterate over requests here
442 /*
443 for request in requests {
444 let key = format!("{} {}", request.method, request.path);
445 let entry = endpoint_errors.entry(key).or_insert((0, 0));
446 entry.0 += 1;
447
448 // Consider 4xx and 5xx as errors
449 if let Some(status) = request.status_code {
450 if status >= 400 {
451 entry.1 += 1;
452 }
453 }
454 }
455 */
456
457 let mut patterns = Vec::new();
458
459 for (endpoint_key, (total, errors)) in endpoint_errors {
460 if total < 20 {
461 // Need at least 20 samples for meaningful analysis
462 continue;
463 }
464
465 let error_rate = errors as f64 / total as f64;
466
467 // Create pattern if error rate is significant (>5%) or increasing
468 if error_rate > 0.05 {
469 let mut parameters = HashMap::new();
470 parameters.insert("endpoint".to_string(), serde_json::json!(endpoint_key));
471 parameters.insert("error_rate".to_string(), serde_json::json!(error_rate));
472 parameters.insert("total_requests".to_string(), serde_json::json!(total));
473 parameters.insert("error_count".to_string(), serde_json::json!(errors));
474
475 // Confidence based on sample size and error rate
476 let confidence = ((total as f64 / 100.0).min(1.0) * error_rate * 10.0).min(1.0);
477
478 patterns.push(LearnedPattern {
479 pattern_id: format!("error_rate_{}", endpoint_key.replace(['/', ' '], "_")),
480 pattern_type: PatternType::ErrorRate,
481 parameters,
482 confidence,
483 sample_count: total,
484 last_updated: Utc::now(),
485 });
486 }
487 }
488
489 Ok(patterns)
490 }
491
492 /// Internal method to detect request sequence patterns
493 /// NOTE: Disabled to break circular dependency
494 #[allow(dead_code)]
495 async fn detect_sequence_patterns_internal(
496 &self,
497 _requests: &[serde_json::Value],
498 ) -> Result<Vec<LearnedPattern>> {
499 use chrono::Utc;
500 use std::collections::HashMap;
501
502 // Disabled to break circular dependency
503 let _requests = _requests;
504 if _requests.len() < 50 {
505 // Need sufficient data for sequence detection
506 return Ok(Vec::new());
507 }
508
509 // Disabled - would process requests here
510 let trace_sequences: HashMap<Option<String>, Vec<String>> = HashMap::new();
511
512 /*
513 for request in requests {
514 let trace_id = request.trace_id.clone();
515 let endpoint_key = format!("{} {}", request.method, request.path);
516 trace_sequences
517 .entry(trace_id)
518 .or_insert_with(Vec::new)
519 .push(endpoint_key);
520 }
521 */
522
523 // Find common sequences (patterns that appear multiple times)
524 let mut sequence_counts: HashMap<String, usize> = HashMap::new();
525
526 for sequence in trace_sequences.values() {
527 if sequence.len() >= 2 {
528 // Create sequence signature (first 3 endpoints)
529 let signature: Vec<String> = sequence.iter().take(3).cloned().collect();
530 let signature_str = signature.join(" -> ");
531 *sequence_counts.entry(signature_str).or_insert(0) += 1;
532 }
533 }
534
535 let mut patterns = Vec::new();
536
537 for (sequence_str, count) in sequence_counts {
538 if count >= 5 {
539 // Pattern appears at least 5 times
540 let mut parameters = HashMap::new();
541 parameters.insert("sequence".to_string(), serde_json::json!(sequence_str));
542 parameters.insert("occurrence_count".to_string(), serde_json::json!(count));
543
544 // Confidence based on occurrence frequency
545 let confidence = (count as f64 / 20.0).min(1.0);
546
547 patterns.push(LearnedPattern {
548 pattern_id: format!(
549 "sequence_{}",
550 sequence_str.replace(['/', ' '], "_").replace("->", "_")
551 ),
552 pattern_type: PatternType::RequestSequence,
553 parameters,
554 confidence,
555 sample_count: count,
556 last_updated: Utc::now(),
557 });
558 }
559 }
560
561 Ok(patterns)
562 }
563
564 /// Detect latency patterns
565 ///
566 /// This method is a convenience wrapper that requires a database.
567 /// Use `analyze_traffic_patterns` with a RecorderDatabase for full analysis.
568 pub async fn detect_latency_patterns(&mut self) -> Result<Vec<LearnedPattern>> {
569 // This method now requires database access - use analyze_traffic_patterns instead
570 Ok(Vec::new())
571 }
572
573 /// Detect error rate patterns
574 ///
575 /// This method is a convenience wrapper that requires a database.
576 /// Use `analyze_traffic_patterns` with a RecorderDatabase for full analysis.
577 pub async fn detect_error_patterns(&mut self) -> Result<Vec<LearnedPattern>> {
578 // This method now requires database access - use analyze_traffic_patterns instead
579 Ok(Vec::new())
580 }
581}
582
583/// Persona Behavior Learner
584///
585/// Adapts persona profiles based on request patterns.
586pub struct PersonaBehaviorLearner {
587 /// Learning configuration
588 config: LearningConfig,
589 /// Behavior history (persona_id -> behavior events)
590 behavior_history: HashMap<String, Vec<BehaviorEvent>>,
591}
592
593/// Behavior event for a persona
594#[derive(Debug, Clone)]
595pub struct BehaviorEvent {
596 /// Event timestamp
597 pub timestamp: chrono::DateTime<chrono::Utc>,
598 /// Event type
599 pub event_type: BehaviorEventType,
600 /// Event data
601 pub data: HashMap<String, Value>,
602}
603
604/// Behavior event type
605#[derive(Debug, Clone, PartialEq, Eq)]
606pub enum BehaviorEventType {
607 /// Request made to an endpoint
608 Request {
609 /// Endpoint path
610 endpoint: String,
611 /// HTTP method
612 method: String,
613 },
614 /// Request failed
615 RequestFailed {
616 /// Endpoint path
617 endpoint: String,
618 /// HTTP status code
619 status_code: u16,
620 },
621 /// Request succeeded after failure
622 RequestSucceededAfterFailure {
623 /// Endpoint path
624 endpoint: String,
625 },
626 /// Pattern detected
627 PatternDetected {
628 /// Pattern identifier
629 pattern: String,
630 },
631}
632
633impl PersonaBehaviorLearner {
634 /// Create a new persona behavior learner
635 pub fn new(config: LearningConfig) -> Result<Self> {
636 Ok(Self {
637 config,
638 behavior_history: HashMap::new(),
639 })
640 }
641
642 /// Record a behavior event for a persona
643 pub fn record_event(&mut self, persona_id: String, event: BehaviorEvent) {
644 if !self.config.enabled {
645 return;
646 }
647
648 // Check if persona learning is enabled for this persona
649 if let Some(&enabled) = self.config.persona_learning.get(&persona_id) {
650 if !enabled {
651 return; // Learning disabled for this persona
652 }
653 }
654
655 let events = self.behavior_history.entry(persona_id).or_default();
656 events.push(event);
657
658 // Keep only recent events (last 1000)
659 if events.len() > 1000 {
660 events.remove(0);
661 }
662 }
663
664 /// Analyze behavior patterns for a persona
665 pub async fn analyze_persona_behavior(
666 &self,
667 persona_id: &str,
668 ) -> Result<Option<LearnedPattern>> {
669 if !self.config.enabled {
670 return Ok(None);
671 }
672
673 let events = match self.behavior_history.get(persona_id) {
674 Some(events) => events,
675 None => return Ok(None),
676 };
677
678 if events.len() < self.config.min_samples {
679 return Ok(None); // Not enough samples
680 }
681
682 // Analyze patterns
683 // Example: If persona repeatedly requests /checkout after failure, learn this pattern
684 let mut checkout_after_failure_count = 0;
685 let mut total_failures = 0;
686
687 for i in 1..events.len() {
688 if let BehaviorEventType::RequestFailed { .. } = events[i - 1].event_type {
689 total_failures += 1;
690 if let BehaviorEventType::Request { endpoint, .. } = &events[i].event_type {
691 if endpoint.contains("/checkout") {
692 checkout_after_failure_count += 1;
693 }
694 }
695 }
696 }
697
698 if total_failures > 0 && checkout_after_failure_count as f64 / total_failures as f64 > 0.5 {
699 // Pattern detected: persona requests /checkout after failure > 50% of the time
700 let mut parameters = HashMap::new();
701 parameters.insert("retry_checkout_after_failure".to_string(), Value::from(true));
702 parameters.insert(
703 "retry_probability".to_string(),
704 Value::from(checkout_after_failure_count as f64 / total_failures as f64),
705 );
706
707 return Ok(Some(LearnedPattern {
708 pattern_id: format!("persona_{}_checkout_retry", persona_id),
709 pattern_type: PatternType::PersonaBehavior,
710 parameters,
711 confidence: (checkout_after_failure_count as f64 / total_failures as f64).min(1.0),
712 sample_count: total_failures,
713 last_updated: chrono::Utc::now(),
714 }));
715 }
716
717 Ok(None)
718 }
719
720 /// Get behavior history for a persona
721 pub fn get_behavior_history(&self, persona_id: &str) -> Option<&Vec<BehaviorEvent>> {
722 self.behavior_history.get(persona_id)
723 }
724
725 /// Apply learned patterns to a persona in PersonaRegistry
726 ///
727 /// This method should be called periodically to update persona profiles
728 /// based on learned behavior patterns.
729 pub async fn apply_learned_patterns_to_persona(
730 &self,
731 persona_id: &str,
732 persona_registry: &crate::PersonaRegistry,
733 ) -> Result<()> {
734 if !self.config.enabled {
735 return Ok(());
736 }
737
738 // Analyze behavior for this persona
739 if let Some(pattern) = self.analyze_persona_behavior(persona_id).await? {
740 // Convert learned pattern parameters to traits
741 let mut learned_traits = std::collections::HashMap::new();
742 for (key, value) in &pattern.parameters {
743 let trait_key = format!("learned_{}", key);
744 let trait_value = if let Some(s) = value.as_str() {
745 s.to_string()
746 } else if let Some(n) = value.as_f64() {
747 n.to_string()
748 } else if let Some(b) = value.as_bool() {
749 b.to_string()
750 } else {
751 value.to_string()
752 };
753 learned_traits.insert(trait_key, trait_value);
754 }
755
756 // Update persona traits in registry
757 if !learned_traits.is_empty() {
758 persona_registry.update_persona(persona_id, learned_traits)?;
759 }
760 }
761
762 Ok(())
763 }
764}
765
766#[cfg(test)]
767mod tests {
768 use super::*;
769
770 #[test]
771 fn test_learning_config_default() {
772 let config = LearningConfig::default();
773 assert!(!config.enabled); // Opt-in by default
774 assert_eq!(config.sensitivity, 0.2);
775 assert_eq!(config.min_samples, 10);
776 }
777
778 #[test]
779 fn test_drift_learning_engine_creation() {
780 let drift_config = DataDriftConfig::new();
781 let learning_config = LearningConfig::default();
782 let engine = DriftLearningEngine::new(drift_config, learning_config);
783 assert!(engine.is_ok());
784 }
785
786 #[tokio::test]
787 async fn test_persona_behavior_learner() {
788 let config = LearningConfig {
789 enabled: true,
790 persona_adaptation: true,
791 ..Default::default()
792 };
793 let mut learner = PersonaBehaviorLearner::new(config).unwrap();
794
795 // Record failure
796 learner.record_event(
797 "persona-1".to_string(),
798 BehaviorEvent {
799 timestamp: chrono::Utc::now(),
800 event_type: BehaviorEventType::RequestFailed {
801 endpoint: "/api/checkout".to_string(),
802 status_code: 500,
803 },
804 data: HashMap::new(),
805 },
806 );
807
808 // Record checkout request after failure
809 learner.record_event(
810 "persona-1".to_string(),
811 BehaviorEvent {
812 timestamp: chrono::Utc::now(),
813 event_type: BehaviorEventType::Request {
814 endpoint: "/api/checkout".to_string(),
815 method: "POST".to_string(),
816 },
817 data: HashMap::new(),
818 },
819 );
820
821 // Analyze (won't find pattern with only 2 samples, need min_samples)
822 let pattern = learner.analyze_persona_behavior("persona-1").await.unwrap();
823 assert!(pattern.is_none()); // Not enough samples
824 }
825}