1use oximo_expr::Expr;
2
3use crate::set::{FromIndexKey, Set};
4
5pub trait SumDomain<K> {
13 fn keys(&self) -> impl Iterator<Item = K> + '_;
14}
15
16impl<K: FromIndexKey> SumDomain<K> for Set {
17 fn keys(&self) -> impl Iterator<Item = K> + '_ {
18 self.iter().map(|k| K::from_index_key(&k))
19 }
20}
21
22impl<K: Copy> SumDomain<K> for [K] {
23 fn keys(&self) -> impl Iterator<Item = K> + '_ {
24 self.iter().copied()
25 }
26}
27
28impl<K: Copy> SumDomain<K> for Vec<K> {
29 fn keys(&self) -> impl Iterator<Item = K> + '_ {
30 self.iter().copied()
31 }
32}
33
34impl<K: Copy, const N: usize> SumDomain<K> for [K; N] {
35 fn keys(&self) -> impl Iterator<Item = K> + '_ {
36 self.iter().copied()
37 }
38}
39
40pub fn sum_over<'a, K, D, F>(domain: &D, mut f: F) -> Expr<'a>
51where
52 D: SumDomain<K> + ?Sized,
53 F: FnMut(K) -> Expr<'a>,
54{
55 let mut iter = domain.keys();
56 let first = f(iter.next().expect("sum_over on empty domain"));
57 iter.fold(first, |acc, k| acc + f(k))
58}
59
60#[cfg(test)]
61mod tests {
62 use oximo_expr::extract_linear;
63
64 use super::*;
65 use crate::model::Model;
66 use crate::set::IndexKey;
67
68 #[test]
69 fn sum_over_scalar_set() {
70 let m = Model::new("scalar");
71 let items = Set::range(0..4);
72 let x = m.indexed_var("x", &items).lb(0.0).build();
73
74 let total = sum_over(&items, |i: usize| x[i]);
75 let arena = m.arena();
76 let terms = extract_linear(&arena, total.id).expect("linear");
77 assert_eq!(terms.coeffs.len(), 4);
78 assert!(terms.coeffs.iter().all(|(_, c)| (c - 1.0).abs() < f64::EPSILON));
79 }
80
81 #[test]
82 fn sum_over_tuple_set() {
83 let m = Model::new("tuple");
84 let plants = Set::strings(["seattle", "san-diego"]);
85 let markets = Set::strings(["nyc", "chicago", "topeka"]);
86 let routes = &plants * &markets;
87 let x = m.indexed_var("x", &routes).lb(0.0).build();
88
89 let total = sum_over(&routes, |(p, q): (String, String)| x[(p, q)]);
90 let arena = m.arena();
91 let terms = extract_linear(&arena, total.id).expect("linear");
92 assert_eq!(terms.coeffs.len(), 6);
93 }
94
95 #[test]
96 fn nested_sum_over_double_sum() {
97 let m = Model::new("nested");
98 let plants = Set::strings(["a", "b"]);
99 let markets = Set::strings(["x", "y", "z"]);
100 let routes = &plants * &markets;
101 let x = m.indexed_var("x", &routes).lb(0.0).build();
102
103 let total = sum_over(&plants, |p: String| sum_over(&markets, |q: String| x[(&p, q)]));
104 let arena = m.arena();
105 let terms = extract_linear(&arena, total.id).expect("linear");
106 assert_eq!(terms.coeffs.len(), 6);
107 }
108
109 #[test]
110 fn sum_over_passes_raw_index_key() {
111 let m = Model::new("rawkey");
112 let items = Set::range(0..3);
113 let x = m.indexed_var("x", &items).lb(0.0).build();
114
115 let total = sum_over(&items, |k: IndexKey| x[k]);
116 let arena = m.arena();
117 let terms = extract_linear(&arena, total.id).expect("linear");
118 assert_eq!(terms.coeffs.len(), 3);
119 }
120
121 #[test]
122 fn sum_over_slice_of_usize() {
123 let m = Model::new("slice");
124 let items = Set::range(0..5);
125 let x = m.indexed_var("x", &items).lb(0.0).build();
126
127 let picked: &[usize] = &[0, 2, 4];
128 let total = sum_over(picked, |i: usize| x[i]);
129 let arena = m.arena();
130 let terms = extract_linear(&arena, total.id).expect("linear");
131 assert_eq!(terms.coeffs.len(), 3);
132 }
133
134 #[test]
135 fn sum_over_vec_of_usize() {
136 let m = Model::new("vec");
137 let items = Set::range(0..5);
138 let x = m.indexed_var("x", &items).lb(0.0).build();
139
140 let picked: Vec<usize> = vec![1, 3];
141 let total = sum_over(&picked, |i: usize| x[i]);
142 let arena = m.arena();
143 let terms = extract_linear(&arena, total.id).expect("linear");
144 assert_eq!(terms.coeffs.len(), 2);
145 }
146
147 #[test]
148 fn sum_over_array_of_usize() {
149 let m = Model::new("array");
150 let items = Set::range(0..5);
151 let x = m.indexed_var("x", &items).lb(0.0).build();
152
153 let picked: [usize; 4] = [0, 1, 2, 3];
154 let total = sum_over(&picked, |i: usize| x[i]);
155 let arena = m.arena();
156 let terms = extract_linear(&arena, total.id).expect("linear");
157 assert_eq!(terms.coeffs.len(), 4);
158 }
159
160 #[test]
161 #[should_panic(expected = "sum_over on empty domain")]
162 fn sum_over_empty_set_panics() {
163 let m = Model::new("empty");
164 let empty = Set::range(0..0);
165 let _x = m.indexed_var("x", &Set::range(0..1)).lb(0.0).build();
166 let _ = sum_over(&empty, |_: usize| panic!("closure should not run"));
167 }
168
169 #[test]
170 #[should_panic(expected = "sum_over on empty domain")]
171 fn sum_over_empty_slice_panics() {
172 let m = Model::new("empty_slice");
173 let _x = m.indexed_var("x", &Set::range(0..1)).lb(0.0).build();
174 let empty: &[usize] = &[];
175 let _ = sum_over(empty, |_: usize| panic!("closure should not run"));
176 }
177}