1use oximo_expr::Expr;
2
3use crate::set::{FromIndexKey, Set};
4
5#[diagnostic::on_unimplemented(
13 message = "`{Self}` is not a valid index domain over key type `{K}`",
14 label = "the domain's keys are not `{K}`",
15 note = "the loop/closure binding type must match the domain's keys",
16 note = "for a `Set<T>` write `for x in set` (the key type is inferred) or annotate `for x: T in set`. Integer ranges yield `usize`/`i64`/`i32`. A slice/`Vec`/array yields its element type"
17)]
18pub trait SumDomain<K> {
19 fn keys(&self) -> impl Iterator<Item = K> + '_;
20}
21
22impl<K: FromIndexKey> SumDomain<K> for Set<K> {
27 fn keys(&self) -> impl Iterator<Item = K> + '_ {
28 self.iter().map(|k| K::from_index_key(&k))
29 }
30}
31
32impl<K: Copy> SumDomain<K> for [K] {
33 fn keys(&self) -> impl Iterator<Item = K> + '_ {
34 self.iter().copied()
35 }
36}
37
38impl<K: Copy> SumDomain<K> for Vec<K> {
39 fn keys(&self) -> impl Iterator<Item = K> + '_ {
40 self.iter().copied()
41 }
42}
43
44impl<K: Copy, const N: usize> SumDomain<K> for [K; N] {
45 fn keys(&self) -> impl Iterator<Item = K> + '_ {
46 self.iter().copied()
47 }
48}
49
50impl<K, D: SumDomain<K> + ?Sized> SumDomain<K> for &D {
53 fn keys(&self) -> impl Iterator<Item = K> + '_ {
54 (**self).keys()
55 }
56}
57
58impl SumDomain<usize> for std::ops::Range<usize> {
62 fn keys(&self) -> impl Iterator<Item = usize> + '_ {
63 self.clone()
64 }
65}
66
67impl SumDomain<i64> for std::ops::Range<i64> {
68 fn keys(&self) -> impl Iterator<Item = i64> + '_ {
69 self.clone()
70 }
71}
72
73impl SumDomain<i32> for std::ops::Range<i32> {
74 fn keys(&self) -> impl Iterator<Item = i32> + '_ {
75 self.clone()
76 }
77}
78
79#[deprecated(
90 since = "0.3.0",
91 note = "use the `sum!` macro, the builder API is scheduled for removal in 0.4.0"
92)]
93pub fn sum_over<'a, K, D, F>(domain: &D, f: F) -> Expr<'a>
94where
95 D: SumDomain<K> + ?Sized,
96 F: FnMut(K) -> Expr<'a>,
97{
98 __sum_over(domain, f)
99}
100
101#[doc(hidden)]
104pub fn __sum_over<'a, K, D, F>(domain: &D, f: F) -> Expr<'a>
105where
106 D: SumDomain<K> + ?Sized,
107 F: FnMut(K) -> Expr<'a>,
108{
109 let terms: Vec<Expr<'a>> = domain.keys().map(f).collect();
110 assert!(!terms.is_empty(), "sum_over on empty domain");
111 terms.into_iter().sum()
112}
113
114#[cfg(test)]
115#[allow(deprecated)]
117mod tests {
118 use oximo_expr::extract_linear;
119
120 use super::*;
121 use crate::model::Model;
122
123 #[test]
124 fn sum_over_scalar_set() {
125 let m = Model::new("scalar");
126 let items = Set::range(0..4);
127 let x = m.indexed_var("x", &items).lb(0.0).build();
128
129 let total = sum_over(&items, |i: usize| x[i]);
130 let arena = m.arena();
131 let terms = extract_linear(&arena, total.id).expect("linear");
132 assert_eq!(terms.coeffs.len(), 4);
133 assert!(terms.coeffs.iter().all(|(_, c)| (c - 1.0).abs() < f64::EPSILON));
134 }
135
136 #[test]
137 fn sum_over_tuple_set() {
138 let m = Model::new("tuple");
139 let plants = Set::strings(["seattle", "san-diego"]);
140 let markets = Set::strings(["nyc", "chicago", "topeka"]);
141 let routes = &plants * &markets;
142 let x = m.indexed_var("x", &routes).lb(0.0).build();
143
144 let total = sum_over(&routes, |(p, q): (String, String)| x[(p, q)]);
145 let arena = m.arena();
146 let terms = extract_linear(&arena, total.id).expect("linear");
147 assert_eq!(terms.coeffs.len(), 6);
148 }
149
150 #[test]
151 fn nested_sum_over_double_sum() {
152 let m = Model::new("nested");
153 let plants = Set::strings(["a", "b"]);
154 let markets = Set::strings(["x", "y", "z"]);
155 let routes = &plants * &markets;
156 let x = m.indexed_var("x", &routes).lb(0.0).build();
157
158 let total = sum_over(&plants, |p: String| sum_over(&markets, |q: String| x[(&p, q)]));
159 let arena = m.arena();
160 let terms = extract_linear(&arena, total.id).expect("linear");
161 assert_eq!(terms.coeffs.len(), 6);
162 }
163
164 #[test]
165 fn sum_over_passes_typed_usize_key() {
166 let m = Model::new("usizekey");
167 let items = Set::range(0..3);
168 let x = m.indexed_var("x", &items).lb(0.0).build();
169
170 let total = sum_over(&items, |i: usize| x[i]);
171 let arena = m.arena();
172 let terms = extract_linear(&arena, total.id).expect("linear");
173 assert_eq!(terms.coeffs.len(), 3);
174 }
175
176 #[test]
177 fn sum_over_slice_of_usize() {
178 let m = Model::new("slice");
179 let items = Set::range(0..5);
180 let x = m.indexed_var("x", &items).lb(0.0).build();
181
182 let picked: &[usize] = &[0, 2, 4];
183 let total = sum_over(picked, |i: usize| x[i]);
184 let arena = m.arena();
185 let terms = extract_linear(&arena, total.id).expect("linear");
186 assert_eq!(terms.coeffs.len(), 3);
187 }
188
189 #[test]
190 fn sum_over_vec_of_usize() {
191 let m = Model::new("vec");
192 let items = Set::range(0..5);
193 let x = m.indexed_var("x", &items).lb(0.0).build();
194
195 let picked: Vec<usize> = vec![1, 3];
196 let total = sum_over(&picked, |i: usize| x[i]);
197 let arena = m.arena();
198 let terms = extract_linear(&arena, total.id).expect("linear");
199 assert_eq!(terms.coeffs.len(), 2);
200 }
201
202 #[test]
203 fn sum_over_array_of_usize() {
204 let m = Model::new("array");
205 let items = Set::range(0..5);
206 let x = m.indexed_var("x", &items).lb(0.0).build();
207
208 let picked: [usize; 4] = [0, 1, 2, 3];
209 let total = sum_over(&picked, |i: usize| x[i]);
210 let arena = m.arena();
211 let terms = extract_linear(&arena, total.id).expect("linear");
212 assert_eq!(terms.coeffs.len(), 4);
213 }
214
215 #[test]
216 #[should_panic(expected = "sum_over on empty domain")]
217 fn sum_over_empty_set_panics() {
218 let m = Model::new("empty");
219 let empty = Set::range(0..0);
220 let _x = m.indexed_var("x", &Set::range(0..1)).lb(0.0).build();
221 let _ = sum_over(&empty, |_: usize| panic!("closure should not run"));
222 }
223
224 #[test]
225 #[should_panic(expected = "sum_over on empty domain")]
226 fn sum_over_empty_slice_panics() {
227 let m = Model::new("empty_slice");
228 let _x = m.indexed_var("x", &Set::range(0..1)).lb(0.0).build();
229 let empty: &[usize] = &[];
230 let _ = sum_over(empty, |_: usize| panic!("closure should not run"));
231 }
232}