kizzasi_logic/
timevarying.rs

1//! Time-varying Constraints
2//!
3//! This module provides constraints that change over time, including:
4//! - Scheduled constraint changes
5//! - State-dependent constraint activation
6//! - Predictive constraint adaptation
7//! - Temporal constraint interpolation
8
9use crate::{LogicError, LogicResult, ViolationComputable};
10use scirs2_core::ndarray::{Array1, Array2};
11use serde::{Deserialize, Serialize};
12
13/// A constraint that varies over time
14#[derive(Debug, Clone)]
15pub struct TimeVaryingConstraint<C: ViolationComputable> {
16    /// Name of the constraint
17    #[allow(dead_code)]
18    name: String,
19    /// Base constraint that gets modified
20    #[allow(dead_code)]
21    base_constraint: C,
22    /// Scheduled parameter changes (time, update) sorted by time
23    schedule: Vec<(f32, ParameterUpdate)>,
24    /// Current time
25    current_time: f32,
26    /// Interpolation mode
27    #[allow(dead_code)]
28    interpolation: InterpolationMode,
29}
30
31/// Parameter update for constraint modification
32#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct ParameterUpdate {
34    /// Scale factor for constraint tightness (1.0 = no change)
35    pub scale: Option<f32>,
36    /// Additive offset for constraint bounds
37    pub offset: Option<Array1<f32>>,
38    /// Completely replace constraint parameters
39    pub replacement: Option<ConstraintParams>,
40}
41
42/// Constraint parameters that can be updated
43#[derive(Debug, Clone, Serialize, Deserialize)]
44pub enum ConstraintParams {
45    /// Linear constraint: A*x <= b
46    Linear { a: Array2<f32>, b: Array1<f32> },
47    /// Quadratic constraint: x^T Q x + c^T x <= d
48    Quadratic {
49        q: Array2<f32>,
50        c: Array1<f32>,
51        d: f32,
52    },
53    /// Box constraint: lower <= x <= upper
54    Box {
55        lower: Array1<f32>,
56        upper: Array1<f32>,
57    },
58}
59
60/// Interpolation mode for smooth transitions
61#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
62pub enum InterpolationMode {
63    /// No interpolation, step changes
64    Step,
65    /// Linear interpolation between keyframes
66    Linear,
67    /// Smooth cubic interpolation
68    Cubic,
69    /// Exponential decay/growth
70    Exponential { rate: f32 },
71}
72
73impl<C: ViolationComputable + Clone> TimeVaryingConstraint<C> {
74    /// Create a new time-varying constraint
75    pub fn new(
76        name: impl Into<String>,
77        base_constraint: C,
78        interpolation: InterpolationMode,
79    ) -> Self {
80        Self {
81            name: name.into(),
82            base_constraint,
83            schedule: Vec::new(),
84            current_time: 0.0,
85            interpolation,
86        }
87    }
88
89    /// Schedule a parameter update at a specific time
90    pub fn schedule_update(&mut self, time: f32, update: ParameterUpdate) {
91        // Insert in sorted order
92        let pos = self.schedule.iter().position(|(t, _)| *t > time);
93        match pos {
94            Some(idx) => self.schedule.insert(idx, (time, update)),
95            None => self.schedule.push((time, update)),
96        }
97    }
98
99    /// Advance time and update constraint parameters
100    pub fn advance_time(&mut self, time: f32) -> LogicResult<()> {
101        if time < self.current_time {
102            return Err(LogicError::InvalidInput(
103                "Cannot go backward in time".to_string(),
104            ));
105        }
106        self.current_time = time;
107        Ok(())
108    }
109
110    /// Get current time
111    pub fn current_time(&self) -> f32 {
112        self.current_time
113    }
114
115    /// Get interpolated parameter update for current time
116    #[allow(dead_code)]
117    fn get_current_update(&self) -> Option<ParameterUpdate> {
118        // Find surrounding keyframes
119        let before_idx = self
120            .schedule
121            .iter()
122            .rposition(|(t, _)| *t <= self.current_time);
123        let after_idx = self
124            .schedule
125            .iter()
126            .position(|(t, _)| *t >= self.current_time);
127
128        match (before_idx, after_idx) {
129            (Some(i1), Some(i2)) if i1 != i2 => {
130                let (t1, u1) = &self.schedule[i1];
131                let (t2, u2) = &self.schedule[i2];
132                // Interpolate between keyframes
133                let alpha = (self.current_time - t1) / (t2 - t1);
134                Some(self.interpolate_updates(u1, u2, alpha))
135            }
136            (Some(i), None) | (None, Some(i)) => Some(self.schedule[i].1.clone()),
137            _ => None,
138        }
139    }
140
141    /// Interpolate between two parameter updates
142    #[allow(dead_code)]
143    fn interpolate_updates(
144        &self,
145        u1: &ParameterUpdate,
146        u2: &ParameterUpdate,
147        alpha: f32,
148    ) -> ParameterUpdate {
149        let alpha = match self.interpolation {
150            InterpolationMode::Step => {
151                if alpha < 1.0 {
152                    0.0
153                } else {
154                    1.0
155                }
156            }
157            InterpolationMode::Linear => alpha,
158            InterpolationMode::Cubic => {
159                // Smooth step function
160                alpha * alpha * (3.0 - 2.0 * alpha)
161            }
162            InterpolationMode::Exponential { rate } => 1.0 - (-rate * alpha).exp(),
163        };
164
165        ParameterUpdate {
166            scale: match (u1.scale, u2.scale) {
167                (Some(s1), Some(s2)) => Some(s1 + (s2 - s1) * alpha),
168                (Some(s), None) | (None, Some(s)) => Some(s),
169                _ => None,
170            },
171            offset: match (&u1.offset, &u2.offset) {
172                (Some(o1), Some(o2)) => Some(o1 + &(o2 - o1) * alpha),
173                (Some(o), None) | (None, Some(o)) => Some(o.clone()),
174                _ => None,
175            },
176            replacement: if alpha < 0.5 {
177                u1.replacement.clone()
178            } else {
179                u2.replacement.clone()
180            },
181        }
182    }
183}
184
185/// State-dependent constraint activation
186#[derive(Debug, Clone)]
187pub struct StateDependentConstraint<C: ViolationComputable> {
188    /// Name of the constraint
189    #[allow(dead_code)]
190    name: String,
191    /// The constraint to apply when active
192    constraint: C,
193    /// Activation function: returns true if constraint should be active
194    activation_fn: ActivationFunction,
195    /// Current activation state
196    is_active: bool,
197}
198
199/// Activation function type
200#[derive(Debug, Clone)]
201pub enum ActivationFunction {
202    /// Activate when state norm exceeds threshold
203    NormThreshold { threshold: f32 },
204    /// Activate when specific state component exceeds threshold
205    ComponentThreshold { index: usize, threshold: f32 },
206    /// Activate when state enters a region
207    RegionBased {
208        lower: Array1<f32>,
209        upper: Array1<f32>,
210    },
211    /// Activate based on state velocity (rate of change)
212    VelocityBased { threshold: f32 },
213    /// Custom activation function
214    Custom(fn(&Array1<f32>) -> bool),
215}
216
217impl<C: ViolationComputable + Clone> StateDependentConstraint<C> {
218    /// Create a new state-dependent constraint
219    pub fn new(name: impl Into<String>, constraint: C, activation_fn: ActivationFunction) -> Self {
220        Self {
221            name: name.into(),
222            constraint,
223            activation_fn,
224            is_active: false,
225        }
226    }
227
228    /// Update activation state based on current system state
229    pub fn update_activation(&mut self, state: &Array1<f32>) -> bool {
230        self.is_active = match &self.activation_fn {
231            ActivationFunction::NormThreshold { threshold } => {
232                let norm = state.iter().map(|x| x * x).sum::<f32>().sqrt();
233                norm > *threshold
234            }
235            ActivationFunction::ComponentThreshold { index, threshold } => state
236                .get(*index)
237                .map(|x| x.abs() > *threshold)
238                .unwrap_or(false),
239            ActivationFunction::RegionBased { lower, upper } => {
240                state.iter().zip(lower.iter()).all(|(x, l)| x >= l)
241                    && state.iter().zip(upper.iter()).all(|(x, u)| x <= u)
242            }
243            ActivationFunction::VelocityBased { threshold } => {
244                // This requires historical state; simplified version
245                state.iter().any(|x| x.abs() > *threshold)
246            }
247            ActivationFunction::Custom(f) => f(state),
248        };
249        self.is_active
250    }
251
252    /// Check if constraint is currently active
253    pub fn is_active(&self) -> bool {
254        self.is_active
255    }
256
257    /// Check constraint if active
258    pub fn check_if_active(&self, state: &Array1<f32>) -> bool {
259        if self.is_active {
260            self.constraint.check(state.as_slice().unwrap_or(&[]))
261        } else {
262            true // Inactive constraints are trivially satisfied
263        }
264    }
265
266    /// Get violation if active
267    pub fn violation_if_active(&self, state: &Array1<f32>) -> f32 {
268        if self.is_active {
269            self.constraint.violation(state.as_slice().unwrap_or(&[]))
270        } else {
271            0.0
272        }
273    }
274}
275
276/// Predictive constraint adaptation
277#[derive(Debug, Clone)]
278pub struct PredictiveConstraintAdapter<C: ViolationComputable> {
279    /// Name of the adapter
280    #[allow(dead_code)]
281    name: String,
282    /// Base constraint
283    base_constraint: C,
284    /// Prediction horizon (steps ahead)
285    horizon: usize,
286    /// Historical violations for learning
287    violation_history: Vec<f32>,
288    /// Adaptation rate (how quickly to adjust)
289    adaptation_rate: f32,
290    /// Current tightness multiplier
291    tightness: f32,
292}
293
294impl<C: ViolationComputable + Clone> PredictiveConstraintAdapter<C> {
295    /// Create a new predictive constraint adapter
296    pub fn new(
297        name: impl Into<String>,
298        base_constraint: C,
299        horizon: usize,
300        adaptation_rate: f32,
301    ) -> Self {
302        Self {
303            name: name.into(),
304            base_constraint,
305            horizon,
306            violation_history: Vec::new(),
307            adaptation_rate,
308            tightness: 1.0,
309        }
310    }
311
312    /// Predict future violations based on trajectory
313    pub fn predict_violations(&self, trajectory: &[Array1<f32>]) -> Vec<f32> {
314        let mut violations = Vec::new();
315        for state in trajectory.iter().take(self.horizon) {
316            let viol = self
317                .base_constraint
318                .violation(state.as_slice().unwrap_or(&[]));
319            violations.push(viol);
320        }
321        violations
322    }
323
324    /// Adapt constraint based on predicted violations
325    pub fn adapt(&mut self, predicted_violations: &[f32]) -> LogicResult<()> {
326        // Calculate mean predicted violation
327        let mean_violation = if predicted_violations.is_empty() {
328            0.0
329        } else {
330            predicted_violations.iter().sum::<f32>() / predicted_violations.len() as f32
331        };
332
333        // Record in history
334        self.violation_history.push(mean_violation);
335        if self.violation_history.len() > 100 {
336            self.violation_history.remove(0);
337        }
338
339        // Adapt tightness: tighten if violations predicted, loosen if safe
340        if mean_violation > 0.0 {
341            // Tighten constraint
342            self.tightness *= 1.0 + self.adaptation_rate * mean_violation;
343        } else {
344            // Gradually loosen if no violations
345            self.tightness *= 1.0 - self.adaptation_rate * 0.1;
346        }
347
348        // Keep tightness in reasonable range
349        self.tightness = self.tightness.clamp(0.5, 2.0);
350
351        Ok(())
352    }
353
354    /// Get current tightness multiplier
355    pub fn tightness(&self) -> f32 {
356        self.tightness
357    }
358
359    /// Get violation history
360    pub fn violation_history(&self) -> &[f32] {
361        &self.violation_history
362    }
363}
364
365/// Temporal constraint interpolation
366#[derive(Debug, Clone)]
367pub struct ConstraintInterpolator<C: ViolationComputable> {
368    /// Name of the interpolator
369    #[allow(dead_code)]
370    name: String,
371    /// Start constraint
372    start_constraint: C,
373    /// End constraint
374    end_constraint: C,
375    /// Interpolation parameter (0.0 to 1.0)
376    alpha: f32,
377    /// Interpolation mode
378    mode: InterpolationMode,
379}
380
381impl<C: ViolationComputable + Clone> ConstraintInterpolator<C> {
382    /// Create a new constraint interpolator
383    pub fn new(
384        name: impl Into<String>,
385        start_constraint: C,
386        end_constraint: C,
387        mode: InterpolationMode,
388    ) -> Self {
389        Self {
390            name: name.into(),
391            start_constraint,
392            end_constraint,
393            alpha: 0.0,
394            mode,
395        }
396    }
397
398    /// Set interpolation parameter (0.0 = start, 1.0 = end)
399    pub fn set_alpha(&mut self, alpha: f32) -> LogicResult<()> {
400        if !(0.0..=1.0).contains(&alpha) {
401            return Err(LogicError::InvalidInput(
402                "Alpha must be in [0, 1]".to_string(),
403            ));
404        }
405        self.alpha = alpha;
406        Ok(())
407    }
408
409    /// Get current interpolation parameter
410    pub fn alpha(&self) -> f32 {
411        self.alpha
412    }
413
414    /// Compute interpolated violation
415    pub fn violation(&self, state: &Array1<f32>) -> f32 {
416        let v1 = self
417            .start_constraint
418            .violation(state.as_slice().unwrap_or(&[]));
419        let v2 = self
420            .end_constraint
421            .violation(state.as_slice().unwrap_or(&[]));
422
423        let alpha = match self.mode {
424            InterpolationMode::Step => {
425                if self.alpha < 1.0 {
426                    0.0
427                } else {
428                    1.0
429                }
430            }
431            InterpolationMode::Linear => self.alpha,
432            InterpolationMode::Cubic => self.alpha * self.alpha * (3.0 - 2.0 * self.alpha),
433            InterpolationMode::Exponential { rate } => 1.0 - (-rate * self.alpha).exp(),
434        };
435
436        v1 * (1.0 - alpha) + v2 * alpha
437    }
438
439    /// Check interpolated constraint
440    pub fn check(&self, state: &Array1<f32>) -> bool {
441        self.violation(state) <= 0.0
442    }
443}
444
445/// Manager for multiple time-varying constraints
446#[derive(Debug, Clone)]
447pub struct TimeVaryingConstraintSet<C: ViolationComputable> {
448    /// Collection of state-dependent constraints
449    state_dependent: Vec<StateDependentConstraint<C>>,
450    /// Collection of predictive adapters
451    predictive: Vec<PredictiveConstraintAdapter<C>>,
452    /// Collection of interpolators
453    interpolators: Vec<ConstraintInterpolator<C>>,
454    /// Current global time
455    current_time: f32,
456}
457
458impl<C: ViolationComputable + Clone> TimeVaryingConstraintSet<C> {
459    /// Create a new constraint set
460    pub fn new() -> Self {
461        Self {
462            state_dependent: Vec::new(),
463            predictive: Vec::new(),
464            interpolators: Vec::new(),
465            current_time: 0.0,
466        }
467    }
468
469    /// Add a state-dependent constraint
470    pub fn add_state_dependent(&mut self, constraint: StateDependentConstraint<C>) {
471        self.state_dependent.push(constraint);
472    }
473
474    /// Add a predictive adapter
475    pub fn add_predictive(&mut self, adapter: PredictiveConstraintAdapter<C>) {
476        self.predictive.push(adapter);
477    }
478
479    /// Add an interpolator
480    pub fn add_interpolator(&mut self, interpolator: ConstraintInterpolator<C>) {
481        self.interpolators.push(interpolator);
482    }
483
484    /// Advance global time
485    pub fn advance_time(&mut self, time: f32) -> LogicResult<()> {
486        self.current_time = time;
487        Ok(())
488    }
489
490    /// Update all state-dependent activations
491    pub fn update_activations(&mut self, state: &Array1<f32>) {
492        for constraint in &mut self.state_dependent {
493            constraint.update_activation(state);
494        }
495    }
496
497    /// Get number of active constraints
498    pub fn num_active(&self) -> usize {
499        self.state_dependent
500            .iter()
501            .filter(|c| c.is_active())
502            .count()
503            + self.predictive.len()
504            + self.interpolators.len()
505    }
506
507    /// Check all constraints
508    pub fn check_all(&self, state: &Array1<f32>) -> bool {
509        // Check state-dependent constraints
510        for constraint in &self.state_dependent {
511            if !constraint.check_if_active(state) {
512                return false;
513            }
514        }
515
516        // Check interpolators
517        for interpolator in &self.interpolators {
518            if !interpolator.check(state) {
519                return false;
520            }
521        }
522
523        true
524    }
525
526    /// Compute total violation across all constraints
527    pub fn total_violation(&self, state: &Array1<f32>) -> f32 {
528        let mut total = 0.0;
529
530        // State-dependent violations
531        for constraint in &self.state_dependent {
532            total += constraint.violation_if_active(state).max(0.0);
533        }
534
535        // Interpolator violations
536        for interpolator in &self.interpolators {
537            total += interpolator.violation(state).max(0.0);
538        }
539
540        total
541    }
542}
543
544impl<C: ViolationComputable + Clone> Default for TimeVaryingConstraintSet<C> {
545    fn default() -> Self {
546        Self::new()
547    }
548}
549
550#[cfg(test)]
551mod tests {
552    use super::*;
553    use crate::LinearConstraint;
554
555    #[test]
556    fn test_state_dependent_activation() {
557        // x <= 1.0
558        let base = LinearConstraint::less_eq(vec![1.0], 1.0);
559
560        let mut sdc = StateDependentConstraint::new(
561            "test",
562            base,
563            ActivationFunction::NormThreshold { threshold: 5.0 },
564        );
565
566        let state = Array1::from_vec(vec![1.0, 2.0, 3.0]);
567        let active = sdc.update_activation(&state);
568        assert!(!active); // norm ≈ 3.74 < 5.0
569
570        let state2 = Array1::from_vec(vec![3.0, 4.0, 5.0]);
571        let active2 = sdc.update_activation(&state2);
572        assert!(active2); // norm ≈ 7.07 > 5.0
573    }
574
575    #[test]
576    fn test_predictive_adaptation() {
577        // x <= 1.0
578        let base = LinearConstraint::less_eq(vec![1.0], 1.0);
579
580        let mut adapter = PredictiveConstraintAdapter::new("test", base, 5, 0.1);
581
582        let trajectory = vec![
583            Array1::from_vec(vec![0.5]),
584            Array1::from_vec(vec![0.8]),
585            Array1::from_vec(vec![1.2]), // violation
586        ];
587
588        let violations = adapter.predict_violations(&trajectory);
589        assert_eq!(violations.len(), 3);
590
591        let _ = adapter.adapt(&violations);
592        assert!(adapter.tightness() >= 1.0); // Should tighten due to predicted violation
593    }
594
595    #[test]
596    fn test_constraint_interpolation() -> LogicResult<()> {
597        // x <= 1.0 and x <= 2.0
598        let start = LinearConstraint::less_eq(vec![1.0], 1.0);
599        let end = LinearConstraint::less_eq(vec![1.0], 2.0);
600
601        let mut interp = ConstraintInterpolator::new("test", start, end, InterpolationMode::Linear);
602
603        interp.set_alpha(0.5)?;
604        assert_eq!(interp.alpha(), 0.5);
605
606        let state = Array1::from_vec(vec![1.5]);
607        let violation = interp.violation(&state);
608        // x=1.5: start violation = 1.5-1.0=0.5, end violation = max(1.5-2.0, 0)=0
609        // interpolated: 0.5 * 0.5 + 0.5 * 0 = 0.25
610        assert!((0.0..=0.5).contains(&violation));
611
612        Ok(())
613    }
614
615    #[test]
616    fn test_constraint_set() {
617        let mut set = TimeVaryingConstraintSet::new();
618
619        // x <= 1.0
620        let base = LinearConstraint::less_eq(vec![1.0], 1.0);
621        let sdc = StateDependentConstraint::new(
622            "state_dep",
623            base,
624            ActivationFunction::NormThreshold { threshold: 5.0 },
625        );
626
627        set.add_state_dependent(sdc);
628
629        let state = Array1::from_vec(vec![1.0, 2.0]);
630        set.update_activations(&state);
631
632        assert_eq!(set.num_active(), 0); // Not active due to low norm
633
634        let state2 = Array1::from_vec(vec![5.0, 5.0]);
635        set.update_activations(&state2);
636        assert_eq!(set.num_active(), 1); // Should be active now
637    }
638}