kizzasi_logic/
differential_constraints.rs

1//! Differential Constraints
2//!
3//! This module provides constraints on derivatives and integrals:
4//! - Higher-order derivative constraints
5//! - Integral constraints over time windows
6//! - Differential-algebraic constraints
7//! - Path integral constraints
8
9use scirs2_core::ndarray::Array1;
10use serde::{Deserialize, Serialize};
11use std::collections::VecDeque;
12
13/// Order of derivative
14#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
15pub enum DerivativeOrder {
16    /// First derivative (velocity)
17    First,
18    /// Second derivative (acceleration)
19    Second,
20    /// Third derivative (jerk)
21    Third,
22    /// Higher order
23    Custom(usize),
24}
25
26impl DerivativeOrder {
27    /// Get numeric order
28    pub fn order(&self) -> usize {
29        match self {
30            Self::First => 1,
31            Self::Second => 2,
32            Self::Third => 3,
33            Self::Custom(n) => *n,
34        }
35    }
36}
37
38/// Higher-order derivative constraint
39#[derive(Debug, Clone)]
40pub struct DerivativeConstraint {
41    /// Name of the constraint
42    name: String,
43    /// Order of derivative
44    order: DerivativeOrder,
45    /// Time step for numerical differentiation
46    dt: f32,
47    /// Upper bound on derivative magnitude
48    max_magnitude: f32,
49    /// Historical values for computing derivatives
50    history: VecDeque<(f32, Array1<f32>)>, // (time, value)
51    /// Maximum history length
52    max_history: usize,
53}
54
55impl DerivativeConstraint {
56    /// Create a new derivative constraint
57    pub fn new(
58        name: impl Into<String>,
59        order: DerivativeOrder,
60        dt: f32,
61        max_magnitude: f32,
62    ) -> Self {
63        let max_history = order.order() + 2;
64        Self {
65            name: name.into(),
66            order,
67            dt,
68            max_magnitude,
69            history: VecDeque::new(),
70            max_history,
71        }
72    }
73
74    /// Add a new observation
75    pub fn observe(&mut self, time: f32, value: Array1<f32>) {
76        self.history.push_back((time, value));
77        if self.history.len() > self.max_history {
78            self.history.pop_front();
79        }
80    }
81
82    /// Compute numerical derivative of given order
83    fn compute_derivative(&self) -> Option<Array1<f32>> {
84        let n = self.order.order();
85        if self.history.len() < n + 1 {
86            return None; // Not enough data
87        }
88
89        // Use finite differences to compute nth derivative
90        // For simplicity, use backward differences
91        let values: Vec<_> = self.history.iter().rev().take(n + 1).collect();
92
93        match n {
94            1 => {
95                // First derivative: (v[0] - v[1]) / dt
96                let (t0, v0) = values[0];
97                let (t1, v1) = values[1];
98                let dt = t0 - t1;
99                Some((v0 - v1) / dt)
100            }
101            2 => {
102                // Second derivative: ((v[0] - v[1]) - (v[1] - v[2])) / dt^2
103                let (t0, v0) = values[0];
104                let (t1, v1) = values[1];
105                let (t2, v2) = values[2];
106                let dt = (t0 - t1 + t1 - t2) / 2.0;
107                Some(((v0 - v1) - (v1 - v2)) / (dt * dt))
108            }
109            3 => {
110                // Third derivative (jerk)
111                if values.len() < 4 {
112                    return None;
113                }
114                let (_, v0) = values[0];
115                let (_, v1) = values[1];
116                let (_, v2) = values[2];
117                let (_, v3) = values[3];
118                Some((v0 - &(v1 * 3.0) + &(v2 * 3.0) - v3) / (self.dt * self.dt * self.dt))
119            }
120            _ => None, // Higher orders not implemented
121        }
122    }
123
124    /// Check if derivative constraint is satisfied
125    pub fn check(&self) -> bool {
126        if let Some(derivative) = self.compute_derivative() {
127            let magnitude = derivative.iter().map(|x| x * x).sum::<f32>().sqrt();
128            magnitude <= self.max_magnitude
129        } else {
130            true // Not enough data, trivially satisfied
131        }
132    }
133
134    /// Compute violation amount
135    pub fn violation(&self) -> f32 {
136        if let Some(derivative) = self.compute_derivative() {
137            let magnitude = derivative.iter().map(|x| x * x).sum::<f32>().sqrt();
138            (magnitude - self.max_magnitude).max(0.0)
139        } else {
140            0.0
141        }
142    }
143
144    /// Get current derivative estimate
145    pub fn get_derivative(&self) -> Option<Array1<f32>> {
146        self.compute_derivative()
147    }
148
149    /// Get name
150    pub fn name(&self) -> &str {
151        &self.name
152    }
153
154    /// Reset history
155    pub fn reset(&mut self) {
156        self.history.clear();
157    }
158}
159
160/// Integral constraint over a time window
161#[derive(Debug, Clone)]
162pub struct IntegralConstraint {
163    /// Name of the constraint
164    name: String,
165    /// Time window for integration
166    window_duration: f32,
167    /// Upper bound on integral value
168    max_integral: f32,
169    /// Lower bound on integral value
170    min_integral: f32,
171    /// Historical values for integration
172    history: VecDeque<(f32, Array1<f32>)>, // (time, value)
173}
174
175impl IntegralConstraint {
176    /// Create a new integral constraint
177    pub fn new(
178        name: impl Into<String>,
179        window_duration: f32,
180        min_integral: f32,
181        max_integral: f32,
182    ) -> Self {
183        Self {
184            name: name.into(),
185            window_duration,
186            max_integral,
187            min_integral,
188            history: VecDeque::new(),
189        }
190    }
191
192    /// Add a new observation
193    pub fn observe(&mut self, time: f32, value: Array1<f32>) {
194        self.history.push_back((time, value));
195
196        // Remove values outside the window
197        let cutoff_time = time - self.window_duration;
198        while let Some((t, _)) = self.history.front() {
199            if *t < cutoff_time {
200                self.history.pop_front();
201            } else {
202                break;
203            }
204        }
205    }
206
207    /// Compute integral using trapezoidal rule
208    fn compute_integral(&self) -> Option<Array1<f32>> {
209        if self.history.len() < 2 {
210            return None;
211        }
212
213        let dim = self.history[0].1.len();
214        let mut integral = Array1::zeros(dim);
215
216        for i in 0..self.history.len() - 1 {
217            let (t1, v1) = &self.history[i];
218            let (t2, v2) = &self.history[i + 1];
219            let dt = t2 - t1;
220            // Trapezoidal rule: (v1 + v2) / 2 * dt
221            integral += &((v1 + v2) * (dt / 2.0));
222        }
223
224        Some(integral)
225    }
226
227    /// Check if integral constraint is satisfied
228    pub fn check(&self) -> bool {
229        if let Some(integral) = self.compute_integral() {
230            integral
231                .iter()
232                .all(|&x| x >= self.min_integral && x <= self.max_integral)
233        } else {
234            true
235        }
236    }
237
238    /// Compute violation amount
239    pub fn violation(&self) -> f32 {
240        if let Some(integral) = self.compute_integral() {
241            let mut total_violation = 0.0;
242            for &x in integral.iter() {
243                if x < self.min_integral {
244                    total_violation += self.min_integral - x;
245                } else if x > self.max_integral {
246                    total_violation += x - self.max_integral;
247                }
248            }
249            total_violation
250        } else {
251            0.0
252        }
253    }
254
255    /// Get current integral estimate
256    pub fn get_integral(&self) -> Option<Array1<f32>> {
257        self.compute_integral()
258    }
259
260    /// Get name
261    pub fn name(&self) -> &str {
262        &self.name
263    }
264
265    /// Reset history
266    pub fn reset(&mut self) {
267        self.history.clear();
268    }
269}
270
271/// Differential-algebraic constraint (DAE)
272/// Represents constraints of the form: F(x, dx/dt, t) = 0
273#[derive(Debug, Clone)]
274pub struct DifferentialAlgebraicConstraint {
275    /// Name of the constraint
276    name: String,
277    /// Constraint function: F(x, dx/dt, t) -> residual
278    constraint_fn: fn(&Array1<f32>, &Array1<f32>, f32) -> Array1<f32>,
279    /// Tolerance for residual
280    tolerance: f32,
281    /// Historical data for computing derivative
282    history: VecDeque<(f32, Array1<f32>)>,
283    /// Time step
284    #[allow(dead_code)]
285    dt: f32,
286}
287
288impl DifferentialAlgebraicConstraint {
289    /// Create a new DAE constraint
290    pub fn new(
291        name: impl Into<String>,
292        constraint_fn: fn(&Array1<f32>, &Array1<f32>, f32) -> Array1<f32>,
293        tolerance: f32,
294        dt: f32,
295    ) -> Self {
296        Self {
297            name: name.into(),
298            constraint_fn,
299            tolerance,
300            history: VecDeque::new(),
301            dt,
302        }
303    }
304
305    /// Add observation
306    pub fn observe(&mut self, time: f32, value: Array1<f32>) {
307        self.history.push_back((time, value));
308        if self.history.len() > 2 {
309            self.history.pop_front();
310        }
311    }
312
313    /// Compute current derivative estimate
314    fn compute_derivative(&self) -> Option<Array1<f32>> {
315        if self.history.len() < 2 {
316            return None;
317        }
318
319        let (t1, v1) = &self.history[0];
320        let (t2, v2) = &self.history[1];
321        let dt = t2 - t1;
322        Some((v2 - v1) / dt)
323    }
324
325    /// Check constraint
326    pub fn check(&self) -> bool {
327        if self.history.is_empty() {
328            return true;
329        }
330
331        let (t, x) = &self.history[self.history.len() - 1];
332
333        if let Some(dx_dt) = self.compute_derivative() {
334            let residual = (self.constraint_fn)(x, &dx_dt, *t);
335            let residual_norm = residual.iter().map(|r| r * r).sum::<f32>().sqrt();
336            residual_norm <= self.tolerance
337        } else {
338            true
339        }
340    }
341
342    /// Compute violation
343    pub fn violation(&self) -> f32 {
344        if self.history.is_empty() {
345            return 0.0;
346        }
347
348        let (t, x) = &self.history[self.history.len() - 1];
349
350        if let Some(dx_dt) = self.compute_derivative() {
351            let residual = (self.constraint_fn)(x, &dx_dt, *t);
352            let residual_norm = residual.iter().map(|r| r * r).sum::<f32>().sqrt();
353            (residual_norm - self.tolerance).max(0.0)
354        } else {
355            0.0
356        }
357    }
358
359    /// Get name
360    pub fn name(&self) -> &str {
361        &self.name
362    }
363
364    /// Reset
365    pub fn reset(&mut self) {
366        self.history.clear();
367    }
368}
369
370/// Path integral constraint for trajectory optimization
371#[derive(Debug, Clone)]
372pub struct PathIntegralConstraint {
373    /// Name of the constraint
374    name: String,
375    /// Cost function: cost(x, dx/dt, t) -> scalar
376    cost_fn: fn(&Array1<f32>, &Array1<f32>, f32) -> f32,
377    /// Maximum allowed path integral (total cost)
378    max_cost: f32,
379    /// Historical trajectory data
380    trajectory: VecDeque<(f32, Array1<f32>)>,
381}
382
383impl PathIntegralConstraint {
384    /// Create a new path integral constraint
385    pub fn new(
386        name: impl Into<String>,
387        cost_fn: fn(&Array1<f32>, &Array1<f32>, f32) -> f32,
388        max_cost: f32,
389    ) -> Self {
390        Self {
391            name: name.into(),
392            cost_fn,
393            max_cost,
394            trajectory: VecDeque::new(),
395        }
396    }
397
398    /// Add trajectory point
399    pub fn observe(&mut self, time: f32, state: Array1<f32>) {
400        self.trajectory.push_back((time, state));
401    }
402
403    /// Compute path integral
404    fn compute_path_integral(&self) -> f32 {
405        if self.trajectory.len() < 2 {
406            return 0.0;
407        }
408
409        let mut total_cost = 0.0;
410
411        for i in 0..self.trajectory.len() - 1 {
412            let (t1, x1) = &self.trajectory[i];
413            let (t2, x2) = &self.trajectory[i + 1];
414
415            let dt = t2 - t1;
416            let dx_dt = (x2 - x1) / dt;
417
418            // Evaluate cost at midpoint
419            let t_mid = (t1 + t2) / 2.0;
420            let x_mid = (x1 + x2) / 2.0;
421
422            let cost = (self.cost_fn)(&x_mid, &dx_dt, t_mid);
423            total_cost += cost * dt;
424        }
425
426        total_cost
427    }
428
429    /// Check constraint
430    pub fn check(&self) -> bool {
431        let cost = self.compute_path_integral();
432        cost <= self.max_cost
433    }
434
435    /// Compute violation
436    pub fn violation(&self) -> f32 {
437        let cost = self.compute_path_integral();
438        (cost - self.max_cost).max(0.0)
439    }
440
441    /// Get current path cost
442    pub fn get_path_cost(&self) -> f32 {
443        self.compute_path_integral()
444    }
445
446    /// Get name
447    pub fn name(&self) -> &str {
448        &self.name
449    }
450
451    /// Reset trajectory
452    pub fn reset(&mut self) {
453        self.trajectory.clear();
454    }
455
456    /// Get trajectory length
457    pub fn trajectory_length(&self) -> usize {
458        self.trajectory.len()
459    }
460}
461
462/// Constraint set for differential constraints
463#[derive(Debug, Clone)]
464pub struct DifferentialConstraintSet {
465    /// Derivative constraints
466    derivative_constraints: Vec<DerivativeConstraint>,
467    /// Integral constraints
468    integral_constraints: Vec<IntegralConstraint>,
469    /// DAE constraints
470    dae_constraints: Vec<DifferentialAlgebraicConstraint>,
471    /// Path integral constraints
472    path_integral_constraints: Vec<PathIntegralConstraint>,
473}
474
475impl DifferentialConstraintSet {
476    /// Create a new differential constraint set
477    pub fn new() -> Self {
478        Self {
479            derivative_constraints: Vec::new(),
480            integral_constraints: Vec::new(),
481            dae_constraints: Vec::new(),
482            path_integral_constraints: Vec::new(),
483        }
484    }
485
486    /// Add a derivative constraint
487    pub fn add_derivative(&mut self, constraint: DerivativeConstraint) {
488        self.derivative_constraints.push(constraint);
489    }
490
491    /// Add an integral constraint
492    pub fn add_integral(&mut self, constraint: IntegralConstraint) {
493        self.integral_constraints.push(constraint);
494    }
495
496    /// Add a DAE constraint
497    pub fn add_dae(&mut self, constraint: DifferentialAlgebraicConstraint) {
498        self.dae_constraints.push(constraint);
499    }
500
501    /// Add a path integral constraint
502    pub fn add_path_integral(&mut self, constraint: PathIntegralConstraint) {
503        self.path_integral_constraints.push(constraint);
504    }
505
506    /// Observe new state at time t
507    pub fn observe(&mut self, time: f32, state: Array1<f32>) {
508        for constraint in &mut self.derivative_constraints {
509            constraint.observe(time, state.clone());
510        }
511        for constraint in &mut self.integral_constraints {
512            constraint.observe(time, state.clone());
513        }
514        for constraint in &mut self.dae_constraints {
515            constraint.observe(time, state.clone());
516        }
517        for constraint in &mut self.path_integral_constraints {
518            constraint.observe(time, state.clone());
519        }
520    }
521
522    /// Check all constraints
523    pub fn check_all(&self) -> bool {
524        self.derivative_constraints.iter().all(|c| c.check())
525            && self.integral_constraints.iter().all(|c| c.check())
526            && self.dae_constraints.iter().all(|c| c.check())
527            && self.path_integral_constraints.iter().all(|c| c.check())
528    }
529
530    /// Compute total violation
531    pub fn total_violation(&self) -> f32 {
532        let mut total = 0.0;
533        for c in &self.derivative_constraints {
534            total += c.violation();
535        }
536        for c in &self.integral_constraints {
537            total += c.violation();
538        }
539        for c in &self.dae_constraints {
540            total += c.violation();
541        }
542        for c in &self.path_integral_constraints {
543            total += c.violation();
544        }
545        total
546    }
547
548    /// Reset all constraints
549    pub fn reset(&mut self) {
550        for c in &mut self.derivative_constraints {
551            c.reset();
552        }
553        for c in &mut self.integral_constraints {
554            c.reset();
555        }
556        for c in &mut self.dae_constraints {
557            c.reset();
558        }
559        for c in &mut self.path_integral_constraints {
560            c.reset();
561        }
562    }
563
564    /// Get number of constraints
565    pub fn num_constraints(&self) -> usize {
566        self.derivative_constraints.len()
567            + self.integral_constraints.len()
568            + self.dae_constraints.len()
569            + self.path_integral_constraints.len()
570    }
571}
572
573impl Default for DifferentialConstraintSet {
574    fn default() -> Self {
575        Self::new()
576    }
577}
578
579#[cfg(test)]
580mod tests {
581    use super::*;
582
583    #[test]
584    fn test_derivative_constraint() {
585        let mut constraint =
586            DerivativeConstraint::new("velocity_limit", DerivativeOrder::First, 0.1, 10.0);
587
588        // Add observations
589        constraint.observe(0.0, Array1::from_vec(vec![0.0]));
590        constraint.observe(0.1, Array1::from_vec(vec![0.5])); // velocity = 5.0
591
592        assert!(constraint.check()); // 5.0 <= 10.0
593
594        constraint.observe(0.2, Array1::from_vec(vec![2.0])); // velocity = 15.0
595        assert!(!constraint.check()); // 15.0 > 10.0
596    }
597
598    #[test]
599    fn test_integral_constraint() {
600        let mut constraint = IntegralConstraint::new("energy_limit", 1.0, 0.0, 100.0);
601
602        constraint.observe(0.0, Array1::from_vec(vec![10.0]));
603        constraint.observe(0.5, Array1::from_vec(vec![20.0]));
604        constraint.observe(1.0, Array1::from_vec(vec![10.0]));
605
606        assert!(constraint.check());
607        assert!(constraint.get_integral().is_some());
608    }
609
610    #[test]
611    fn test_dae_constraint() {
612        // Simple DAE: x + dx/dt = 0
613        fn dae_fn(x: &Array1<f32>, dx_dt: &Array1<f32>, _t: f32) -> Array1<f32> {
614            x + dx_dt
615        }
616
617        let mut constraint = DifferentialAlgebraicConstraint::new("simple_dae", dae_fn, 1.0, 0.1);
618
619        constraint.observe(0.0, Array1::from_vec(vec![1.0]));
620        constraint.observe(0.1, Array1::from_vec(vec![0.9])); // dx/dt = -1.0, satisfies x + dx/dt ≈ 0
621
622        assert!(constraint.check());
623    }
624
625    #[test]
626    fn test_path_integral_constraint() {
627        // Cost function: ||dx/dt||^2
628        fn cost_fn(_x: &Array1<f32>, dx_dt: &Array1<f32>, _t: f32) -> f32 {
629            dx_dt.iter().map(|v| v * v).sum()
630        }
631
632        let mut constraint = PathIntegralConstraint::new("min_energy", cost_fn, 100.0);
633
634        constraint.observe(0.0, Array1::from_vec(vec![0.0]));
635        constraint.observe(0.1, Array1::from_vec(vec![1.0]));
636        constraint.observe(0.2, Array1::from_vec(vec![2.0]));
637
638        assert!(constraint.check());
639        assert_eq!(constraint.trajectory_length(), 3);
640    }
641
642    #[test]
643    fn test_differential_constraint_set() {
644        let mut set = DifferentialConstraintSet::new();
645
646        set.add_derivative(DerivativeConstraint::new(
647            "velocity",
648            DerivativeOrder::First,
649            0.1,
650            10.0,
651        ));
652
653        set.observe(0.0, Array1::from_vec(vec![0.0]));
654        set.observe(0.1, Array1::from_vec(vec![0.5]));
655
656        assert!(set.check_all());
657        assert_eq!(set.num_constraints(), 1);
658    }
659}