Skip to main content

oximo_core/
sum.rs

1use oximo_expr::Expr;
2
3use crate::set::{FromIndexKey, Set};
4
5/// Domain over which [`sum_over`] iterates. Lets a single `sum_over` call
6/// accept either a [`Set`] (with typed key decoding via [`FromIndexKey`])
7/// or a borrowed slice of `Copy` keys, without intermediate conversions.
8///
9/// Returns an iterator (rather than taking a callback) so the trait method
10/// monomorphizes through to the loop body in [`sum_over`], allowing inlining
11/// in hot sums. Implementations are typically one line.
12pub trait SumDomain<K> {
13    fn keys(&self) -> impl Iterator<Item = K> + '_;
14}
15
16impl<K: FromIndexKey> SumDomain<K> for Set {
17    fn keys(&self) -> impl Iterator<Item = K> + '_ {
18        self.iter().map(|k| K::from_index_key(&k))
19    }
20}
21
22impl<K: Copy> SumDomain<K> for [K] {
23    fn keys(&self) -> impl Iterator<Item = K> + '_ {
24        self.iter().copied()
25    }
26}
27
28impl<K: Copy> SumDomain<K> for Vec<K> {
29    fn keys(&self) -> impl Iterator<Item = K> + '_ {
30        self.iter().copied()
31    }
32}
33
34impl<K: Copy, const N: usize> SumDomain<K> for [K; N] {
35    fn keys(&self) -> impl Iterator<Item = K> + '_ {
36        self.iter().copied()
37    }
38}
39
40/// Sum an expression over every element of a domain.
41///
42/// Reads as the mathematical `sum_{k in domain} f(k)`. The closure parameter is
43/// either decoded from the domain's [`crate::set::IndexKey`] via [`FromIndexKey`] (when
44/// the domain is a [`Set`]) or yielded directly (when the domain is a slice
45/// of `Copy` keys).
46///
47/// # Panics
48/// Panics if `domain` is empty, the resulting expression has no arena to
49/// attach to.
50pub fn sum_over<'a, K, D, F>(domain: &D, mut f: F) -> Expr<'a>
51where
52    D: SumDomain<K> + ?Sized,
53    F: FnMut(K) -> Expr<'a>,
54{
55    let mut iter = domain.keys();
56    let first = f(iter.next().expect("sum_over on empty domain"));
57    iter.fold(first, |acc, k| acc + f(k))
58}
59
60#[cfg(test)]
61mod tests {
62    use oximo_expr::extract_linear;
63
64    use super::*;
65    use crate::model::Model;
66    use crate::set::IndexKey;
67
68    #[test]
69    fn sum_over_scalar_set() {
70        let m = Model::new("scalar");
71        let items = Set::range(0..4);
72        let x = m.indexed_var("x", &items).lb(0.0).build();
73
74        let total = sum_over(&items, |i: usize| x[i]);
75        let arena = m.arena();
76        let terms = extract_linear(&arena, total.id).expect("linear");
77        assert_eq!(terms.coeffs.len(), 4);
78        assert!(terms.coeffs.iter().all(|(_, c)| (c - 1.0).abs() < f64::EPSILON));
79    }
80
81    #[test]
82    fn sum_over_tuple_set() {
83        let m = Model::new("tuple");
84        let plants = Set::strings(["seattle", "san-diego"]);
85        let markets = Set::strings(["nyc", "chicago", "topeka"]);
86        let routes = &plants * &markets;
87        let x = m.indexed_var("x", &routes).lb(0.0).build();
88
89        let total = sum_over(&routes, |(p, q): (String, String)| x[(p, q)]);
90        let arena = m.arena();
91        let terms = extract_linear(&arena, total.id).expect("linear");
92        assert_eq!(terms.coeffs.len(), 6);
93    }
94
95    #[test]
96    fn nested_sum_over_double_sum() {
97        let m = Model::new("nested");
98        let plants = Set::strings(["a", "b"]);
99        let markets = Set::strings(["x", "y", "z"]);
100        let routes = &plants * &markets;
101        let x = m.indexed_var("x", &routes).lb(0.0).build();
102
103        let total = sum_over(&plants, |p: String| sum_over(&markets, |q: String| x[(&p, q)]));
104        let arena = m.arena();
105        let terms = extract_linear(&arena, total.id).expect("linear");
106        assert_eq!(terms.coeffs.len(), 6);
107    }
108
109    #[test]
110    fn sum_over_passes_raw_index_key() {
111        let m = Model::new("rawkey");
112        let items = Set::range(0..3);
113        let x = m.indexed_var("x", &items).lb(0.0).build();
114
115        let total = sum_over(&items, |k: IndexKey| x[k]);
116        let arena = m.arena();
117        let terms = extract_linear(&arena, total.id).expect("linear");
118        assert_eq!(terms.coeffs.len(), 3);
119    }
120
121    #[test]
122    fn sum_over_slice_of_usize() {
123        let m = Model::new("slice");
124        let items = Set::range(0..5);
125        let x = m.indexed_var("x", &items).lb(0.0).build();
126
127        let picked: &[usize] = &[0, 2, 4];
128        let total = sum_over(picked, |i: usize| x[i]);
129        let arena = m.arena();
130        let terms = extract_linear(&arena, total.id).expect("linear");
131        assert_eq!(terms.coeffs.len(), 3);
132    }
133
134    #[test]
135    fn sum_over_vec_of_usize() {
136        let m = Model::new("vec");
137        let items = Set::range(0..5);
138        let x = m.indexed_var("x", &items).lb(0.0).build();
139
140        let picked: Vec<usize> = vec![1, 3];
141        let total = sum_over(&picked, |i: usize| x[i]);
142        let arena = m.arena();
143        let terms = extract_linear(&arena, total.id).expect("linear");
144        assert_eq!(terms.coeffs.len(), 2);
145    }
146
147    #[test]
148    fn sum_over_array_of_usize() {
149        let m = Model::new("array");
150        let items = Set::range(0..5);
151        let x = m.indexed_var("x", &items).lb(0.0).build();
152
153        let picked: [usize; 4] = [0, 1, 2, 3];
154        let total = sum_over(&picked, |i: usize| x[i]);
155        let arena = m.arena();
156        let terms = extract_linear(&arena, total.id).expect("linear");
157        assert_eq!(terms.coeffs.len(), 4);
158    }
159
160    #[test]
161    #[should_panic(expected = "sum_over on empty domain")]
162    fn sum_over_empty_set_panics() {
163        let m = Model::new("empty");
164        let empty = Set::range(0..0);
165        let _x = m.indexed_var("x", &Set::range(0..1)).lb(0.0).build();
166        let _ = sum_over(&empty, |_: usize| panic!("closure should not run"));
167    }
168
169    #[test]
170    #[should_panic(expected = "sum_over on empty domain")]
171    fn sum_over_empty_slice_panics() {
172        let m = Model::new("empty_slice");
173        let _x = m.indexed_var("x", &Set::range(0..1)).lb(0.0).build();
174        let empty: &[usize] = &[];
175        let _ = sum_over(empty, |_: usize| panic!("closure should not run"));
176    }
177}