1use crate::{LogicError, LogicResult, ViolationComputable};
10use scirs2_core::ndarray::{Array1, Array2};
11use serde::{Deserialize, Serialize};
12
13#[derive(Debug, Clone)]
15pub struct TimeVaryingConstraint<C: ViolationComputable> {
16 #[allow(dead_code)]
18 name: String,
19 #[allow(dead_code)]
21 base_constraint: C,
22 schedule: Vec<(f32, ParameterUpdate)>,
24 current_time: f32,
26 #[allow(dead_code)]
28 interpolation: InterpolationMode,
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct ParameterUpdate {
34 pub scale: Option<f32>,
36 pub offset: Option<Array1<f32>>,
38 pub replacement: Option<ConstraintParams>,
40}
41
42#[derive(Debug, Clone, Serialize, Deserialize)]
44pub enum ConstraintParams {
45 Linear { a: Array2<f32>, b: Array1<f32> },
47 Quadratic {
49 q: Array2<f32>,
50 c: Array1<f32>,
51 d: f32,
52 },
53 Box {
55 lower: Array1<f32>,
56 upper: Array1<f32>,
57 },
58}
59
60#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
62pub enum InterpolationMode {
63 Step,
65 Linear,
67 Cubic,
69 Exponential { rate: f32 },
71}
72
73impl<C: ViolationComputable + Clone> TimeVaryingConstraint<C> {
74 pub fn new(
76 name: impl Into<String>,
77 base_constraint: C,
78 interpolation: InterpolationMode,
79 ) -> Self {
80 Self {
81 name: name.into(),
82 base_constraint,
83 schedule: Vec::new(),
84 current_time: 0.0,
85 interpolation,
86 }
87 }
88
89 pub fn schedule_update(&mut self, time: f32, update: ParameterUpdate) {
91 let pos = self.schedule.iter().position(|(t, _)| *t > time);
93 match pos {
94 Some(idx) => self.schedule.insert(idx, (time, update)),
95 None => self.schedule.push((time, update)),
96 }
97 }
98
99 pub fn advance_time(&mut self, time: f32) -> LogicResult<()> {
101 if time < self.current_time {
102 return Err(LogicError::InvalidInput(
103 "Cannot go backward in time".to_string(),
104 ));
105 }
106 self.current_time = time;
107 Ok(())
108 }
109
110 pub fn current_time(&self) -> f32 {
112 self.current_time
113 }
114
115 #[allow(dead_code)]
117 fn get_current_update(&self) -> Option<ParameterUpdate> {
118 let before_idx = self
120 .schedule
121 .iter()
122 .rposition(|(t, _)| *t <= self.current_time);
123 let after_idx = self
124 .schedule
125 .iter()
126 .position(|(t, _)| *t >= self.current_time);
127
128 match (before_idx, after_idx) {
129 (Some(i1), Some(i2)) if i1 != i2 => {
130 let (t1, u1) = &self.schedule[i1];
131 let (t2, u2) = &self.schedule[i2];
132 let alpha = (self.current_time - t1) / (t2 - t1);
134 Some(self.interpolate_updates(u1, u2, alpha))
135 }
136 (Some(i), None) | (None, Some(i)) => Some(self.schedule[i].1.clone()),
137 _ => None,
138 }
139 }
140
141 #[allow(dead_code)]
143 fn interpolate_updates(
144 &self,
145 u1: &ParameterUpdate,
146 u2: &ParameterUpdate,
147 alpha: f32,
148 ) -> ParameterUpdate {
149 let alpha = match self.interpolation {
150 InterpolationMode::Step => {
151 if alpha < 1.0 {
152 0.0
153 } else {
154 1.0
155 }
156 }
157 InterpolationMode::Linear => alpha,
158 InterpolationMode::Cubic => {
159 alpha * alpha * (3.0 - 2.0 * alpha)
161 }
162 InterpolationMode::Exponential { rate } => 1.0 - (-rate * alpha).exp(),
163 };
164
165 ParameterUpdate {
166 scale: match (u1.scale, u2.scale) {
167 (Some(s1), Some(s2)) => Some(s1 + (s2 - s1) * alpha),
168 (Some(s), None) | (None, Some(s)) => Some(s),
169 _ => None,
170 },
171 offset: match (&u1.offset, &u2.offset) {
172 (Some(o1), Some(o2)) => Some(o1 + &(o2 - o1) * alpha),
173 (Some(o), None) | (None, Some(o)) => Some(o.clone()),
174 _ => None,
175 },
176 replacement: if alpha < 0.5 {
177 u1.replacement.clone()
178 } else {
179 u2.replacement.clone()
180 },
181 }
182 }
183}
184
185#[derive(Debug, Clone)]
187pub struct StateDependentConstraint<C: ViolationComputable> {
188 #[allow(dead_code)]
190 name: String,
191 constraint: C,
193 activation_fn: ActivationFunction,
195 is_active: bool,
197}
198
199#[derive(Debug, Clone)]
201pub enum ActivationFunction {
202 NormThreshold { threshold: f32 },
204 ComponentThreshold { index: usize, threshold: f32 },
206 RegionBased {
208 lower: Array1<f32>,
209 upper: Array1<f32>,
210 },
211 VelocityBased { threshold: f32 },
213 Custom(fn(&Array1<f32>) -> bool),
215}
216
217impl<C: ViolationComputable + Clone> StateDependentConstraint<C> {
218 pub fn new(name: impl Into<String>, constraint: C, activation_fn: ActivationFunction) -> Self {
220 Self {
221 name: name.into(),
222 constraint,
223 activation_fn,
224 is_active: false,
225 }
226 }
227
228 pub fn update_activation(&mut self, state: &Array1<f32>) -> bool {
230 self.is_active = match &self.activation_fn {
231 ActivationFunction::NormThreshold { threshold } => {
232 let norm = state.iter().map(|x| x * x).sum::<f32>().sqrt();
233 norm > *threshold
234 }
235 ActivationFunction::ComponentThreshold { index, threshold } => state
236 .get(*index)
237 .map(|x| x.abs() > *threshold)
238 .unwrap_or(false),
239 ActivationFunction::RegionBased { lower, upper } => {
240 state.iter().zip(lower.iter()).all(|(x, l)| x >= l)
241 && state.iter().zip(upper.iter()).all(|(x, u)| x <= u)
242 }
243 ActivationFunction::VelocityBased { threshold } => {
244 state.iter().any(|x| x.abs() > *threshold)
246 }
247 ActivationFunction::Custom(f) => f(state),
248 };
249 self.is_active
250 }
251
252 pub fn is_active(&self) -> bool {
254 self.is_active
255 }
256
257 pub fn check_if_active(&self, state: &Array1<f32>) -> bool {
259 if self.is_active {
260 self.constraint.check(state.as_slice().unwrap_or(&[]))
261 } else {
262 true }
264 }
265
266 pub fn violation_if_active(&self, state: &Array1<f32>) -> f32 {
268 if self.is_active {
269 self.constraint.violation(state.as_slice().unwrap_or(&[]))
270 } else {
271 0.0
272 }
273 }
274}
275
276#[derive(Debug, Clone)]
278pub struct PredictiveConstraintAdapter<C: ViolationComputable> {
279 #[allow(dead_code)]
281 name: String,
282 base_constraint: C,
284 horizon: usize,
286 violation_history: Vec<f32>,
288 adaptation_rate: f32,
290 tightness: f32,
292}
293
294impl<C: ViolationComputable + Clone> PredictiveConstraintAdapter<C> {
295 pub fn new(
297 name: impl Into<String>,
298 base_constraint: C,
299 horizon: usize,
300 adaptation_rate: f32,
301 ) -> Self {
302 Self {
303 name: name.into(),
304 base_constraint,
305 horizon,
306 violation_history: Vec::new(),
307 adaptation_rate,
308 tightness: 1.0,
309 }
310 }
311
312 pub fn predict_violations(&self, trajectory: &[Array1<f32>]) -> Vec<f32> {
314 let mut violations = Vec::new();
315 for state in trajectory.iter().take(self.horizon) {
316 let viol = self
317 .base_constraint
318 .violation(state.as_slice().unwrap_or(&[]));
319 violations.push(viol);
320 }
321 violations
322 }
323
324 pub fn adapt(&mut self, predicted_violations: &[f32]) -> LogicResult<()> {
326 let mean_violation = if predicted_violations.is_empty() {
328 0.0
329 } else {
330 predicted_violations.iter().sum::<f32>() / predicted_violations.len() as f32
331 };
332
333 self.violation_history.push(mean_violation);
335 if self.violation_history.len() > 100 {
336 self.violation_history.remove(0);
337 }
338
339 if mean_violation > 0.0 {
341 self.tightness *= 1.0 + self.adaptation_rate * mean_violation;
343 } else {
344 self.tightness *= 1.0 - self.adaptation_rate * 0.1;
346 }
347
348 self.tightness = self.tightness.clamp(0.5, 2.0);
350
351 Ok(())
352 }
353
354 pub fn tightness(&self) -> f32 {
356 self.tightness
357 }
358
359 pub fn violation_history(&self) -> &[f32] {
361 &self.violation_history
362 }
363}
364
365#[derive(Debug, Clone)]
367pub struct ConstraintInterpolator<C: ViolationComputable> {
368 #[allow(dead_code)]
370 name: String,
371 start_constraint: C,
373 end_constraint: C,
375 alpha: f32,
377 mode: InterpolationMode,
379}
380
381impl<C: ViolationComputable + Clone> ConstraintInterpolator<C> {
382 pub fn new(
384 name: impl Into<String>,
385 start_constraint: C,
386 end_constraint: C,
387 mode: InterpolationMode,
388 ) -> Self {
389 Self {
390 name: name.into(),
391 start_constraint,
392 end_constraint,
393 alpha: 0.0,
394 mode,
395 }
396 }
397
398 pub fn set_alpha(&mut self, alpha: f32) -> LogicResult<()> {
400 if !(0.0..=1.0).contains(&alpha) {
401 return Err(LogicError::InvalidInput(
402 "Alpha must be in [0, 1]".to_string(),
403 ));
404 }
405 self.alpha = alpha;
406 Ok(())
407 }
408
409 pub fn alpha(&self) -> f32 {
411 self.alpha
412 }
413
414 pub fn violation(&self, state: &Array1<f32>) -> f32 {
416 let v1 = self
417 .start_constraint
418 .violation(state.as_slice().unwrap_or(&[]));
419 let v2 = self
420 .end_constraint
421 .violation(state.as_slice().unwrap_or(&[]));
422
423 let alpha = match self.mode {
424 InterpolationMode::Step => {
425 if self.alpha < 1.0 {
426 0.0
427 } else {
428 1.0
429 }
430 }
431 InterpolationMode::Linear => self.alpha,
432 InterpolationMode::Cubic => self.alpha * self.alpha * (3.0 - 2.0 * self.alpha),
433 InterpolationMode::Exponential { rate } => 1.0 - (-rate * self.alpha).exp(),
434 };
435
436 v1 * (1.0 - alpha) + v2 * alpha
437 }
438
439 pub fn check(&self, state: &Array1<f32>) -> bool {
441 self.violation(state) <= 0.0
442 }
443}
444
445#[derive(Debug, Clone)]
447pub struct TimeVaryingConstraintSet<C: ViolationComputable> {
448 state_dependent: Vec<StateDependentConstraint<C>>,
450 predictive: Vec<PredictiveConstraintAdapter<C>>,
452 interpolators: Vec<ConstraintInterpolator<C>>,
454 current_time: f32,
456}
457
458impl<C: ViolationComputable + Clone> TimeVaryingConstraintSet<C> {
459 pub fn new() -> Self {
461 Self {
462 state_dependent: Vec::new(),
463 predictive: Vec::new(),
464 interpolators: Vec::new(),
465 current_time: 0.0,
466 }
467 }
468
469 pub fn add_state_dependent(&mut self, constraint: StateDependentConstraint<C>) {
471 self.state_dependent.push(constraint);
472 }
473
474 pub fn add_predictive(&mut self, adapter: PredictiveConstraintAdapter<C>) {
476 self.predictive.push(adapter);
477 }
478
479 pub fn add_interpolator(&mut self, interpolator: ConstraintInterpolator<C>) {
481 self.interpolators.push(interpolator);
482 }
483
484 pub fn advance_time(&mut self, time: f32) -> LogicResult<()> {
486 self.current_time = time;
487 Ok(())
488 }
489
490 pub fn update_activations(&mut self, state: &Array1<f32>) {
492 for constraint in &mut self.state_dependent {
493 constraint.update_activation(state);
494 }
495 }
496
497 pub fn num_active(&self) -> usize {
499 self.state_dependent
500 .iter()
501 .filter(|c| c.is_active())
502 .count()
503 + self.predictive.len()
504 + self.interpolators.len()
505 }
506
507 pub fn check_all(&self, state: &Array1<f32>) -> bool {
509 for constraint in &self.state_dependent {
511 if !constraint.check_if_active(state) {
512 return false;
513 }
514 }
515
516 for interpolator in &self.interpolators {
518 if !interpolator.check(state) {
519 return false;
520 }
521 }
522
523 true
524 }
525
526 pub fn total_violation(&self, state: &Array1<f32>) -> f32 {
528 let mut total = 0.0;
529
530 for constraint in &self.state_dependent {
532 total += constraint.violation_if_active(state).max(0.0);
533 }
534
535 for interpolator in &self.interpolators {
537 total += interpolator.violation(state).max(0.0);
538 }
539
540 total
541 }
542}
543
544impl<C: ViolationComputable + Clone> Default for TimeVaryingConstraintSet<C> {
545 fn default() -> Self {
546 Self::new()
547 }
548}
549
550#[cfg(test)]
551mod tests {
552 use super::*;
553 use crate::LinearConstraint;
554
555 #[test]
556 fn test_state_dependent_activation() {
557 let base = LinearConstraint::less_eq(vec![1.0], 1.0);
559
560 let mut sdc = StateDependentConstraint::new(
561 "test",
562 base,
563 ActivationFunction::NormThreshold { threshold: 5.0 },
564 );
565
566 let state = Array1::from_vec(vec![1.0, 2.0, 3.0]);
567 let active = sdc.update_activation(&state);
568 assert!(!active); let state2 = Array1::from_vec(vec![3.0, 4.0, 5.0]);
571 let active2 = sdc.update_activation(&state2);
572 assert!(active2); }
574
575 #[test]
576 fn test_predictive_adaptation() {
577 let base = LinearConstraint::less_eq(vec![1.0], 1.0);
579
580 let mut adapter = PredictiveConstraintAdapter::new("test", base, 5, 0.1);
581
582 let trajectory = vec![
583 Array1::from_vec(vec![0.5]),
584 Array1::from_vec(vec![0.8]),
585 Array1::from_vec(vec![1.2]), ];
587
588 let violations = adapter.predict_violations(&trajectory);
589 assert_eq!(violations.len(), 3);
590
591 let _ = adapter.adapt(&violations);
592 assert!(adapter.tightness() >= 1.0); }
594
595 #[test]
596 fn test_constraint_interpolation() -> LogicResult<()> {
597 let start = LinearConstraint::less_eq(vec![1.0], 1.0);
599 let end = LinearConstraint::less_eq(vec![1.0], 2.0);
600
601 let mut interp = ConstraintInterpolator::new("test", start, end, InterpolationMode::Linear);
602
603 interp.set_alpha(0.5)?;
604 assert_eq!(interp.alpha(), 0.5);
605
606 let state = Array1::from_vec(vec![1.5]);
607 let violation = interp.violation(&state);
608 assert!((0.0..=0.5).contains(&violation));
611
612 Ok(())
613 }
614
615 #[test]
616 fn test_constraint_set() {
617 let mut set = TimeVaryingConstraintSet::new();
618
619 let base = LinearConstraint::less_eq(vec![1.0], 1.0);
621 let sdc = StateDependentConstraint::new(
622 "state_dep",
623 base,
624 ActivationFunction::NormThreshold { threshold: 5.0 },
625 );
626
627 set.add_state_dependent(sdc);
628
629 let state = Array1::from_vec(vec![1.0, 2.0]);
630 set.update_activations(&state);
631
632 assert_eq!(set.num_active(), 0); let state2 = Array1::from_vec(vec![5.0, 5.0]);
635 set.update_activations(&state2);
636 assert_eq!(set.num_active(), 1); }
638}