1use serde::{Deserialize, Serialize};
7
8use crate::error::FoldError;
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct SelectorInput<T> {
13 pub id: String,
14 pub content: T,
15 pub size: usize,
17 pub score: f32,
19 #[serde(default)]
21 pub category: Option<String>,
22 #[serde(default)]
29 pub information_gain: Option<f32>,
30}
31
32#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct SelectorOutput<T> {
35 pub selected: Vec<SelectorInput<T>>,
37 pub total_size: usize,
39 pub budget: usize,
41}
42
43#[derive(Debug, Clone, Default, Serialize, Deserialize)]
47pub struct SelectorWeights {
48 pub category_weights: std::collections::BTreeMap<String, f32>,
50 pub min_score: f32,
52 pub diversity_bias: f32,
54 #[serde(default)]
60 pub epistemic_weight: f32,
61}
62
63pub trait Selector<T> {
68 fn select(
69 &self,
70 inputs: Vec<SelectorInput<T>>,
71 budget: usize,
72 weights: &SelectorWeights,
73 ) -> Result<SelectorOutput<T>, FoldError>;
74}
75
76#[derive(Debug, Clone, Copy, Default)]
91pub struct GreedySelector;
92
93#[inline]
98fn pragmatic_plus_epistemic<T>(item: &SelectorInput<T>, epistemic_weight: f32) -> f32 {
99 if epistemic_weight == 0.0 {
100 return item.score;
101 }
102 item.score + epistemic_weight * item.information_gain.unwrap_or(0.0)
103}
104
105fn effective_score<T>(
106 item: &SelectorInput<T>,
107 counts: &std::collections::BTreeMap<String, usize>,
108 bias: f32,
109 epistemic_weight: f32,
110) -> f32 {
111 let base = pragmatic_plus_epistemic(item, epistemic_weight);
112 if bias == 0.0 {
113 return base;
114 }
115 let count = item
116 .category
117 .as_ref()
118 .and_then(|c| counts.get(c))
119 .copied()
120 .unwrap_or(0);
121 base * (1.0 - bias * count as f32 / (count as f32 + 1.0))
122}
123
124impl<T: Clone> Selector<T> for GreedySelector {
125 fn select(
126 &self,
127 mut inputs: Vec<SelectorInput<T>>,
128 budget: usize,
129 weights: &SelectorWeights,
130 ) -> Result<SelectorOutput<T>, FoldError> {
131 inputs.retain(|i| i.score.is_finite() && i.score >= weights.min_score);
133
134 if !weights.category_weights.is_empty() {
136 for item in &mut inputs {
137 if let Some(ref cat) = item.category {
138 if let Some(&w) = weights.category_weights.get(cat.as_str()) {
139 item.score *= w.max(0.0);
140 }
141 }
142 }
143 inputs.retain(|i| i.score.is_finite() && i.score >= weights.min_score);
144 }
145
146 let ew = weights.epistemic_weight;
147
148 inputs.sort_by(|a, b| {
151 let a_eff = pragmatic_plus_epistemic(a, ew);
152 let b_eff = pragmatic_plus_epistemic(b, ew);
153 b_eff
154 .total_cmp(&a_eff)
155 .then_with(|| a.size.cmp(&b.size))
156 .then_with(|| a.id.cmp(&b.id))
157 });
158
159 let mut selected = Vec::new();
160 let mut total_size = 0usize;
161
162 if weights.diversity_bias == 0.0 {
163 for input in inputs {
165 if input.size <= budget.saturating_sub(total_size) {
166 total_size += input.size;
167 selected.push(input);
168 }
169 }
170 } else {
171 let mut remaining = inputs;
173 let mut category_counts: std::collections::BTreeMap<String, usize> =
174 std::collections::BTreeMap::new();
175
176 while !remaining.is_empty() && total_size < budget {
177 let best_idx = remaining
178 .iter()
179 .enumerate()
180 .filter(|(_, item)| item.size <= budget.saturating_sub(total_size))
181 .max_by(|(_, a), (_, b)| {
182 let a_eff =
183 effective_score(a, &category_counts, weights.diversity_bias, ew);
184 let b_eff =
185 effective_score(b, &category_counts, weights.diversity_bias, ew);
186 a_eff
187 .total_cmp(&b_eff)
188 .then_with(|| b.size.cmp(&a.size))
189 .then_with(|| a.id.cmp(&b.id))
190 })
191 .map(|(i, _)| i);
192
193 match best_idx {
194 Some(idx) => {
195 let item = remaining.swap_remove(idx);
196 if let Some(ref cat) = item.category {
197 *category_counts.entry(cat.clone()).or_default() += 1;
198 }
199 total_size += item.size;
200 selected.push(item);
201 }
202 None => break,
203 }
204 }
205 }
206
207 Ok(SelectorOutput {
208 selected,
209 total_size,
210 budget,
211 })
212 }
213}
214
215#[cfg(test)]
216mod tests {
217 use super::*;
218
219 fn input(id: &str, size: usize, score: f32) -> SelectorInput<()> {
220 SelectorInput {
221 id: id.to_string(),
222 content: (),
223 size,
224 score,
225 category: None,
226 information_gain: None,
227 }
228 }
229
230 fn input_cat(id: &str, size: usize, score: f32, cat: &str) -> SelectorInput<()> {
231 SelectorInput {
232 id: id.to_string(),
233 content: (),
234 size,
235 score,
236 category: Some(cat.to_string()),
237 information_gain: None,
238 }
239 }
240
241 fn weights(min_score: f32) -> SelectorWeights {
242 SelectorWeights {
243 min_score,
244 ..Default::default()
245 }
246 }
247
248 #[test]
249 fn empty_input() {
250 let inputs: Vec<SelectorInput<()>> = vec![];
251 let out = GreedySelector.select(inputs, 1000, &weights(0.0)).unwrap();
252 assert!(out.selected.is_empty());
253 assert_eq!(out.total_size, 0);
254 assert_eq!(out.budget, 1000);
255 }
256
257 #[test]
258 fn packs_highest_scores_first() {
259 let inputs = vec![
260 input("a", 100, 0.5),
261 input("b", 100, 0.9),
262 input("c", 100, 0.7),
263 ];
264 let out = GreedySelector.select(inputs, 200, &weights(0.0)).unwrap();
265 assert_eq!(out.selected.len(), 2);
266 assert_eq!(out.selected[0].id, "b");
267 assert_eq!(out.selected[1].id, "c");
268 assert_eq!(out.total_size, 200);
269 }
270
271 #[test]
272 fn respects_budget() {
273 let inputs = vec![
274 input("a", 300, 0.9),
275 input("b", 300, 0.8),
276 input("c", 300, 0.7),
277 ];
278 let out = GreedySelector.select(inputs, 500, &weights(0.0)).unwrap();
279 assert_eq!(out.selected.len(), 1);
280 assert_eq!(out.selected[0].id, "a");
281 assert_eq!(out.total_size, 300);
282 }
283
284 #[test]
285 fn filters_below_min_score() {
286 let inputs = vec![
287 input("a", 10, 0.8),
288 input("b", 10, 0.1),
289 input("c", 10, 0.5),
290 ];
291 let out = GreedySelector.select(inputs, 1000, &weights(0.3)).unwrap();
292 assert_eq!(out.selected.len(), 2);
293 assert_eq!(out.selected[0].id, "a");
294 assert_eq!(out.selected[1].id, "c");
295 }
296
297 #[test]
298 fn filters_nan_and_inf() {
299 let inputs = vec![
300 input("nan", 10, f32::NAN),
301 input("inf", 10, f32::INFINITY),
302 input("neg_inf", 10, f32::NEG_INFINITY),
303 input("ok", 10, 0.5),
304 ];
305 let out = GreedySelector.select(inputs, 1000, &weights(0.0)).unwrap();
306 assert_eq!(out.selected.len(), 1);
307 assert_eq!(out.selected[0].id, "ok");
308 }
309
310 #[test]
311 fn tie_break_size_ascending() {
312 let inputs = vec![input("big", 200, 0.5), input("small", 50, 0.5)];
313 let out = GreedySelector.select(inputs, 1000, &weights(0.0)).unwrap();
314 assert_eq!(out.selected[0].id, "small");
315 assert_eq!(out.selected[1].id, "big");
316 }
317
318 #[test]
319 fn tie_break_id_ascending() {
320 let inputs = vec![input("z", 100, 0.5), input("a", 100, 0.5)];
321 let out = GreedySelector.select(inputs, 1000, &weights(0.0)).unwrap();
322 assert_eq!(out.selected[0].id, "a");
323 assert_eq!(out.selected[1].id, "z");
324 }
325
326 #[test]
327 fn skips_oversized_items_takes_smaller() {
328 let inputs = vec![
329 input("huge", 900, 0.9),
330 input("small1", 40, 0.3),
331 input("small2", 40, 0.2),
332 ];
333 let out = GreedySelector.select(inputs, 100, &weights(0.0)).unwrap();
334 assert_eq!(out.selected.len(), 2);
335 assert_eq!(out.selected[0].id, "small1");
336 assert_eq!(out.selected[1].id, "small2");
337 assert_eq!(out.total_size, 80);
338 }
339
340 #[test]
341 fn zero_budget() {
342 let inputs = vec![input("a", 1, 0.9)];
343 let out = GreedySelector.select(inputs, 0, &weights(0.0)).unwrap();
344 assert!(out.selected.is_empty());
345 }
346
347 #[test]
348 fn deterministic_across_input_order() {
349 let a = vec![
350 input("x", 50, 0.7),
351 input("y", 50, 0.7),
352 input("z", 50, 0.7),
353 ];
354 let b = vec![
355 input("z", 50, 0.7),
356 input("x", 50, 0.7),
357 input("y", 50, 0.7),
358 ];
359 let out_a = GreedySelector.select(a, 100, &weights(0.0)).unwrap();
360 let out_b = GreedySelector.select(b, 100, &weights(0.0)).unwrap();
361 let ids_a: Vec<&str> = out_a.selected.iter().map(|i| i.id.as_str()).collect();
362 let ids_b: Vec<&str> = out_b.selected.iter().map(|i| i.id.as_str()).collect();
363 assert_eq!(ids_a, ids_b);
364 assert_eq!(ids_a, vec!["x", "y"]);
365 }
366
367 #[test]
368 fn exact_budget_fit() {
369 let inputs = vec![input("a", 50, 0.9), input("b", 50, 0.8)];
370 let out = GreedySelector.select(inputs, 100, &weights(0.0)).unwrap();
371 assert_eq!(out.selected.len(), 2);
372 assert_eq!(out.total_size, 100);
373 }
374
375 #[test]
376 fn category_weights_boost_preferred_category() {
377 let inputs = vec![
378 input_cat("a", 100, 0.9, "low"),
379 input_cat("b", 100, 0.5, "high"),
380 ];
381 let w = SelectorWeights {
382 category_weights: [("high".to_string(), 2.0f32), ("low".to_string(), 1.0f32)]
383 .into_iter()
384 .collect(),
385 ..Default::default()
386 };
387 let out = GreedySelector.select(inputs, 100, &w).unwrap();
388 assert_eq!(out.selected.len(), 1);
389 assert_eq!(out.selected[0].id, "b");
390 }
391
392 #[test]
393 fn category_weights_can_push_below_min_score() {
394 let inputs = vec![
395 input_cat("a", 10, 0.4, "bad"),
396 input_cat("b", 10, 0.8, "good"),
397 ];
398 let w = SelectorWeights {
399 min_score: 0.3,
400 category_weights: [("bad".to_string(), 0.5f32)].into_iter().collect(),
401 ..Default::default()
402 };
403 let out = GreedySelector.select(inputs, 1000, &w).unwrap();
404 assert_eq!(out.selected.len(), 1);
405 assert_eq!(out.selected[0].id, "b");
406 }
407
408 #[test]
409 fn diversity_bias_zero_identical_to_greedy() {
410 let make = || {
411 vec![
412 input_cat("a", 100, 0.9, "x"),
413 input_cat("b", 100, 0.8, "x"),
414 input_cat("c", 100, 0.7, "y"),
415 ]
416 };
417 let w_greedy = SelectorWeights {
418 ..Default::default()
419 };
420 let w_bias0 = SelectorWeights {
421 diversity_bias: 0.0,
422 ..Default::default()
423 };
424 let out_g = GreedySelector.select(make(), 200, &w_greedy).unwrap();
425 let out_b = GreedySelector.select(make(), 200, &w_bias0).unwrap();
426 let ids_g: Vec<&str> = out_g.selected.iter().map(|i| i.id.as_str()).collect();
427 let ids_b: Vec<&str> = out_b.selected.iter().map(|i| i.id.as_str()).collect();
428 assert_eq!(ids_g, ids_b);
429 }
430
431 #[test]
432 fn diversity_bias_prefers_different_categories() {
433 let inputs = vec![
434 input_cat("a", 100, 0.9, "x"),
435 input_cat("b", 100, 0.8, "x"),
436 input_cat("c", 100, 0.7, "y"),
437 ];
438 let w = SelectorWeights {
439 diversity_bias: 1.0,
440 ..Default::default()
441 };
442 let out = GreedySelector.select(inputs, 200, &w).unwrap();
443 assert_eq!(out.selected.len(), 2);
444 let ids: Vec<&str> = out.selected.iter().map(|i| i.id.as_str()).collect();
445 assert!(ids.contains(&"a"), "a should always be selected");
446 assert!(
447 ids.contains(&"c"),
448 "c should be preferred over b due to diversity"
449 );
450 }
451
452 #[test]
453 fn no_overflow_near_usize_max() {
454 let large = usize::MAX - 1;
456 let inputs = vec![
457 SelectorInput {
458 id: "a".to_string(),
459 content: (),
460 size: large,
461 score: 0.9,
462 category: None,
463 information_gain: None,
464 },
465 SelectorInput {
466 id: "b".to_string(),
467 content: (),
468 size: 10,
469 score: 0.8,
470 category: None,
471 information_gain: None,
472 },
473 ];
474 let out = GreedySelector.select(inputs, 100, &weights(0.0)).unwrap();
476 assert_eq!(out.selected.len(), 1);
477 assert_eq!(out.selected[0].id, "b");
478 }
479
480 #[test]
481 fn diversity_bias_no_categories_unaffected() {
482 let inputs = vec![
483 input("a", 100, 0.9),
484 input("b", 100, 0.8),
485 input("c", 100, 0.7),
486 ];
487 let w = SelectorWeights {
488 diversity_bias: 1.0,
489 ..Default::default()
490 };
491 let out = GreedySelector.select(inputs, 200, &w).unwrap();
492 assert_eq!(out.selected.len(), 2);
493 assert_eq!(out.selected[0].id, "a");
494 assert_eq!(out.selected[1].id, "b");
495 }
496
497 fn input_with_gain(id: &str, size: usize, score: f32, gain: f32) -> SelectorInput<()> {
500 SelectorInput {
501 id: id.to_string(),
502 content: (),
503 size,
504 score,
505 category: None,
506 information_gain: Some(gain),
507 }
508 }
509
510 #[test]
511 fn epistemic_weight_zero_preserves_behavior() {
512 let make = || {
514 vec![
515 input_with_gain("a", 100, 0.9, 10.0),
516 input_with_gain("b", 100, 0.8, 0.0),
517 input_with_gain("c", 100, 0.7, 5.0),
518 ]
519 };
520 let w_default = SelectorWeights {
521 ..Default::default()
522 };
523 let w_zero = SelectorWeights {
524 epistemic_weight: 0.0,
525 ..Default::default()
526 };
527 let out_d = GreedySelector.select(make(), 200, &w_default).unwrap();
528 let out_z = GreedySelector.select(make(), 200, &w_zero).unwrap();
529 let ids_d: Vec<&str> = out_d.selected.iter().map(|i| i.id.as_str()).collect();
530 let ids_z: Vec<&str> = out_z.selected.iter().map(|i| i.id.as_str()).collect();
531 assert_eq!(ids_d, ids_z);
532 assert_eq!(ids_d, vec!["a", "b"]);
534 }
535
536 #[test]
537 fn epistemic_weight_positive_reorders_by_gain() {
538 let inputs = vec![
542 input_with_gain("a", 100, 0.5, 10.0),
543 input_with_gain("b", 100, 0.9, 0.0),
544 ];
545 let w = SelectorWeights {
546 epistemic_weight: 1.0,
547 ..Default::default()
548 };
549 let out = GreedySelector.select(inputs, 100, &w).unwrap();
550 assert_eq!(out.selected.len(), 1);
551 assert_eq!(out.selected[0].id, "a");
552 }
553
554 #[test]
555 fn information_gain_none_equivalent_to_zero() {
556 let with_none = vec![
558 input("a", 100, 0.9), input("b", 100, 0.8),
560 ];
561 let with_zero = vec![
562 input_with_gain("a", 100, 0.9, 0.0),
563 input_with_gain("b", 100, 0.8, 0.0),
564 ];
565 let w = SelectorWeights {
566 epistemic_weight: 1.0,
567 ..Default::default()
568 };
569 let out_none = GreedySelector.select(with_none, 200, &w).unwrap();
570 let out_zero = GreedySelector.select(with_zero, 200, &w).unwrap();
571 let ids_none: Vec<&str> = out_none.selected.iter().map(|i| i.id.as_str()).collect();
572 let ids_zero: Vec<&str> = out_zero.selected.iter().map(|i| i.id.as_str()).collect();
573 assert_eq!(ids_none, ids_zero);
574 }
575
576 #[test]
577 fn epistemic_weight_works_with_diversity_bias() {
578 let inputs = vec![
585 {
586 let mut i = input_with_gain("a", 100, 0.5, 10.0);
587 i.category = Some("x".to_string());
588 i
589 },
590 {
591 let mut i = input_with_gain("b", 100, 0.8, 0.0);
592 i.category = Some("x".to_string());
593 i
594 },
595 {
596 let mut i = input_with_gain("c", 100, 0.3, 0.0);
597 i.category = Some("y".to_string());
598 i
599 },
600 ];
601 let w = SelectorWeights {
602 epistemic_weight: 1.0,
603 diversity_bias: 0.5,
604 ..Default::default()
605 };
606 let out = GreedySelector.select(inputs, 200, &w).unwrap();
607 assert_eq!(out.selected.len(), 2);
608 assert_eq!(out.selected[0].id, "a");
609 assert_eq!(out.selected[1].id, "b");
611 }
612}