kizzasi_logic/
online_learning.rs

1//! Online Constraint Learning
2//!
3//! This module provides algorithms for learning constraints from streaming data:
4//! - Incremental constraint refinement
5//! - Anomaly-based constraint discovery
6//! - Constraint parameter tuning from feedback
7//! - Active learning for constraint boundaries
8
9use crate::{LinearConstraint, LogicResult};
10use scirs2_core::ndarray::Array1;
11use std::collections::VecDeque;
12
13/// Online learner for refining constraints from streaming data
14#[derive(Debug, Clone)]
15pub struct OnlineConstraintLearner {
16    /// Current constraint estimate
17    constraint: LinearConstraint,
18    /// Learning rate for parameter updates
19    #[allow(dead_code)]
20    learning_rate: f32,
21    /// Historical data buffer
22    data_buffer: VecDeque<(Array1<f32>, bool)>, // (sample, is_feasible)
23    /// Maximum buffer size
24    #[allow(dead_code)]
25    max_buffer_size: usize,
26    /// Number of updates performed
27    update_count: usize,
28}
29
30impl OnlineConstraintLearner {
31    /// Create a new online constraint learner
32    pub fn new(
33        initial_constraint: LinearConstraint,
34        learning_rate: f32,
35        max_buffer_size: usize,
36    ) -> Self {
37        Self {
38            constraint: initial_constraint,
39            learning_rate,
40            data_buffer: VecDeque::new(),
41            max_buffer_size,
42            update_count: 0,
43        }
44    }
45
46    /// Add a new observation and update the constraint
47    pub fn observe(&mut self, sample: Array1<f32>, is_feasible: bool) -> LogicResult<()> {
48        // Add to buffer
49        self.data_buffer.push_back((sample.clone(), is_feasible));
50        if self.data_buffer.len() > self.max_buffer_size {
51            self.data_buffer.pop_front();
52        }
53
54        // Update constraint using stochastic gradient descent
55        self.refine_constraint(&sample, is_feasible)?;
56        self.update_count += 1;
57
58        Ok(())
59    }
60
61    /// Refine constraint based on a single observation
62    fn refine_constraint(&mut self, sample: &Array1<f32>, is_feasible: bool) -> LogicResult<()> {
63        let sample_slice = sample.as_slice().unwrap_or(&[]);
64        let current_satisfied = self.constraint.check(sample_slice);
65
66        // If prediction matches observation, no update needed
67        if current_satisfied == is_feasible {
68            return Ok(());
69        }
70
71        // Compute gradient for constraint refinement
72        // For a·x <= b, we adjust 'a' and 'b' to better fit the data
73        let violation = self.constraint.violation(sample_slice);
74
75        // Update using perceptron-like rule
76        let update_scale = if is_feasible {
77            // Sample should be feasible but is violated: loosen constraint
78            self.learning_rate * violation
79        } else {
80            // Sample should be infeasible but is satisfied: tighten constraint
81            -self.learning_rate
82        };
83
84        // Create updated constraint (simplified - in practice would update coefficients)
85        // This is a placeholder for demonstration
86        let _ = update_scale; // TODO: implement actual coefficient updates
87
88        Ok(())
89    }
90
91    /// Get the current learned constraint
92    pub fn get_constraint(&self) -> &LinearConstraint {
93        &self.constraint
94    }
95
96    /// Get number of updates performed
97    pub fn update_count(&self) -> usize {
98        self.update_count
99    }
100
101    /// Evaluate confidence in current constraint
102    pub fn confidence(&self) -> f32 {
103        if self.data_buffer.is_empty() {
104            return 0.0;
105        }
106
107        // Compute accuracy on buffered data
108        let correct = self
109            .data_buffer
110            .iter()
111            .filter(|(sample, is_feasible)| {
112                let satisfied = self.constraint.check(sample.as_slice().unwrap_or(&[]));
113                satisfied == *is_feasible
114            })
115            .count();
116
117        correct as f32 / self.data_buffer.len() as f32
118    }
119}
120
121/// Anomaly detector for discovering new constraints
122#[derive(Debug, Clone)]
123pub struct AnomalyBasedConstraintDiscovery {
124    /// Historical normal samples
125    normal_samples: VecDeque<Array1<f32>>,
126    /// Maximum number of normal samples to keep
127    max_samples: usize,
128    /// Anomaly threshold (number of standard deviations)
129    anomaly_threshold: f32,
130    /// Discovered constraints
131    discovered_constraints: Vec<LinearConstraint>,
132}
133
134impl AnomalyBasedConstraintDiscovery {
135    /// Create a new anomaly-based constraint discoverer
136    pub fn new(max_samples: usize, anomaly_threshold: f32) -> Self {
137        Self {
138            normal_samples: VecDeque::new(),
139            max_samples,
140            anomaly_threshold,
141            discovered_constraints: Vec::new(),
142        }
143    }
144
145    /// Add a normal sample for baseline estimation
146    pub fn add_normal_sample(&mut self, sample: Array1<f32>) {
147        self.normal_samples.push_back(sample);
148        if self.normal_samples.len() > self.max_samples {
149            self.normal_samples.pop_front();
150        }
151    }
152
153    /// Check if a sample is anomalous and potentially discover new constraint
154    pub fn detect_anomaly(&mut self, sample: &Array1<f32>) -> bool {
155        if self.normal_samples.len() < 2 {
156            return false; // Not enough data for detection
157        }
158
159        // Compute statistics of normal samples
160        let dim = sample.len();
161        let n = self.normal_samples.len();
162
163        // Compute mean and std dev for each dimension
164        let mut is_anomalous = false;
165
166        for d in 0..dim {
167            let mean: f32 = self.normal_samples.iter().map(|s| s[d]).sum::<f32>() / n as f32;
168
169            let variance: f32 = self
170                .normal_samples
171                .iter()
172                .map(|s| (s[d] - mean).powi(2))
173                .sum::<f32>()
174                / n as f32;
175
176            let std_dev = variance.sqrt();
177
178            // Check if sample is outside threshold
179            let z_score = (sample[d] - mean).abs() / (std_dev + 1e-8);
180            if z_score > self.anomaly_threshold {
181                is_anomalous = true;
182
183                // Discover constraint: x[d] should be within [mean - threshold*std, mean + threshold*std]
184                self.discover_bound_constraint(d, mean, std_dev);
185            }
186        }
187
188        is_anomalous
189    }
190
191    /// Discover a bound constraint for a dimension
192    fn discover_bound_constraint(&mut self, dim: usize, mean: f32, std_dev: f32) {
193        // Create constraint: x[dim] <= mean + threshold * std_dev
194        let upper_bound = mean + self.anomaly_threshold * std_dev;
195        let mut coeffs = vec![0.0; dim + 1];
196        coeffs[dim] = 1.0;
197
198        let constraint = LinearConstraint::less_eq(coeffs, upper_bound);
199
200        // Check if we already have a similar constraint
201        let is_duplicate = self.discovered_constraints.iter().any(|c| {
202            c.coefficients().len() == constraint.coefficients().len()
203                && c.coefficients()
204                    .iter()
205                    .zip(constraint.coefficients().iter())
206                    .all(|(a, b)| (a - b).abs() < 0.1)
207        });
208
209        if !is_duplicate {
210            self.discovered_constraints.push(constraint);
211        }
212    }
213
214    /// Get discovered constraints
215    pub fn discovered_constraints(&self) -> &[LinearConstraint] {
216        &self.discovered_constraints
217    }
218
219    /// Get number of discovered constraints
220    pub fn num_discovered(&self) -> usize {
221        self.discovered_constraints.len()
222    }
223}
224
225/// Active learner for exploring constraint boundaries
226#[derive(Debug, Clone)]
227pub struct ActiveConstraintBoundaryLearner {
228    /// Current constraint estimate
229    constraint: LinearConstraint,
230    /// Samples near the boundary (uncertain region)
231    boundary_samples: Vec<(Array1<f32>, Option<bool>)>, // (sample, label if known)
232    /// Uncertainty threshold for boundary detection
233    uncertainty_threshold: f32,
234    /// Maximum number of boundary samples to track
235    max_boundary_samples: usize,
236}
237
238impl ActiveConstraintBoundaryLearner {
239    /// Create a new active boundary learner
240    pub fn new(
241        initial_constraint: LinearConstraint,
242        uncertainty_threshold: f32,
243        max_boundary_samples: usize,
244    ) -> Self {
245        Self {
246            constraint: initial_constraint,
247            boundary_samples: Vec::new(),
248            uncertainty_threshold,
249            max_boundary_samples,
250        }
251    }
252
253    /// Get the most informative sample to query (closest to boundary)
254    pub fn query_next(&self) -> Option<Array1<f32>> {
255        // Find unlabeled sample closest to decision boundary
256        self.boundary_samples
257            .iter()
258            .filter(|(_, label)| label.is_none())
259            .min_by(|(s1, _), (s2, _)| {
260                let v1 = self
261                    .constraint
262                    .violation(s1.as_slice().unwrap_or(&[]))
263                    .abs();
264                let v2 = self
265                    .constraint
266                    .violation(s2.as_slice().unwrap_or(&[]))
267                    .abs();
268                v1.partial_cmp(&v2).unwrap_or(std::cmp::Ordering::Equal)
269            })
270            .map(|(s, _)| s.clone())
271    }
272
273    /// Add a labeled sample
274    pub fn add_labeled_sample(&mut self, sample: Array1<f32>, is_feasible: bool) {
275        let violation = self
276            .constraint
277            .violation(sample.as_slice().unwrap_or(&[]))
278            .abs();
279
280        // Add to boundary samples if near boundary
281        if violation < self.uncertainty_threshold {
282            self.boundary_samples.push((sample, Some(is_feasible)));
283            if self.boundary_samples.len() > self.max_boundary_samples {
284                self.boundary_samples.remove(0);
285            }
286        }
287    }
288
289    /// Add an unlabeled sample for potential querying
290    pub fn add_unlabeled_sample(&mut self, sample: Array1<f32>) {
291        let violation = self
292            .constraint
293            .violation(sample.as_slice().unwrap_or(&[]))
294            .abs();
295
296        if violation < self.uncertainty_threshold {
297            self.boundary_samples.push((sample, None));
298            if self.boundary_samples.len() > self.max_boundary_samples {
299                self.boundary_samples.remove(0);
300            }
301        }
302    }
303
304    /// Refine constraint based on labeled boundary samples
305    pub fn refine(&mut self) -> LogicResult<()> {
306        // Use labeled boundary samples to refine constraint
307        // This is a simplified version - in practice would use SVM or similar
308        let labeled: Vec<_> = self
309            .boundary_samples
310            .iter()
311            .filter_map(|(s, l)| l.map(|label| (s, label)))
312            .collect();
313
314        if labeled.len() < 2 {
315            return Ok(()); // Not enough data
316        }
317
318        // Placeholder for actual refinement logic
319        // Would typically use margin-based learning or similar
320        Ok(())
321    }
322
323    /// Get the current constraint
324    pub fn get_constraint(&self) -> &LinearConstraint {
325        &self.constraint
326    }
327
328    /// Get number of boundary samples
329    pub fn num_boundary_samples(&self) -> usize {
330        self.boundary_samples.len()
331    }
332
333    /// Get number of unlabeled samples
334    pub fn num_unlabeled(&self) -> usize {
335        self.boundary_samples
336            .iter()
337            .filter(|(_, l)| l.is_none())
338            .count()
339    }
340}
341
342/// Feedback-based constraint tuner
343#[derive(Debug, Clone)]
344pub struct FeedbackConstraintTuner {
345    /// Current constraint
346    constraint: LinearConstraint,
347    /// Feedback history (violation amount, user satisfaction)
348    feedback_history: Vec<(f32, f32)>, // (violation, satisfaction in [0, 1])
349    /// Adaptation rate
350    #[allow(dead_code)]
351    adaptation_rate: f32,
352    /// Target satisfaction level
353    target_satisfaction: f32,
354}
355
356impl FeedbackConstraintTuner {
357    /// Create a new feedback-based tuner
358    pub fn new(
359        initial_constraint: LinearConstraint,
360        adaptation_rate: f32,
361        target_satisfaction: f32,
362    ) -> Self {
363        Self {
364            constraint: initial_constraint,
365            feedback_history: Vec::new(),
366            adaptation_rate,
367            target_satisfaction,
368        }
369    }
370
371    /// Add user feedback for a sample
372    pub fn add_feedback(&mut self, sample: &Array1<f32>, satisfaction: f32) -> LogicResult<()> {
373        let violation = self.constraint.violation(sample.as_slice().unwrap_or(&[]));
374        self.feedback_history.push((violation, satisfaction));
375
376        // Tune constraint based on feedback
377        self.tune()?;
378
379        Ok(())
380    }
381
382    /// Tune constraint based on accumulated feedback
383    fn tune(&mut self) -> LogicResult<()> {
384        if self.feedback_history.len() < 5 {
385            return Ok(()); // Need more data
386        }
387
388        // Compute average satisfaction
389        let avg_satisfaction: f32 = self.feedback_history.iter().map(|(_, s)| s).sum::<f32>()
390            / self.feedback_history.len() as f32;
391
392        // If satisfaction is below target, adjust constraint
393        let satisfaction_gap = self.target_satisfaction - avg_satisfaction;
394
395        if satisfaction_gap.abs() > 0.1 {
396            // Significant gap - adjust constraint tightness
397            // Positive gap means we need to be less strict
398            // This is a placeholder for actual tuning logic
399            let _ = satisfaction_gap; // TODO: implement actual tuning
400        }
401
402        Ok(())
403    }
404
405    /// Get the current constraint
406    pub fn get_constraint(&self) -> &LinearConstraint {
407        &self.constraint
408    }
409
410    /// Get average satisfaction from recent feedback
411    pub fn average_satisfaction(&self) -> f32 {
412        if self.feedback_history.is_empty() {
413            return 0.0;
414        }
415
416        self.feedback_history.iter().map(|(_, s)| s).sum::<f32>()
417            / self.feedback_history.len() as f32
418    }
419
420    /// Get number of feedback samples
421    pub fn num_feedback_samples(&self) -> usize {
422        self.feedback_history.len()
423    }
424}
425
426/// Unified online learning system combining multiple strategies
427#[derive(Debug, Clone)]
428pub struct OnlineLearningSystem {
429    /// Incremental learner
430    incremental_learner: OnlineConstraintLearner,
431    /// Anomaly detector
432    anomaly_detector: AnomalyBasedConstraintDiscovery,
433    /// Active learner
434    active_learner: ActiveConstraintBoundaryLearner,
435    /// Feedback tuner
436    feedback_tuner: FeedbackConstraintTuner,
437    /// Enable/disable each component
438    use_incremental: bool,
439    use_anomaly: bool,
440    use_active: bool,
441    use_feedback: bool,
442}
443
444impl OnlineLearningSystem {
445    /// Create a new online learning system
446    pub fn new(initial_constraint: LinearConstraint) -> Self {
447        Self {
448            incremental_learner: OnlineConstraintLearner::new(
449                initial_constraint.clone(),
450                0.01,
451                1000,
452            ),
453            anomaly_detector: AnomalyBasedConstraintDiscovery::new(1000, 3.0),
454            active_learner: ActiveConstraintBoundaryLearner::new(
455                initial_constraint.clone(),
456                0.1,
457                100,
458            ),
459            feedback_tuner: FeedbackConstraintTuner::new(initial_constraint, 0.01, 0.8),
460            use_incremental: true,
461            use_anomaly: true,
462            use_active: true,
463            use_feedback: true,
464        }
465    }
466
467    /// Process a new labeled sample
468    pub fn process_labeled_sample(
469        &mut self,
470        sample: Array1<f32>,
471        is_feasible: bool,
472    ) -> LogicResult<()> {
473        if self.use_incremental {
474            self.incremental_learner
475                .observe(sample.clone(), is_feasible)?;
476        }
477
478        if self.use_active {
479            self.active_learner
480                .add_labeled_sample(sample.clone(), is_feasible);
481        }
482
483        if is_feasible && self.use_anomaly {
484            self.anomaly_detector.add_normal_sample(sample);
485        }
486
487        Ok(())
488    }
489
490    /// Process an unlabeled sample (for anomaly detection and active learning)
491    pub fn process_unlabeled_sample(&mut self, sample: Array1<f32>) {
492        if self.use_anomaly {
493            self.anomaly_detector.detect_anomaly(&sample);
494        }
495
496        if self.use_active {
497            self.active_learner.add_unlabeled_sample(sample);
498        }
499    }
500
501    /// Add user feedback
502    pub fn add_feedback(&mut self, sample: &Array1<f32>, satisfaction: f32) -> LogicResult<()> {
503        if self.use_feedback {
504            self.feedback_tuner.add_feedback(sample, satisfaction)?;
505        }
506        Ok(())
507    }
508
509    /// Get the current best constraint estimate
510    pub fn get_best_constraint(&self) -> &LinearConstraint {
511        // Use the incremental learner's constraint as primary
512        // Could be enhanced to ensemble multiple learned constraints
513        self.incremental_learner.get_constraint()
514    }
515
516    /// Get confidence in current constraint
517    pub fn confidence(&self) -> f32 {
518        self.incremental_learner.confidence()
519    }
520
521    /// Get discovered anomaly-based constraints
522    pub fn discovered_constraints(&self) -> &[LinearConstraint] {
523        self.anomaly_detector.discovered_constraints()
524    }
525
526    /// Get next sample to query (active learning)
527    pub fn query_next(&self) -> Option<Array1<f32>> {
528        if self.use_active {
529            self.active_learner.query_next()
530        } else {
531            None
532        }
533    }
534
535    /// Enable/disable components
536    pub fn set_use_incremental(&mut self, use_it: bool) {
537        self.use_incremental = use_it;
538    }
539
540    pub fn set_use_anomaly(&mut self, use_it: bool) {
541        self.use_anomaly = use_it;
542    }
543
544    pub fn set_use_active(&mut self, use_it: bool) {
545        self.use_active = use_it;
546    }
547
548    pub fn set_use_feedback(&mut self, use_it: bool) {
549        self.use_feedback = use_it;
550    }
551}
552
553#[cfg(test)]
554mod tests {
555    use super::*;
556
557    #[test]
558    fn test_online_learner_basic() -> LogicResult<()> {
559        let constraint = LinearConstraint::less_eq(vec![1.0], 5.0);
560        let mut learner = OnlineConstraintLearner::new(constraint, 0.1, 100);
561
562        // Add some observations
563        learner.observe(Array1::from_vec(vec![3.0]), true)?; // Feasible
564        learner.observe(Array1::from_vec(vec![7.0]), false)?; // Infeasible
565
566        assert_eq!(learner.update_count(), 2);
567        assert!(learner.confidence() > 0.0);
568
569        Ok(())
570    }
571
572    #[test]
573    fn test_anomaly_detection() {
574        let mut detector = AnomalyBasedConstraintDiscovery::new(100, 3.0);
575
576        // Add normal samples
577        for _ in 0..20 {
578            detector.add_normal_sample(Array1::from_vec(vec![5.0, 10.0]));
579        }
580
581        // Detect anomaly
582        let is_anomaly = detector.detect_anomaly(&Array1::from_vec(vec![50.0, 100.0]));
583        assert!(is_anomaly);
584    }
585
586    #[test]
587    fn test_active_learning() {
588        let constraint = LinearConstraint::less_eq(vec![1.0], 5.0);
589        let mut learner = ActiveConstraintBoundaryLearner::new(constraint, 1.0, 100);
590
591        learner.add_unlabeled_sample(Array1::from_vec(vec![4.9])); // Near boundary
592        learner.add_unlabeled_sample(Array1::from_vec(vec![10.0])); // Far from boundary
593
594        assert_eq!(learner.num_unlabeled(), 1); // Only near-boundary sample added
595    }
596
597    #[test]
598    fn test_feedback_tuner() -> LogicResult<()> {
599        let constraint = LinearConstraint::less_eq(vec![1.0], 5.0);
600        let mut tuner = FeedbackConstraintTuner::new(constraint, 0.1, 0.8);
601
602        tuner.add_feedback(&Array1::from_vec(vec![3.0]), 0.9)?; // High satisfaction
603        tuner.add_feedback(&Array1::from_vec(vec![4.0]), 0.7)?; // Medium satisfaction
604
605        assert_eq!(tuner.num_feedback_samples(), 2);
606        assert!(tuner.average_satisfaction() > 0.0);
607
608        Ok(())
609    }
610
611    #[test]
612    fn test_online_learning_system() -> LogicResult<()> {
613        let constraint = LinearConstraint::less_eq(vec![1.0], 5.0);
614        let mut system = OnlineLearningSystem::new(constraint);
615
616        system.process_labeled_sample(Array1::from_vec(vec![3.0]), true)?;
617        system.process_unlabeled_sample(Array1::from_vec(vec![4.5]));
618        system.add_feedback(&Array1::from_vec(vec![3.5]), 0.9)?;
619
620        assert!(system.confidence() > 0.0);
621
622        Ok(())
623    }
624}