1use deke_types::{SRobotPath, SRobotQ};
2
3use crate::reqpath::DirectedOption;
4
5pub enum TransitionCost<'a, const N: usize> {
10 JointWeighted(SRobotQ<N, f64>),
12 Custom(&'a dyn Fn(&SRobotQ<N, f64>, &SRobotQ<N, f64>) -> f64),
15}
16
17impl<const N: usize> TransitionCost<'_, N> {
18 pub(crate) fn eval(&self, from: &SRobotQ<N, f64>, to: &SRobotQ<N, f64>) -> f64 {
19 match self {
20 TransitionCost::JointWeighted(w) => weighted_distance(from, to, w),
21 TransitionCost::Custom(f) => f(from, to),
22 }
23 }
24}
25
26pub fn weighted_distance<const N: usize>(
29 a: &SRobotQ<N, f64>,
30 b: &SRobotQ<N, f64>,
31 w: &SRobotQ<N, f64>,
32) -> f64 {
33 w.0.iter()
34 .zip(a.0.iter())
35 .zip(b.0.iter())
36 .map(|((&wi, &ai), &bi)| {
37 let d = wi * (ai - bi);
38 d * d
39 })
40 .sum::<f64>()
41 .sqrt()
42}
43
44fn traversal_cost<const N: usize>(path: &SRobotPath<N, f64>, cost: &TransitionCost<N>) -> f64 {
47 path.segments().map(|(a, b)| cost.eval(a, b)).sum()
48}
49
50pub(crate) struct CostMatrices {
57 pub transition: Vec<f64>,
60 pub start: Vec<f64>,
63 pub end: Vec<f64>,
66}
67
68pub(crate) fn build_matrices<const N: usize>(
69 options: &[DirectedOption<N>],
70 cost: &TransitionCost<N>,
71 start_q: &SRobotQ<N, f64>,
72 end_q: Option<&SRobotQ<N, f64>>,
73) -> CostMatrices {
74 let m = options.len();
75 let traversal: Vec<f64> = options
76 .iter()
77 .map(|o| traversal_cost(&o.path, cost))
78 .collect();
79
80 let mut transition = vec![0.0_f64; m * m];
81 for (i, oi) in options.iter().enumerate() {
82 let from = oi.path.last();
83 let row = &mut transition[i * m..(i + 1) * m];
84 for (j, cell) in row.iter_mut().enumerate() {
85 *cell = cost.eval(from, options[j].path.first()) + traversal[j];
86 }
87 }
88
89 let start = options
90 .iter()
91 .enumerate()
92 .map(|(i, o)| cost.eval(start_q, o.path.first()) + traversal[i])
93 .collect();
94
95 let end = options
96 .iter()
97 .map(|o| end_q.map_or(0.0, |e| cost.eval(o.path.last(), e)))
98 .collect();
99
100 CostMatrices {
101 transition,
102 start,
103 end,
104 }
105}
106
107#[cfg(test)]
108mod tests {
109 use super::*;
110
111 fn q<const N: usize>(a: [f64; N]) -> SRobotQ<N, f64> {
112 SRobotQ(a)
113 }
114
115 #[test]
116 fn weighted_distance_applies_weights() {
117 let origin = q([0.0, 0.0]);
118 let w = q([2.0, 1.0]);
119 assert!((weighted_distance(&origin, &q([1.0, 0.0]), &w) - 2.0).abs() < 1e-12);
121 assert!((weighted_distance(&origin, &q([0.0, 1.0]), &w) - 1.0).abs() < 1e-12);
122 }
123
124 #[test]
125 fn traversal_sums_weighted_segments() {
126 let path = SRobotPath::<2, f64>::try_new(vec![q([0.0, 0.0]), q([1.0, 0.0]), q([1.0, 1.0])])
127 .unwrap();
128 let cost = TransitionCost::JointWeighted(q([2.0, 1.0]));
129 assert!((traversal_cost(&path, &cost) - 3.0).abs() < 1e-12);
131 }
132}