Skip to main content

oximo_core/
sum.rs

1use oximo_expr::Expr;
2
3use crate::set::{FromIndexKey, Set};
4
5/// Domain over which [`sum_over`] iterates. Lets a single `sum_over` call
6/// accept either a [`Set`] (with typed key decoding via [`FromIndexKey`])
7/// or a borrowed slice of `Copy` keys, without intermediate conversions.
8///
9/// Returns an iterator (rather than taking a callback) so the trait method
10/// monomorphizes through to the loop body in [`sum_over`], allowing inlining
11/// in hot sums. Implementations are typically one line.
12#[diagnostic::on_unimplemented(
13    message = "`{Self}` is not a valid index domain over key type `{K}`",
14    label = "the domain's keys are not `{K}`",
15    note = "the loop/closure binding type must match the domain's keys",
16    note = "for a `Set<T>` write `for x in set` (the key type is inferred) or annotate `for x: T in set`. Integer ranges yield `usize`/`i64`/`i32`. A slice/`Vec`/array yields its element type"
17)]
18pub trait SumDomain<K> {
19    fn keys(&self) -> impl Iterator<Item = K> + '_;
20}
21
22// A typed set yields exactly its own key type.
23// The single `SumDomain` impl for `Set<K>`, so `sum!`/`constraint!`
24// can infer the closure parameter without an annotation
25// (the erased `Set` defaulted to `Set<IndexKey>`).
26impl<K: FromIndexKey> SumDomain<K> for Set<K> {
27    fn keys(&self) -> impl Iterator<Item = K> + '_ {
28        self.iter().map(|k| K::from_index_key(&k))
29    }
30}
31
32impl<K: Copy> SumDomain<K> for [K] {
33    fn keys(&self) -> impl Iterator<Item = K> + '_ {
34        self.iter().copied()
35    }
36}
37
38impl<K: Copy> SumDomain<K> for Vec<K> {
39    fn keys(&self) -> impl Iterator<Item = K> + '_ {
40        self.iter().copied()
41    }
42}
43
44impl<K: Copy, const N: usize> SumDomain<K> for [K; N] {
45    fn keys(&self) -> impl Iterator<Item = K> + '_ {
46        self.iter().copied()
47    }
48}
49
50// Forward through a reference, so a domain that is itself a reference (e.g. a
51// `&Set` function parameter passed to `sum!`/`constraint!`) is accepted.
52impl<K, D: SumDomain<K> + ?Sized> SumDomain<K> for &D {
53    fn keys(&self) -> impl Iterator<Item = K> + '_ {
54        (**self).keys()
55    }
56}
57
58// Integer ranges as sum domains. Iteration is lazy, so `sum!(x[i] for i in 0..n)`
59// allocates nothing. Provided for the common integer types the `sum!`/`constraint!`
60// macros default to.
61impl SumDomain<usize> for std::ops::Range<usize> {
62    fn keys(&self) -> impl Iterator<Item = usize> + '_ {
63        self.clone()
64    }
65}
66
67impl SumDomain<i64> for std::ops::Range<i64> {
68    fn keys(&self) -> impl Iterator<Item = i64> + '_ {
69        self.clone()
70    }
71}
72
73impl SumDomain<i32> for std::ops::Range<i32> {
74    fn keys(&self) -> impl Iterator<Item = i32> + '_ {
75        self.clone()
76    }
77}
78
79/// Sum an expression over every element of a domain.
80///
81/// Reads as the mathematical `sum_{k in domain} f(k)`. The closure parameter is
82/// either decoded from the domain's [`crate::set::IndexKey`] via [`FromIndexKey`] (when
83/// the domain is a [`Set`]) or yielded directly (when the domain is a slice
84/// of `Copy` keys).
85///
86/// # Panics
87/// Panics if `domain` is empty, the resulting expression has no arena to
88/// attach to.
89#[deprecated(
90    since = "0.3.0",
91    note = "use the `sum!` macro, the builder API is scheduled for removal in 0.4.0"
92)]
93pub fn sum_over<'a, K, D, F>(domain: &D, f: F) -> Expr<'a>
94where
95    D: SumDomain<K> + ?Sized,
96    F: FnMut(K) -> Expr<'a>,
97{
98    __sum_over(domain, f)
99}
100
101/// Macro-facing entry point behind [`sum_over`]. Backs the `sum!` macro. Not
102/// part of the stable public API.
103#[doc(hidden)]
104pub fn __sum_over<'a, K, D, F>(domain: &D, f: F) -> Expr<'a>
105where
106    D: SumDomain<K> + ?Sized,
107    F: FnMut(K) -> Expr<'a>,
108{
109    let terms: Vec<Expr<'a>> = domain.keys().map(f).collect();
110    assert!(!terms.is_empty(), "sum_over on empty domain");
111    terms.into_iter().sum()
112}
113
114#[cfg(test)]
115// exercises the builder API directly until its 0.4.0 removal
116#[allow(deprecated)]
117mod tests {
118    use oximo_expr::extract_linear;
119
120    use super::*;
121    use crate::model::Model;
122
123    #[test]
124    fn sum_over_scalar_set() {
125        let m = Model::new("scalar");
126        let items = Set::range(0..4);
127        let x = m.indexed_var("x", &items).lb(0.0).build();
128
129        let total = sum_over(&items, |i: usize| x[i]);
130        let arena = m.arena();
131        let terms = extract_linear(&arena, total.id).expect("linear");
132        assert_eq!(terms.coeffs.len(), 4);
133        assert!(terms.coeffs.iter().all(|(_, c)| (c - 1.0).abs() < f64::EPSILON));
134    }
135
136    #[test]
137    fn sum_over_tuple_set() {
138        let m = Model::new("tuple");
139        let plants = Set::strings(["seattle", "san-diego"]);
140        let markets = Set::strings(["nyc", "chicago", "topeka"]);
141        let routes = &plants * &markets;
142        let x = m.indexed_var("x", &routes).lb(0.0).build();
143
144        let total = sum_over(&routes, |(p, q): (String, String)| x[(p, q)]);
145        let arena = m.arena();
146        let terms = extract_linear(&arena, total.id).expect("linear");
147        assert_eq!(terms.coeffs.len(), 6);
148    }
149
150    #[test]
151    fn nested_sum_over_double_sum() {
152        let m = Model::new("nested");
153        let plants = Set::strings(["a", "b"]);
154        let markets = Set::strings(["x", "y", "z"]);
155        let routes = &plants * &markets;
156        let x = m.indexed_var("x", &routes).lb(0.0).build();
157
158        let total = sum_over(&plants, |p: String| sum_over(&markets, |q: String| x[(&p, q)]));
159        let arena = m.arena();
160        let terms = extract_linear(&arena, total.id).expect("linear");
161        assert_eq!(terms.coeffs.len(), 6);
162    }
163
164    #[test]
165    fn sum_over_passes_typed_usize_key() {
166        let m = Model::new("usizekey");
167        let items = Set::range(0..3);
168        let x = m.indexed_var("x", &items).lb(0.0).build();
169
170        let total = sum_over(&items, |i: usize| x[i]);
171        let arena = m.arena();
172        let terms = extract_linear(&arena, total.id).expect("linear");
173        assert_eq!(terms.coeffs.len(), 3);
174    }
175
176    #[test]
177    fn sum_over_slice_of_usize() {
178        let m = Model::new("slice");
179        let items = Set::range(0..5);
180        let x = m.indexed_var("x", &items).lb(0.0).build();
181
182        let picked: &[usize] = &[0, 2, 4];
183        let total = sum_over(picked, |i: usize| x[i]);
184        let arena = m.arena();
185        let terms = extract_linear(&arena, total.id).expect("linear");
186        assert_eq!(terms.coeffs.len(), 3);
187    }
188
189    #[test]
190    fn sum_over_vec_of_usize() {
191        let m = Model::new("vec");
192        let items = Set::range(0..5);
193        let x = m.indexed_var("x", &items).lb(0.0).build();
194
195        let picked: Vec<usize> = vec![1, 3];
196        let total = sum_over(&picked, |i: usize| x[i]);
197        let arena = m.arena();
198        let terms = extract_linear(&arena, total.id).expect("linear");
199        assert_eq!(terms.coeffs.len(), 2);
200    }
201
202    #[test]
203    fn sum_over_array_of_usize() {
204        let m = Model::new("array");
205        let items = Set::range(0..5);
206        let x = m.indexed_var("x", &items).lb(0.0).build();
207
208        let picked: [usize; 4] = [0, 1, 2, 3];
209        let total = sum_over(&picked, |i: usize| x[i]);
210        let arena = m.arena();
211        let terms = extract_linear(&arena, total.id).expect("linear");
212        assert_eq!(terms.coeffs.len(), 4);
213    }
214
215    #[test]
216    #[should_panic(expected = "sum_over on empty domain")]
217    fn sum_over_empty_set_panics() {
218        let m = Model::new("empty");
219        let empty = Set::range(0..0);
220        let _x = m.indexed_var("x", &Set::range(0..1)).lb(0.0).build();
221        let _ = sum_over(&empty, |_: usize| panic!("closure should not run"));
222    }
223
224    #[test]
225    #[should_panic(expected = "sum_over on empty domain")]
226    fn sum_over_empty_slice_panics() {
227        let m = Model::new("empty_slice");
228        let _x = m.indexed_var("x", &Set::range(0..1)).lb(0.0).build();
229        let empty: &[usize] = &[];
230        let _ = sum_over(empty, |_: usize| panic!("closure should not run"));
231    }
232}