Skip to main content

khive_fold/objective/
compose.rs

1//! Objective composition utilities
2
3use crate::{Objective, ObjectiveContext};
4
5/// Weighted combination of multiple objectives.
6///
7/// The final score is: sum(weight_i * score_i) / sum(weight_i).
8/// Invalid weights (non-finite, zero, or negative) and non-finite scores are skipped.
9pub struct WeightedObjective<T> {
10    objectives: Vec<(Box<dyn Objective<T>>, f64)>,
11}
12
13impl<T> WeightedObjective<T> {
14    /// Create a new weighted objective
15    pub fn new() -> Self {
16        Self {
17            objectives: Vec::new(),
18        }
19    }
20
21    /// Add an objective with a weight
22    pub fn add(mut self, objective: Box<dyn Objective<T>>, weight: f64) -> Self {
23        self.objectives.push((objective, weight));
24        self
25    }
26}
27
28impl<T> Default for WeightedObjective<T> {
29    fn default() -> Self {
30        Self::new()
31    }
32}
33
34impl<T: Send + Sync> Objective<T> for WeightedObjective<T> {
35    fn score(&self, candidate: &T, context: &ObjectiveContext) -> f64 {
36        if self.objectives.is_empty() {
37            return 0.0;
38        }
39
40        let mut weighted_sum = 0.0;
41        let mut weight_sum = 0.0;
42
43        for (objective, weight) in &self.objectives {
44            let w = *weight;
45            if !w.is_finite() || w <= 0.0 {
46                continue;
47            }
48
49            let score = objective.score(candidate, context);
50            if !score.is_finite() {
51                continue;
52            }
53
54            weighted_sum += score * w;
55            weight_sum += w;
56        }
57
58        if weight_sum > 0.0 {
59            weighted_sum / weight_sum
60        } else {
61            0.0
62        }
63    }
64
65    fn name(&self) -> &str {
66        "WeightedObjective"
67    }
68}
69
70/// Priority-based objective combination.
71///
72/// Evaluates objectives in priority order. If an objective gives a score
73/// above the threshold, that score is used. Otherwise, falls through to
74/// the next priority level.
75pub struct PriorityObjective<T> {
76    objectives: Vec<(Box<dyn Objective<T>>, f64)>,
77    fallback: f64,
78}
79
80impl<T> PriorityObjective<T> {
81    /// Create a new priority objective
82    pub fn new() -> Self {
83        Self {
84            objectives: Vec::new(),
85            fallback: 0.0,
86        }
87    }
88
89    /// Add an objective with a threshold
90    pub fn add(mut self, objective: Box<dyn Objective<T>>, threshold: f64) -> Self {
91        self.objectives.push((objective, threshold));
92        self
93    }
94
95    /// Set the fallback score
96    pub fn with_fallback(mut self, score: f64) -> Self {
97        self.fallback = score;
98        self
99    }
100}
101
102impl<T> Default for PriorityObjective<T> {
103    fn default() -> Self {
104        Self::new()
105    }
106}
107
108impl<T: Send + Sync> Objective<T> for PriorityObjective<T> {
109    fn score(&self, candidate: &T, context: &ObjectiveContext) -> f64 {
110        for (objective, threshold) in &self.objectives {
111            let score = objective.score(candidate, context);
112            if score.is_finite() && score >= *threshold {
113                return score;
114            }
115        }
116
117        self.fallback
118    }
119
120    fn name(&self) -> &str {
121        "PriorityObjective"
122    }
123}
124
125/// Consensus-based objective combination.
126///
127/// All objectives must agree (scores above threshold) for a candidate
128/// to pass. The final score is the geometric mean of all sub-objective scores.
129pub struct ConsensusObjective<T> {
130    objectives: Vec<Box<dyn Objective<T>>>,
131    threshold: f64,
132}
133
134impl<T> ConsensusObjective<T> {
135    /// Create a new consensus objective
136    pub fn new(threshold: f64) -> Self {
137        Self {
138            objectives: Vec::new(),
139            threshold,
140        }
141    }
142
143    /// Add an objective
144    pub fn with_objective(mut self, objective: Box<dyn Objective<T>>) -> Self {
145        self.objectives.push(objective);
146        self
147    }
148}
149
150impl<T: Send + Sync> Objective<T> for ConsensusObjective<T> {
151    fn score(&self, candidate: &T, context: &ObjectiveContext) -> f64 {
152        if self.objectives.is_empty() {
153            return 0.0;
154        }
155
156        let mut log_sum = 0.0f64;
157        let n = self.objectives.len();
158
159        for objective in &self.objectives {
160            let score = objective.score(candidate, context);
161            if !score.is_finite() || score < self.threshold {
162                return 0.0;
163            }
164            if score <= 0.0 {
165                return 0.0;
166            }
167            log_sum += score.ln();
168        }
169
170        // Geometric mean = exp(sum(ln(score_i)) / n)
171        (log_sum / n as f64).exp()
172    }
173
174    fn name(&self) -> &str {
175        "ConsensusObjective"
176    }
177}
178
179/// Union objective — OR semantics.
180///
181/// Candidate passes if ANY objective gives a score above threshold.
182/// The final score is the maximum of all scores.
183pub struct UnionObjective<T> {
184    objectives: Vec<Box<dyn Objective<T>>>,
185}
186
187impl<T> UnionObjective<T> {
188    /// Create a new union objective
189    pub fn new() -> Self {
190        Self {
191            objectives: Vec::new(),
192        }
193    }
194
195    /// Add an objective
196    pub fn with_objective(mut self, objective: Box<dyn Objective<T>>) -> Self {
197        self.objectives.push(objective);
198        self
199    }
200}
201
202impl<T> Default for UnionObjective<T> {
203    fn default() -> Self {
204        Self::new()
205    }
206}
207
208impl<T: Send + Sync> Objective<T> for UnionObjective<T> {
209    fn score(&self, candidate: &T, context: &ObjectiveContext) -> f64 {
210        self.objectives
211            .iter()
212            .map(|obj| obj.score(candidate, context))
213            .filter(|s| s.is_finite())
214            .fold(0.0f64, |a, b| a.max(b))
215    }
216
217    fn name(&self) -> &str {
218        "UnionObjective"
219    }
220}
221
222/// Negation objective — inverts another objective's score.
223pub struct NegateObjective<T> {
224    inner: Box<dyn Objective<T>>,
225}
226
227impl<T> NegateObjective<T> {
228    /// Create a negation of another objective
229    pub fn new(inner: Box<dyn Objective<T>>) -> Self {
230        Self { inner }
231    }
232}
233
234impl<T: Send + Sync> Objective<T> for NegateObjective<T> {
235    fn score(&self, candidate: &T, context: &ObjectiveContext) -> f64 {
236        1.0 - self.inner.score(candidate, context)
237    }
238
239    fn name(&self) -> &str {
240        "NegateObjective"
241    }
242}
243
244/// Scale objective — multiplies another objective's score by a constant factor.
245pub struct ScaleObjective<O> {
246    inner: O,
247    factor: f64,
248}
249
250impl<O> ScaleObjective<O> {
251    /// Create a scaled objective that multiplies the inner score by `factor`.
252    pub fn new(inner: O, factor: f64) -> Self {
253        Self { inner, factor }
254    }
255}
256
257impl<T, O: Objective<T>> Objective<T> for ScaleObjective<O> {
258    fn score(&self, candidate: &T, context: &ObjectiveContext) -> f64 {
259        self.inner.score(candidate, context) * self.factor
260    }
261
262    fn name(&self) -> &str {
263        "ScaleObjective"
264    }
265}
266
267#[cfg(test)]
268mod tests {
269    use super::*;
270    use crate::objective_fn;
271
272    #[test]
273    fn test_weighted_objective() {
274        let obj1 = objective_fn(|n: &i32, _ctx: &ObjectiveContext| *n as f64);
275        let obj2 = objective_fn(|n: &i32, _ctx: &ObjectiveContext| (*n * 2) as f64);
276
277        let weighted = WeightedObjective::new()
278            .add(Box::new(obj1), 1.0)
279            .add(Box::new(obj2), 1.0);
280
281        let context = ObjectiveContext::new();
282
283        assert_eq!(weighted.score(&5, &context), 7.5);
284    }
285
286    #[test]
287    fn test_weighted_objective_ignores_invalid_weights() {
288        let negative = objective_fn(|_n: &i32, _ctx: &ObjectiveContext| 100.0);
289        let positive = objective_fn(|_n: &i32, _ctx: &ObjectiveContext| 4.0);
290
291        let weighted = WeightedObjective::new()
292            .add(Box::new(negative), -1.0)
293            .add(Box::new(positive), 1.0);
294
295        assert_eq!(weighted.score(&5, &ObjectiveContext::new()), 4.0);
296    }
297
298    #[test]
299    fn test_weighted_objective_requires_positive_finite_denominator() {
300        let negative = objective_fn(|_n: &i32, _ctx: &ObjectiveContext| 100.0);
301        let non_finite = objective_fn(|_n: &i32, _ctx: &ObjectiveContext| 4.0);
302
303        let weighted = WeightedObjective::new()
304            .add(Box::new(negative), -1.0)
305            .add(Box::new(non_finite), f64::INFINITY);
306
307        assert_eq!(weighted.score(&5, &ObjectiveContext::new()), 0.0);
308    }
309
310    #[test]
311    fn test_priority_objective() {
312        let obj1 = objective_fn(
313            |n: &i32, _ctx: &ObjectiveContext| {
314                if *n > 10 {
315                    *n as f64
316                } else {
317                    0.0
318                }
319            },
320        );
321
322        let obj2 = objective_fn(|n: &i32, _ctx: &ObjectiveContext| *n as f64 / 2.0);
323
324        let priority = PriorityObjective::new()
325            .add(Box::new(obj1), 5.0)
326            .add(Box::new(obj2), 0.0)
327            .with_fallback(-1.0);
328
329        let context = ObjectiveContext::new();
330
331        assert_eq!(priority.score(&15, &context), 15.0);
332        assert_eq!(priority.score(&5, &context), 2.5);
333    }
334
335    #[test]
336    fn test_consensus_objective() {
337        let obj1 = objective_fn(|n: &i32, _ctx: &ObjectiveContext| *n as f64);
338        let obj2 = objective_fn(|n: &i32, _ctx: &ObjectiveContext| (*n * 2) as f64);
339
340        let consensus = ConsensusObjective::new(5.0)
341            .with_objective(Box::new(obj1))
342            .with_objective(Box::new(obj2));
343
344        let context = ObjectiveContext::new();
345
346        // scores are 10 and 20 → geometric mean = sqrt(10 * 20) = sqrt(200) ≈ 14.142
347        let score = consensus.score(&10, &context);
348        let expected = (10.0f64 * 20.0f64).sqrt();
349        assert!(
350            (score - expected).abs() < 1e-9,
351            "expected {expected}, got {score}"
352        );
353
354        // candidate 2 → scores 2 and 4, both below threshold 5 → 0.0
355        assert_eq!(consensus.score(&2, &context), 0.0);
356    }
357
358    #[test]
359    fn test_consensus_objective_empty() {
360        let consensus: ConsensusObjective<i32> = ConsensusObjective::new(0.0);
361        assert_eq!(consensus.score(&10, &ObjectiveContext::new()), 0.0);
362    }
363
364    #[test]
365    fn test_consensus_objective_zero_score_returns_zero() {
366        let obj1 = objective_fn(|_n: &i32, _ctx: &ObjectiveContext| 0.0f64);
367        let obj2 = objective_fn(|n: &i32, _ctx: &ObjectiveContext| *n as f64);
368
369        let consensus = ConsensusObjective::new(0.0)
370            .with_objective(Box::new(obj1))
371            .with_objective(Box::new(obj2));
372
373        assert_eq!(consensus.score(&10, &ObjectiveContext::new()), 0.0);
374    }
375
376    #[test]
377    fn test_union_objective() {
378        let obj1 = objective_fn(|n: &i32, _ctx: &ObjectiveContext| *n as f64);
379        let obj2 = objective_fn(|n: &i32, _ctx: &ObjectiveContext| 100.0 - *n as f64);
380
381        let union = UnionObjective::new()
382            .with_objective(Box::new(obj1))
383            .with_objective(Box::new(obj2));
384
385        let context = ObjectiveContext::new();
386
387        assert_eq!(union.score(&30, &context), 70.0);
388        assert_eq!(union.score(&80, &context), 80.0);
389    }
390
391    #[test]
392    fn test_negate_objective() {
393        let obj = objective_fn(|n: &i32, _ctx: &ObjectiveContext| *n as f64 / 100.0);
394        let negated = NegateObjective::new(Box::new(obj));
395
396        let context = ObjectiveContext::new();
397
398        assert!((negated.score(&30, &context) - 0.7).abs() < 0.001);
399    }
400
401    #[test]
402    fn test_scale_objective() {
403        let obj = objective_fn(|n: &i32, _ctx: &ObjectiveContext| *n as f64);
404        let scaled = ScaleObjective::new(obj, 2.0);
405
406        let context = ObjectiveContext::new();
407
408        assert!((scaled.score(&0, &context) - 0.0).abs() < 0.001);
409        assert!((scaled.score(&5, &context) - 10.0).abs() < 0.001);
410        assert!((scaled.score(&10, &context) - 20.0).abs() < 0.001);
411    }
412
413    #[test]
414    fn test_scale_objective_negative_factor() {
415        let obj = objective_fn(|n: &i32, _ctx: &ObjectiveContext| *n as f64);
416        let scaled = ScaleObjective::new(obj, -1.0);
417
418        let context = ObjectiveContext::new();
419
420        assert!((scaled.score(&5, &context) - (-5.0)).abs() < 0.001);
421    }
422}