1use crate::{Objective, ObjectiveContext};
4
5pub struct WeightedObjective<T> {
10 objectives: Vec<(Box<dyn Objective<T>>, f64)>,
11}
12
13impl<T> WeightedObjective<T> {
14 pub fn new() -> Self {
16 Self {
17 objectives: Vec::new(),
18 }
19 }
20
21 pub fn add(mut self, objective: Box<dyn Objective<T>>, weight: f64) -> Self {
23 self.objectives.push((objective, weight));
24 self
25 }
26}
27
28impl<T> Default for WeightedObjective<T> {
29 fn default() -> Self {
30 Self::new()
31 }
32}
33
34impl<T: Send + Sync> Objective<T> for WeightedObjective<T> {
35 fn score(&self, candidate: &T, context: &ObjectiveContext) -> f64 {
36 if self.objectives.is_empty() {
37 return 0.0;
38 }
39
40 let mut weighted_sum = 0.0;
41 let mut weight_sum = 0.0;
42
43 for (objective, weight) in &self.objectives {
44 let w = *weight;
45 if !w.is_finite() || w <= 0.0 {
46 continue;
47 }
48
49 let score = objective.score(candidate, context);
50 if !score.is_finite() {
51 continue;
52 }
53
54 weighted_sum += score * w;
55 weight_sum += w;
56 }
57
58 if weight_sum > 0.0 {
59 weighted_sum / weight_sum
60 } else {
61 0.0
62 }
63 }
64
65 fn name(&self) -> &str {
66 "WeightedObjective"
67 }
68}
69
70pub struct PriorityObjective<T> {
76 objectives: Vec<(Box<dyn Objective<T>>, f64)>,
77 fallback: f64,
78}
79
80impl<T> PriorityObjective<T> {
81 pub fn new() -> Self {
83 Self {
84 objectives: Vec::new(),
85 fallback: 0.0,
86 }
87 }
88
89 pub fn add(mut self, objective: Box<dyn Objective<T>>, threshold: f64) -> Self {
91 self.objectives.push((objective, threshold));
92 self
93 }
94
95 pub fn with_fallback(mut self, score: f64) -> Self {
97 self.fallback = score;
98 self
99 }
100}
101
102impl<T> Default for PriorityObjective<T> {
103 fn default() -> Self {
104 Self::new()
105 }
106}
107
108impl<T: Send + Sync> Objective<T> for PriorityObjective<T> {
109 fn score(&self, candidate: &T, context: &ObjectiveContext) -> f64 {
110 for (objective, threshold) in &self.objectives {
111 let score = objective.score(candidate, context);
112 if score.is_finite() && score >= *threshold {
113 return score;
114 }
115 }
116
117 self.fallback
118 }
119
120 fn name(&self) -> &str {
121 "PriorityObjective"
122 }
123}
124
125pub struct ConsensusObjective<T> {
130 objectives: Vec<Box<dyn Objective<T>>>,
131 threshold: f64,
132}
133
134impl<T> ConsensusObjective<T> {
135 pub fn new(threshold: f64) -> Self {
137 Self {
138 objectives: Vec::new(),
139 threshold,
140 }
141 }
142
143 pub fn with_objective(mut self, objective: Box<dyn Objective<T>>) -> Self {
145 self.objectives.push(objective);
146 self
147 }
148}
149
150impl<T: Send + Sync> Objective<T> for ConsensusObjective<T> {
151 fn score(&self, candidate: &T, context: &ObjectiveContext) -> f64 {
152 if self.objectives.is_empty() {
153 return 0.0;
154 }
155
156 let mut log_sum = 0.0f64;
157 let n = self.objectives.len();
158
159 for objective in &self.objectives {
160 let score = objective.score(candidate, context);
161 if !score.is_finite() || score < self.threshold {
162 return 0.0;
163 }
164 if score <= 0.0 {
165 return 0.0;
166 }
167 log_sum += score.ln();
168 }
169
170 (log_sum / n as f64).exp()
172 }
173
174 fn name(&self) -> &str {
175 "ConsensusObjective"
176 }
177}
178
179pub struct UnionObjective<T> {
184 objectives: Vec<Box<dyn Objective<T>>>,
185}
186
187impl<T> UnionObjective<T> {
188 pub fn new() -> Self {
190 Self {
191 objectives: Vec::new(),
192 }
193 }
194
195 pub fn with_objective(mut self, objective: Box<dyn Objective<T>>) -> Self {
197 self.objectives.push(objective);
198 self
199 }
200}
201
202impl<T> Default for UnionObjective<T> {
203 fn default() -> Self {
204 Self::new()
205 }
206}
207
208impl<T: Send + Sync> Objective<T> for UnionObjective<T> {
209 fn score(&self, candidate: &T, context: &ObjectiveContext) -> f64 {
210 self.objectives
211 .iter()
212 .map(|obj| obj.score(candidate, context))
213 .filter(|s| s.is_finite())
214 .fold(0.0f64, |a, b| a.max(b))
215 }
216
217 fn name(&self) -> &str {
218 "UnionObjective"
219 }
220}
221
222pub struct NegateObjective<T> {
224 inner: Box<dyn Objective<T>>,
225}
226
227impl<T> NegateObjective<T> {
228 pub fn new(inner: Box<dyn Objective<T>>) -> Self {
230 Self { inner }
231 }
232}
233
234impl<T: Send + Sync> Objective<T> for NegateObjective<T> {
235 fn score(&self, candidate: &T, context: &ObjectiveContext) -> f64 {
236 1.0 - self.inner.score(candidate, context)
237 }
238
239 fn name(&self) -> &str {
240 "NegateObjective"
241 }
242}
243
244pub struct ScaleObjective<O> {
246 inner: O,
247 factor: f64,
248}
249
250impl<O> ScaleObjective<O> {
251 pub fn new(inner: O, factor: f64) -> Self {
253 Self { inner, factor }
254 }
255}
256
257impl<T, O: Objective<T>> Objective<T> for ScaleObjective<O> {
258 fn score(&self, candidate: &T, context: &ObjectiveContext) -> f64 {
259 self.inner.score(candidate, context) * self.factor
260 }
261
262 fn name(&self) -> &str {
263 "ScaleObjective"
264 }
265}
266
267#[cfg(test)]
268mod tests {
269 use super::*;
270 use crate::objective_fn;
271
272 #[test]
273 fn test_weighted_objective() {
274 let obj1 = objective_fn(|n: &i32, _ctx: &ObjectiveContext| *n as f64);
275 let obj2 = objective_fn(|n: &i32, _ctx: &ObjectiveContext| (*n * 2) as f64);
276
277 let weighted = WeightedObjective::new()
278 .add(Box::new(obj1), 1.0)
279 .add(Box::new(obj2), 1.0);
280
281 let context = ObjectiveContext::new();
282
283 assert_eq!(weighted.score(&5, &context), 7.5);
284 }
285
286 #[test]
287 fn test_weighted_objective_ignores_invalid_weights() {
288 let negative = objective_fn(|_n: &i32, _ctx: &ObjectiveContext| 100.0);
289 let positive = objective_fn(|_n: &i32, _ctx: &ObjectiveContext| 4.0);
290
291 let weighted = WeightedObjective::new()
292 .add(Box::new(negative), -1.0)
293 .add(Box::new(positive), 1.0);
294
295 assert_eq!(weighted.score(&5, &ObjectiveContext::new()), 4.0);
296 }
297
298 #[test]
299 fn test_weighted_objective_requires_positive_finite_denominator() {
300 let negative = objective_fn(|_n: &i32, _ctx: &ObjectiveContext| 100.0);
301 let non_finite = objective_fn(|_n: &i32, _ctx: &ObjectiveContext| 4.0);
302
303 let weighted = WeightedObjective::new()
304 .add(Box::new(negative), -1.0)
305 .add(Box::new(non_finite), f64::INFINITY);
306
307 assert_eq!(weighted.score(&5, &ObjectiveContext::new()), 0.0);
308 }
309
310 #[test]
311 fn test_priority_objective() {
312 let obj1 = objective_fn(
313 |n: &i32, _ctx: &ObjectiveContext| {
314 if *n > 10 {
315 *n as f64
316 } else {
317 0.0
318 }
319 },
320 );
321
322 let obj2 = objective_fn(|n: &i32, _ctx: &ObjectiveContext| *n as f64 / 2.0);
323
324 let priority = PriorityObjective::new()
325 .add(Box::new(obj1), 5.0)
326 .add(Box::new(obj2), 0.0)
327 .with_fallback(-1.0);
328
329 let context = ObjectiveContext::new();
330
331 assert_eq!(priority.score(&15, &context), 15.0);
332 assert_eq!(priority.score(&5, &context), 2.5);
333 }
334
335 #[test]
336 fn test_consensus_objective() {
337 let obj1 = objective_fn(|n: &i32, _ctx: &ObjectiveContext| *n as f64);
338 let obj2 = objective_fn(|n: &i32, _ctx: &ObjectiveContext| (*n * 2) as f64);
339
340 let consensus = ConsensusObjective::new(5.0)
341 .with_objective(Box::new(obj1))
342 .with_objective(Box::new(obj2));
343
344 let context = ObjectiveContext::new();
345
346 let score = consensus.score(&10, &context);
348 let expected = (10.0f64 * 20.0f64).sqrt();
349 assert!(
350 (score - expected).abs() < 1e-9,
351 "expected {expected}, got {score}"
352 );
353
354 assert_eq!(consensus.score(&2, &context), 0.0);
356 }
357
358 #[test]
359 fn test_consensus_objective_empty() {
360 let consensus: ConsensusObjective<i32> = ConsensusObjective::new(0.0);
361 assert_eq!(consensus.score(&10, &ObjectiveContext::new()), 0.0);
362 }
363
364 #[test]
365 fn test_consensus_objective_zero_score_returns_zero() {
366 let obj1 = objective_fn(|_n: &i32, _ctx: &ObjectiveContext| 0.0f64);
367 let obj2 = objective_fn(|n: &i32, _ctx: &ObjectiveContext| *n as f64);
368
369 let consensus = ConsensusObjective::new(0.0)
370 .with_objective(Box::new(obj1))
371 .with_objective(Box::new(obj2));
372
373 assert_eq!(consensus.score(&10, &ObjectiveContext::new()), 0.0);
374 }
375
376 #[test]
377 fn test_union_objective() {
378 let obj1 = objective_fn(|n: &i32, _ctx: &ObjectiveContext| *n as f64);
379 let obj2 = objective_fn(|n: &i32, _ctx: &ObjectiveContext| 100.0 - *n as f64);
380
381 let union = UnionObjective::new()
382 .with_objective(Box::new(obj1))
383 .with_objective(Box::new(obj2));
384
385 let context = ObjectiveContext::new();
386
387 assert_eq!(union.score(&30, &context), 70.0);
388 assert_eq!(union.score(&80, &context), 80.0);
389 }
390
391 #[test]
392 fn test_negate_objective() {
393 let obj = objective_fn(|n: &i32, _ctx: &ObjectiveContext| *n as f64 / 100.0);
394 let negated = NegateObjective::new(Box::new(obj));
395
396 let context = ObjectiveContext::new();
397
398 assert!((negated.score(&30, &context) - 0.7).abs() < 0.001);
399 }
400
401 #[test]
402 fn test_scale_objective() {
403 let obj = objective_fn(|n: &i32, _ctx: &ObjectiveContext| *n as f64);
404 let scaled = ScaleObjective::new(obj, 2.0);
405
406 let context = ObjectiveContext::new();
407
408 assert!((scaled.score(&0, &context) - 0.0).abs() < 0.001);
409 assert!((scaled.score(&5, &context) - 10.0).abs() < 0.001);
410 assert!((scaled.score(&10, &context) - 20.0).abs() < 0.001);
411 }
412
413 #[test]
414 fn test_scale_objective_negative_factor() {
415 let obj = objective_fn(|n: &i32, _ctx: &ObjectiveContext| *n as f64);
416 let scaled = ScaleObjective::new(obj, -1.0);
417
418 let context = ObjectiveContext::new();
419
420 assert!((scaled.score(&5, &context) - (-5.0)).abs() < 0.001);
421 }
422}