1use crate::{LinearConstraint, LogicResult};
10use scirs2_core::ndarray::Array1;
11use std::collections::VecDeque;
12
13#[derive(Debug, Clone)]
15pub struct OnlineConstraintLearner {
16 constraint: LinearConstraint,
18 #[allow(dead_code)]
20 learning_rate: f32,
21 data_buffer: VecDeque<(Array1<f32>, bool)>, #[allow(dead_code)]
25 max_buffer_size: usize,
26 update_count: usize,
28}
29
30impl OnlineConstraintLearner {
31 pub fn new(
33 initial_constraint: LinearConstraint,
34 learning_rate: f32,
35 max_buffer_size: usize,
36 ) -> Self {
37 Self {
38 constraint: initial_constraint,
39 learning_rate,
40 data_buffer: VecDeque::new(),
41 max_buffer_size,
42 update_count: 0,
43 }
44 }
45
46 pub fn observe(&mut self, sample: Array1<f32>, is_feasible: bool) -> LogicResult<()> {
48 self.data_buffer.push_back((sample.clone(), is_feasible));
50 if self.data_buffer.len() > self.max_buffer_size {
51 self.data_buffer.pop_front();
52 }
53
54 self.refine_constraint(&sample, is_feasible)?;
56 self.update_count += 1;
57
58 Ok(())
59 }
60
61 fn refine_constraint(&mut self, sample: &Array1<f32>, is_feasible: bool) -> LogicResult<()> {
63 let sample_slice = sample.as_slice().unwrap_or(&[]);
64 let current_satisfied = self.constraint.check(sample_slice);
65
66 if current_satisfied == is_feasible {
68 return Ok(());
69 }
70
71 let violation = self.constraint.violation(sample_slice);
74
75 let update_scale = if is_feasible {
77 self.learning_rate * violation
79 } else {
80 -self.learning_rate
82 };
83
84 let _ = update_scale; Ok(())
89 }
90
91 pub fn get_constraint(&self) -> &LinearConstraint {
93 &self.constraint
94 }
95
96 pub fn update_count(&self) -> usize {
98 self.update_count
99 }
100
101 pub fn confidence(&self) -> f32 {
103 if self.data_buffer.is_empty() {
104 return 0.0;
105 }
106
107 let correct = self
109 .data_buffer
110 .iter()
111 .filter(|(sample, is_feasible)| {
112 let satisfied = self.constraint.check(sample.as_slice().unwrap_or(&[]));
113 satisfied == *is_feasible
114 })
115 .count();
116
117 correct as f32 / self.data_buffer.len() as f32
118 }
119}
120
121#[derive(Debug, Clone)]
123pub struct AnomalyBasedConstraintDiscovery {
124 normal_samples: VecDeque<Array1<f32>>,
126 max_samples: usize,
128 anomaly_threshold: f32,
130 discovered_constraints: Vec<LinearConstraint>,
132}
133
134impl AnomalyBasedConstraintDiscovery {
135 pub fn new(max_samples: usize, anomaly_threshold: f32) -> Self {
137 Self {
138 normal_samples: VecDeque::new(),
139 max_samples,
140 anomaly_threshold,
141 discovered_constraints: Vec::new(),
142 }
143 }
144
145 pub fn add_normal_sample(&mut self, sample: Array1<f32>) {
147 self.normal_samples.push_back(sample);
148 if self.normal_samples.len() > self.max_samples {
149 self.normal_samples.pop_front();
150 }
151 }
152
153 pub fn detect_anomaly(&mut self, sample: &Array1<f32>) -> bool {
155 if self.normal_samples.len() < 2 {
156 return false; }
158
159 let dim = sample.len();
161 let n = self.normal_samples.len();
162
163 let mut is_anomalous = false;
165
166 for d in 0..dim {
167 let mean: f32 = self.normal_samples.iter().map(|s| s[d]).sum::<f32>() / n as f32;
168
169 let variance: f32 = self
170 .normal_samples
171 .iter()
172 .map(|s| (s[d] - mean).powi(2))
173 .sum::<f32>()
174 / n as f32;
175
176 let std_dev = variance.sqrt();
177
178 let z_score = (sample[d] - mean).abs() / (std_dev + 1e-8);
180 if z_score > self.anomaly_threshold {
181 is_anomalous = true;
182
183 self.discover_bound_constraint(d, mean, std_dev);
185 }
186 }
187
188 is_anomalous
189 }
190
191 fn discover_bound_constraint(&mut self, dim: usize, mean: f32, std_dev: f32) {
193 let upper_bound = mean + self.anomaly_threshold * std_dev;
195 let mut coeffs = vec![0.0; dim + 1];
196 coeffs[dim] = 1.0;
197
198 let constraint = LinearConstraint::less_eq(coeffs, upper_bound);
199
200 let is_duplicate = self.discovered_constraints.iter().any(|c| {
202 c.coefficients().len() == constraint.coefficients().len()
203 && c.coefficients()
204 .iter()
205 .zip(constraint.coefficients().iter())
206 .all(|(a, b)| (a - b).abs() < 0.1)
207 });
208
209 if !is_duplicate {
210 self.discovered_constraints.push(constraint);
211 }
212 }
213
214 pub fn discovered_constraints(&self) -> &[LinearConstraint] {
216 &self.discovered_constraints
217 }
218
219 pub fn num_discovered(&self) -> usize {
221 self.discovered_constraints.len()
222 }
223}
224
225#[derive(Debug, Clone)]
227pub struct ActiveConstraintBoundaryLearner {
228 constraint: LinearConstraint,
230 boundary_samples: Vec<(Array1<f32>, Option<bool>)>, uncertainty_threshold: f32,
234 max_boundary_samples: usize,
236}
237
238impl ActiveConstraintBoundaryLearner {
239 pub fn new(
241 initial_constraint: LinearConstraint,
242 uncertainty_threshold: f32,
243 max_boundary_samples: usize,
244 ) -> Self {
245 Self {
246 constraint: initial_constraint,
247 boundary_samples: Vec::new(),
248 uncertainty_threshold,
249 max_boundary_samples,
250 }
251 }
252
253 pub fn query_next(&self) -> Option<Array1<f32>> {
255 self.boundary_samples
257 .iter()
258 .filter(|(_, label)| label.is_none())
259 .min_by(|(s1, _), (s2, _)| {
260 let v1 = self
261 .constraint
262 .violation(s1.as_slice().unwrap_or(&[]))
263 .abs();
264 let v2 = self
265 .constraint
266 .violation(s2.as_slice().unwrap_or(&[]))
267 .abs();
268 v1.partial_cmp(&v2).unwrap_or(std::cmp::Ordering::Equal)
269 })
270 .map(|(s, _)| s.clone())
271 }
272
273 pub fn add_labeled_sample(&mut self, sample: Array1<f32>, is_feasible: bool) {
275 let violation = self
276 .constraint
277 .violation(sample.as_slice().unwrap_or(&[]))
278 .abs();
279
280 if violation < self.uncertainty_threshold {
282 self.boundary_samples.push((sample, Some(is_feasible)));
283 if self.boundary_samples.len() > self.max_boundary_samples {
284 self.boundary_samples.remove(0);
285 }
286 }
287 }
288
289 pub fn add_unlabeled_sample(&mut self, sample: Array1<f32>) {
291 let violation = self
292 .constraint
293 .violation(sample.as_slice().unwrap_or(&[]))
294 .abs();
295
296 if violation < self.uncertainty_threshold {
297 self.boundary_samples.push((sample, None));
298 if self.boundary_samples.len() > self.max_boundary_samples {
299 self.boundary_samples.remove(0);
300 }
301 }
302 }
303
304 pub fn refine(&mut self) -> LogicResult<()> {
306 let labeled: Vec<_> = self
309 .boundary_samples
310 .iter()
311 .filter_map(|(s, l)| l.map(|label| (s, label)))
312 .collect();
313
314 if labeled.len() < 2 {
315 return Ok(()); }
317
318 Ok(())
321 }
322
323 pub fn get_constraint(&self) -> &LinearConstraint {
325 &self.constraint
326 }
327
328 pub fn num_boundary_samples(&self) -> usize {
330 self.boundary_samples.len()
331 }
332
333 pub fn num_unlabeled(&self) -> usize {
335 self.boundary_samples
336 .iter()
337 .filter(|(_, l)| l.is_none())
338 .count()
339 }
340}
341
342#[derive(Debug, Clone)]
344pub struct FeedbackConstraintTuner {
345 constraint: LinearConstraint,
347 feedback_history: Vec<(f32, f32)>, #[allow(dead_code)]
351 adaptation_rate: f32,
352 target_satisfaction: f32,
354}
355
356impl FeedbackConstraintTuner {
357 pub fn new(
359 initial_constraint: LinearConstraint,
360 adaptation_rate: f32,
361 target_satisfaction: f32,
362 ) -> Self {
363 Self {
364 constraint: initial_constraint,
365 feedback_history: Vec::new(),
366 adaptation_rate,
367 target_satisfaction,
368 }
369 }
370
371 pub fn add_feedback(&mut self, sample: &Array1<f32>, satisfaction: f32) -> LogicResult<()> {
373 let violation = self.constraint.violation(sample.as_slice().unwrap_or(&[]));
374 self.feedback_history.push((violation, satisfaction));
375
376 self.tune()?;
378
379 Ok(())
380 }
381
382 fn tune(&mut self) -> LogicResult<()> {
384 if self.feedback_history.len() < 5 {
385 return Ok(()); }
387
388 let avg_satisfaction: f32 = self.feedback_history.iter().map(|(_, s)| s).sum::<f32>()
390 / self.feedback_history.len() as f32;
391
392 let satisfaction_gap = self.target_satisfaction - avg_satisfaction;
394
395 if satisfaction_gap.abs() > 0.1 {
396 let _ = satisfaction_gap; }
401
402 Ok(())
403 }
404
405 pub fn get_constraint(&self) -> &LinearConstraint {
407 &self.constraint
408 }
409
410 pub fn average_satisfaction(&self) -> f32 {
412 if self.feedback_history.is_empty() {
413 return 0.0;
414 }
415
416 self.feedback_history.iter().map(|(_, s)| s).sum::<f32>()
417 / self.feedback_history.len() as f32
418 }
419
420 pub fn num_feedback_samples(&self) -> usize {
422 self.feedback_history.len()
423 }
424}
425
426#[derive(Debug, Clone)]
428pub struct OnlineLearningSystem {
429 incremental_learner: OnlineConstraintLearner,
431 anomaly_detector: AnomalyBasedConstraintDiscovery,
433 active_learner: ActiveConstraintBoundaryLearner,
435 feedback_tuner: FeedbackConstraintTuner,
437 use_incremental: bool,
439 use_anomaly: bool,
440 use_active: bool,
441 use_feedback: bool,
442}
443
444impl OnlineLearningSystem {
445 pub fn new(initial_constraint: LinearConstraint) -> Self {
447 Self {
448 incremental_learner: OnlineConstraintLearner::new(
449 initial_constraint.clone(),
450 0.01,
451 1000,
452 ),
453 anomaly_detector: AnomalyBasedConstraintDiscovery::new(1000, 3.0),
454 active_learner: ActiveConstraintBoundaryLearner::new(
455 initial_constraint.clone(),
456 0.1,
457 100,
458 ),
459 feedback_tuner: FeedbackConstraintTuner::new(initial_constraint, 0.01, 0.8),
460 use_incremental: true,
461 use_anomaly: true,
462 use_active: true,
463 use_feedback: true,
464 }
465 }
466
467 pub fn process_labeled_sample(
469 &mut self,
470 sample: Array1<f32>,
471 is_feasible: bool,
472 ) -> LogicResult<()> {
473 if self.use_incremental {
474 self.incremental_learner
475 .observe(sample.clone(), is_feasible)?;
476 }
477
478 if self.use_active {
479 self.active_learner
480 .add_labeled_sample(sample.clone(), is_feasible);
481 }
482
483 if is_feasible && self.use_anomaly {
484 self.anomaly_detector.add_normal_sample(sample);
485 }
486
487 Ok(())
488 }
489
490 pub fn process_unlabeled_sample(&mut self, sample: Array1<f32>) {
492 if self.use_anomaly {
493 self.anomaly_detector.detect_anomaly(&sample);
494 }
495
496 if self.use_active {
497 self.active_learner.add_unlabeled_sample(sample);
498 }
499 }
500
501 pub fn add_feedback(&mut self, sample: &Array1<f32>, satisfaction: f32) -> LogicResult<()> {
503 if self.use_feedback {
504 self.feedback_tuner.add_feedback(sample, satisfaction)?;
505 }
506 Ok(())
507 }
508
509 pub fn get_best_constraint(&self) -> &LinearConstraint {
511 self.incremental_learner.get_constraint()
514 }
515
516 pub fn confidence(&self) -> f32 {
518 self.incremental_learner.confidence()
519 }
520
521 pub fn discovered_constraints(&self) -> &[LinearConstraint] {
523 self.anomaly_detector.discovered_constraints()
524 }
525
526 pub fn query_next(&self) -> Option<Array1<f32>> {
528 if self.use_active {
529 self.active_learner.query_next()
530 } else {
531 None
532 }
533 }
534
535 pub fn set_use_incremental(&mut self, use_it: bool) {
537 self.use_incremental = use_it;
538 }
539
540 pub fn set_use_anomaly(&mut self, use_it: bool) {
541 self.use_anomaly = use_it;
542 }
543
544 pub fn set_use_active(&mut self, use_it: bool) {
545 self.use_active = use_it;
546 }
547
548 pub fn set_use_feedback(&mut self, use_it: bool) {
549 self.use_feedback = use_it;
550 }
551}
552
553#[cfg(test)]
554mod tests {
555 use super::*;
556
557 #[test]
558 fn test_online_learner_basic() -> LogicResult<()> {
559 let constraint = LinearConstraint::less_eq(vec![1.0], 5.0);
560 let mut learner = OnlineConstraintLearner::new(constraint, 0.1, 100);
561
562 learner.observe(Array1::from_vec(vec![3.0]), true)?; learner.observe(Array1::from_vec(vec![7.0]), false)?; assert_eq!(learner.update_count(), 2);
567 assert!(learner.confidence() > 0.0);
568
569 Ok(())
570 }
571
572 #[test]
573 fn test_anomaly_detection() {
574 let mut detector = AnomalyBasedConstraintDiscovery::new(100, 3.0);
575
576 for _ in 0..20 {
578 detector.add_normal_sample(Array1::from_vec(vec![5.0, 10.0]));
579 }
580
581 let is_anomaly = detector.detect_anomaly(&Array1::from_vec(vec![50.0, 100.0]));
583 assert!(is_anomaly);
584 }
585
586 #[test]
587 fn test_active_learning() {
588 let constraint = LinearConstraint::less_eq(vec![1.0], 5.0);
589 let mut learner = ActiveConstraintBoundaryLearner::new(constraint, 1.0, 100);
590
591 learner.add_unlabeled_sample(Array1::from_vec(vec![4.9])); learner.add_unlabeled_sample(Array1::from_vec(vec![10.0])); assert_eq!(learner.num_unlabeled(), 1); }
596
597 #[test]
598 fn test_feedback_tuner() -> LogicResult<()> {
599 let constraint = LinearConstraint::less_eq(vec![1.0], 5.0);
600 let mut tuner = FeedbackConstraintTuner::new(constraint, 0.1, 0.8);
601
602 tuner.add_feedback(&Array1::from_vec(vec![3.0]), 0.9)?; tuner.add_feedback(&Array1::from_vec(vec![4.0]), 0.7)?; assert_eq!(tuner.num_feedback_samples(), 2);
606 assert!(tuner.average_satisfaction() > 0.0);
607
608 Ok(())
609 }
610
611 #[test]
612 fn test_online_learning_system() -> LogicResult<()> {
613 let constraint = LinearConstraint::less_eq(vec![1.0], 5.0);
614 let mut system = OnlineLearningSystem::new(constraint);
615
616 system.process_labeled_sample(Array1::from_vec(vec![3.0]), true)?;
617 system.process_unlabeled_sample(Array1::from_vec(vec![4.5]));
618 system.add_feedback(&Array1::from_vec(vec![3.5]), 0.9)?;
619
620 assert!(system.confidence() > 0.0);
621
622 Ok(())
623 }
624}