1use scirs2_core::ndarray::Array1;
10use serde::{Deserialize, Serialize};
11use std::collections::VecDeque;
12
13#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
15pub enum DerivativeOrder {
16 First,
18 Second,
20 Third,
22 Custom(usize),
24}
25
26impl DerivativeOrder {
27 pub fn order(&self) -> usize {
29 match self {
30 Self::First => 1,
31 Self::Second => 2,
32 Self::Third => 3,
33 Self::Custom(n) => *n,
34 }
35 }
36}
37
38#[derive(Debug, Clone)]
40pub struct DerivativeConstraint {
41 name: String,
43 order: DerivativeOrder,
45 dt: f32,
47 max_magnitude: f32,
49 history: VecDeque<(f32, Array1<f32>)>, max_history: usize,
53}
54
55impl DerivativeConstraint {
56 pub fn new(
58 name: impl Into<String>,
59 order: DerivativeOrder,
60 dt: f32,
61 max_magnitude: f32,
62 ) -> Self {
63 let max_history = order.order() + 2;
64 Self {
65 name: name.into(),
66 order,
67 dt,
68 max_magnitude,
69 history: VecDeque::new(),
70 max_history,
71 }
72 }
73
74 pub fn observe(&mut self, time: f32, value: Array1<f32>) {
76 self.history.push_back((time, value));
77 if self.history.len() > self.max_history {
78 self.history.pop_front();
79 }
80 }
81
82 fn compute_derivative(&self) -> Option<Array1<f32>> {
84 let n = self.order.order();
85 if self.history.len() < n + 1 {
86 return None; }
88
89 let values: Vec<_> = self.history.iter().rev().take(n + 1).collect();
92
93 match n {
94 1 => {
95 let (t0, v0) = values[0];
97 let (t1, v1) = values[1];
98 let dt = t0 - t1;
99 Some((v0 - v1) / dt)
100 }
101 2 => {
102 let (t0, v0) = values[0];
104 let (t1, v1) = values[1];
105 let (t2, v2) = values[2];
106 let dt = (t0 - t1 + t1 - t2) / 2.0;
107 Some(((v0 - v1) - (v1 - v2)) / (dt * dt))
108 }
109 3 => {
110 if values.len() < 4 {
112 return None;
113 }
114 let (_, v0) = values[0];
115 let (_, v1) = values[1];
116 let (_, v2) = values[2];
117 let (_, v3) = values[3];
118 Some((v0 - &(v1 * 3.0) + &(v2 * 3.0) - v3) / (self.dt * self.dt * self.dt))
119 }
120 _ => None, }
122 }
123
124 pub fn check(&self) -> bool {
126 if let Some(derivative) = self.compute_derivative() {
127 let magnitude = derivative.iter().map(|x| x * x).sum::<f32>().sqrt();
128 magnitude <= self.max_magnitude
129 } else {
130 true }
132 }
133
134 pub fn violation(&self) -> f32 {
136 if let Some(derivative) = self.compute_derivative() {
137 let magnitude = derivative.iter().map(|x| x * x).sum::<f32>().sqrt();
138 (magnitude - self.max_magnitude).max(0.0)
139 } else {
140 0.0
141 }
142 }
143
144 pub fn get_derivative(&self) -> Option<Array1<f32>> {
146 self.compute_derivative()
147 }
148
149 pub fn name(&self) -> &str {
151 &self.name
152 }
153
154 pub fn reset(&mut self) {
156 self.history.clear();
157 }
158}
159
160#[derive(Debug, Clone)]
162pub struct IntegralConstraint {
163 name: String,
165 window_duration: f32,
167 max_integral: f32,
169 min_integral: f32,
171 history: VecDeque<(f32, Array1<f32>)>, }
174
175impl IntegralConstraint {
176 pub fn new(
178 name: impl Into<String>,
179 window_duration: f32,
180 min_integral: f32,
181 max_integral: f32,
182 ) -> Self {
183 Self {
184 name: name.into(),
185 window_duration,
186 max_integral,
187 min_integral,
188 history: VecDeque::new(),
189 }
190 }
191
192 pub fn observe(&mut self, time: f32, value: Array1<f32>) {
194 self.history.push_back((time, value));
195
196 let cutoff_time = time - self.window_duration;
198 while let Some((t, _)) = self.history.front() {
199 if *t < cutoff_time {
200 self.history.pop_front();
201 } else {
202 break;
203 }
204 }
205 }
206
207 fn compute_integral(&self) -> Option<Array1<f32>> {
209 if self.history.len() < 2 {
210 return None;
211 }
212
213 let dim = self.history[0].1.len();
214 let mut integral = Array1::zeros(dim);
215
216 for i in 0..self.history.len() - 1 {
217 let (t1, v1) = &self.history[i];
218 let (t2, v2) = &self.history[i + 1];
219 let dt = t2 - t1;
220 integral += &((v1 + v2) * (dt / 2.0));
222 }
223
224 Some(integral)
225 }
226
227 pub fn check(&self) -> bool {
229 if let Some(integral) = self.compute_integral() {
230 integral
231 .iter()
232 .all(|&x| x >= self.min_integral && x <= self.max_integral)
233 } else {
234 true
235 }
236 }
237
238 pub fn violation(&self) -> f32 {
240 if let Some(integral) = self.compute_integral() {
241 let mut total_violation = 0.0;
242 for &x in integral.iter() {
243 if x < self.min_integral {
244 total_violation += self.min_integral - x;
245 } else if x > self.max_integral {
246 total_violation += x - self.max_integral;
247 }
248 }
249 total_violation
250 } else {
251 0.0
252 }
253 }
254
255 pub fn get_integral(&self) -> Option<Array1<f32>> {
257 self.compute_integral()
258 }
259
260 pub fn name(&self) -> &str {
262 &self.name
263 }
264
265 pub fn reset(&mut self) {
267 self.history.clear();
268 }
269}
270
271#[derive(Debug, Clone)]
274pub struct DifferentialAlgebraicConstraint {
275 name: String,
277 constraint_fn: fn(&Array1<f32>, &Array1<f32>, f32) -> Array1<f32>,
279 tolerance: f32,
281 history: VecDeque<(f32, Array1<f32>)>,
283 #[allow(dead_code)]
285 dt: f32,
286}
287
288impl DifferentialAlgebraicConstraint {
289 pub fn new(
291 name: impl Into<String>,
292 constraint_fn: fn(&Array1<f32>, &Array1<f32>, f32) -> Array1<f32>,
293 tolerance: f32,
294 dt: f32,
295 ) -> Self {
296 Self {
297 name: name.into(),
298 constraint_fn,
299 tolerance,
300 history: VecDeque::new(),
301 dt,
302 }
303 }
304
305 pub fn observe(&mut self, time: f32, value: Array1<f32>) {
307 self.history.push_back((time, value));
308 if self.history.len() > 2 {
309 self.history.pop_front();
310 }
311 }
312
313 fn compute_derivative(&self) -> Option<Array1<f32>> {
315 if self.history.len() < 2 {
316 return None;
317 }
318
319 let (t1, v1) = &self.history[0];
320 let (t2, v2) = &self.history[1];
321 let dt = t2 - t1;
322 Some((v2 - v1) / dt)
323 }
324
325 pub fn check(&self) -> bool {
327 if self.history.is_empty() {
328 return true;
329 }
330
331 let (t, x) = &self.history[self.history.len() - 1];
332
333 if let Some(dx_dt) = self.compute_derivative() {
334 let residual = (self.constraint_fn)(x, &dx_dt, *t);
335 let residual_norm = residual.iter().map(|r| r * r).sum::<f32>().sqrt();
336 residual_norm <= self.tolerance
337 } else {
338 true
339 }
340 }
341
342 pub fn violation(&self) -> f32 {
344 if self.history.is_empty() {
345 return 0.0;
346 }
347
348 let (t, x) = &self.history[self.history.len() - 1];
349
350 if let Some(dx_dt) = self.compute_derivative() {
351 let residual = (self.constraint_fn)(x, &dx_dt, *t);
352 let residual_norm = residual.iter().map(|r| r * r).sum::<f32>().sqrt();
353 (residual_norm - self.tolerance).max(0.0)
354 } else {
355 0.0
356 }
357 }
358
359 pub fn name(&self) -> &str {
361 &self.name
362 }
363
364 pub fn reset(&mut self) {
366 self.history.clear();
367 }
368}
369
370#[derive(Debug, Clone)]
372pub struct PathIntegralConstraint {
373 name: String,
375 cost_fn: fn(&Array1<f32>, &Array1<f32>, f32) -> f32,
377 max_cost: f32,
379 trajectory: VecDeque<(f32, Array1<f32>)>,
381}
382
383impl PathIntegralConstraint {
384 pub fn new(
386 name: impl Into<String>,
387 cost_fn: fn(&Array1<f32>, &Array1<f32>, f32) -> f32,
388 max_cost: f32,
389 ) -> Self {
390 Self {
391 name: name.into(),
392 cost_fn,
393 max_cost,
394 trajectory: VecDeque::new(),
395 }
396 }
397
398 pub fn observe(&mut self, time: f32, state: Array1<f32>) {
400 self.trajectory.push_back((time, state));
401 }
402
403 fn compute_path_integral(&self) -> f32 {
405 if self.trajectory.len() < 2 {
406 return 0.0;
407 }
408
409 let mut total_cost = 0.0;
410
411 for i in 0..self.trajectory.len() - 1 {
412 let (t1, x1) = &self.trajectory[i];
413 let (t2, x2) = &self.trajectory[i + 1];
414
415 let dt = t2 - t1;
416 let dx_dt = (x2 - x1) / dt;
417
418 let t_mid = (t1 + t2) / 2.0;
420 let x_mid = (x1 + x2) / 2.0;
421
422 let cost = (self.cost_fn)(&x_mid, &dx_dt, t_mid);
423 total_cost += cost * dt;
424 }
425
426 total_cost
427 }
428
429 pub fn check(&self) -> bool {
431 let cost = self.compute_path_integral();
432 cost <= self.max_cost
433 }
434
435 pub fn violation(&self) -> f32 {
437 let cost = self.compute_path_integral();
438 (cost - self.max_cost).max(0.0)
439 }
440
441 pub fn get_path_cost(&self) -> f32 {
443 self.compute_path_integral()
444 }
445
446 pub fn name(&self) -> &str {
448 &self.name
449 }
450
451 pub fn reset(&mut self) {
453 self.trajectory.clear();
454 }
455
456 pub fn trajectory_length(&self) -> usize {
458 self.trajectory.len()
459 }
460}
461
462#[derive(Debug, Clone)]
464pub struct DifferentialConstraintSet {
465 derivative_constraints: Vec<DerivativeConstraint>,
467 integral_constraints: Vec<IntegralConstraint>,
469 dae_constraints: Vec<DifferentialAlgebraicConstraint>,
471 path_integral_constraints: Vec<PathIntegralConstraint>,
473}
474
475impl DifferentialConstraintSet {
476 pub fn new() -> Self {
478 Self {
479 derivative_constraints: Vec::new(),
480 integral_constraints: Vec::new(),
481 dae_constraints: Vec::new(),
482 path_integral_constraints: Vec::new(),
483 }
484 }
485
486 pub fn add_derivative(&mut self, constraint: DerivativeConstraint) {
488 self.derivative_constraints.push(constraint);
489 }
490
491 pub fn add_integral(&mut self, constraint: IntegralConstraint) {
493 self.integral_constraints.push(constraint);
494 }
495
496 pub fn add_dae(&mut self, constraint: DifferentialAlgebraicConstraint) {
498 self.dae_constraints.push(constraint);
499 }
500
501 pub fn add_path_integral(&mut self, constraint: PathIntegralConstraint) {
503 self.path_integral_constraints.push(constraint);
504 }
505
506 pub fn observe(&mut self, time: f32, state: Array1<f32>) {
508 for constraint in &mut self.derivative_constraints {
509 constraint.observe(time, state.clone());
510 }
511 for constraint in &mut self.integral_constraints {
512 constraint.observe(time, state.clone());
513 }
514 for constraint in &mut self.dae_constraints {
515 constraint.observe(time, state.clone());
516 }
517 for constraint in &mut self.path_integral_constraints {
518 constraint.observe(time, state.clone());
519 }
520 }
521
522 pub fn check_all(&self) -> bool {
524 self.derivative_constraints.iter().all(|c| c.check())
525 && self.integral_constraints.iter().all(|c| c.check())
526 && self.dae_constraints.iter().all(|c| c.check())
527 && self.path_integral_constraints.iter().all(|c| c.check())
528 }
529
530 pub fn total_violation(&self) -> f32 {
532 let mut total = 0.0;
533 for c in &self.derivative_constraints {
534 total += c.violation();
535 }
536 for c in &self.integral_constraints {
537 total += c.violation();
538 }
539 for c in &self.dae_constraints {
540 total += c.violation();
541 }
542 for c in &self.path_integral_constraints {
543 total += c.violation();
544 }
545 total
546 }
547
548 pub fn reset(&mut self) {
550 for c in &mut self.derivative_constraints {
551 c.reset();
552 }
553 for c in &mut self.integral_constraints {
554 c.reset();
555 }
556 for c in &mut self.dae_constraints {
557 c.reset();
558 }
559 for c in &mut self.path_integral_constraints {
560 c.reset();
561 }
562 }
563
564 pub fn num_constraints(&self) -> usize {
566 self.derivative_constraints.len()
567 + self.integral_constraints.len()
568 + self.dae_constraints.len()
569 + self.path_integral_constraints.len()
570 }
571}
572
573impl Default for DifferentialConstraintSet {
574 fn default() -> Self {
575 Self::new()
576 }
577}
578
579#[cfg(test)]
580mod tests {
581 use super::*;
582
583 #[test]
584 fn test_derivative_constraint() {
585 let mut constraint =
586 DerivativeConstraint::new("velocity_limit", DerivativeOrder::First, 0.1, 10.0);
587
588 constraint.observe(0.0, Array1::from_vec(vec![0.0]));
590 constraint.observe(0.1, Array1::from_vec(vec![0.5])); assert!(constraint.check()); constraint.observe(0.2, Array1::from_vec(vec![2.0])); assert!(!constraint.check()); }
597
598 #[test]
599 fn test_integral_constraint() {
600 let mut constraint = IntegralConstraint::new("energy_limit", 1.0, 0.0, 100.0);
601
602 constraint.observe(0.0, Array1::from_vec(vec![10.0]));
603 constraint.observe(0.5, Array1::from_vec(vec![20.0]));
604 constraint.observe(1.0, Array1::from_vec(vec![10.0]));
605
606 assert!(constraint.check());
607 assert!(constraint.get_integral().is_some());
608 }
609
610 #[test]
611 fn test_dae_constraint() {
612 fn dae_fn(x: &Array1<f32>, dx_dt: &Array1<f32>, _t: f32) -> Array1<f32> {
614 x + dx_dt
615 }
616
617 let mut constraint = DifferentialAlgebraicConstraint::new("simple_dae", dae_fn, 1.0, 0.1);
618
619 constraint.observe(0.0, Array1::from_vec(vec![1.0]));
620 constraint.observe(0.1, Array1::from_vec(vec![0.9])); assert!(constraint.check());
623 }
624
625 #[test]
626 fn test_path_integral_constraint() {
627 fn cost_fn(_x: &Array1<f32>, dx_dt: &Array1<f32>, _t: f32) -> f32 {
629 dx_dt.iter().map(|v| v * v).sum()
630 }
631
632 let mut constraint = PathIntegralConstraint::new("min_energy", cost_fn, 100.0);
633
634 constraint.observe(0.0, Array1::from_vec(vec![0.0]));
635 constraint.observe(0.1, Array1::from_vec(vec![1.0]));
636 constraint.observe(0.2, Array1::from_vec(vec![2.0]));
637
638 assert!(constraint.check());
639 assert_eq!(constraint.trajectory_length(), 3);
640 }
641
642 #[test]
643 fn test_differential_constraint_set() {
644 let mut set = DifferentialConstraintSet::new();
645
646 set.add_derivative(DerivativeConstraint::new(
647 "velocity",
648 DerivativeOrder::First,
649 0.1,
650 10.0,
651 ));
652
653 set.observe(0.0, Array1::from_vec(vec![0.0]));
654 set.observe(0.1, Array1::from_vec(vec![0.5]));
655
656 assert!(set.check_all());
657 assert_eq!(set.num_constraints(), 1);
658 }
659}