1#[cfg(feature = "serde")]
7use serde::{Deserialize, Serialize};
8
9use crate::error::FoldError;
10
11#[derive(Debug, Clone)]
13#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
14pub struct SelectorInput<T> {
15 pub id: String,
17 pub content: T,
19 pub size: usize,
21 pub score: f32,
23 #[cfg_attr(feature = "serde", serde(default))]
25 pub category: Option<String>,
26 #[cfg_attr(feature = "serde", serde(default))]
33 pub information_gain: Option<f32>,
34}
35
36#[derive(Debug, Clone)]
38#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
39pub struct SelectorOutput<T> {
40 pub selected: Vec<SelectorInput<T>>,
42 pub total_size: usize,
44 pub budget: usize,
46}
47
48#[derive(Debug, Clone, Default)]
52#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
53pub struct SelectorWeights {
54 pub category_weights: std::collections::BTreeMap<String, f32>,
56 pub min_score: f32,
58 pub diversity_bias: f32,
60 #[cfg_attr(feature = "serde", serde(default))]
66 pub epistemic_weight: f32,
67}
68
69pub trait Selector<T>: Send + Sync {
74 fn select(
76 &self,
77 inputs: Vec<SelectorInput<T>>,
78 budget: usize,
79 weights: &SelectorWeights,
80 ) -> Result<SelectorOutput<T>, FoldError>;
81}
82
83#[derive(Debug, Clone, Copy, Default)]
98pub struct GreedySelector;
99
100#[inline]
105fn pragmatic_plus_epistemic<T>(item: &SelectorInput<T>, epistemic_weight: f32) -> f32 {
106 if epistemic_weight == 0.0 {
107 return item.score;
108 }
109 item.score + epistemic_weight * item.information_gain.unwrap_or(0.0)
110}
111
112fn effective_score<T>(
113 item: &SelectorInput<T>,
114 counts: &std::collections::BTreeMap<String, usize>,
115 bias: f32,
116 epistemic_weight: f32,
117) -> f32 {
118 let base = pragmatic_plus_epistemic(item, epistemic_weight);
119 if bias == 0.0 {
120 return base;
121 }
122 let count = item
123 .category
124 .as_ref()
125 .and_then(|c| counts.get(c))
126 .copied()
127 .unwrap_or(0);
128 base * (1.0 - bias * count as f32 / (count as f32 + 1.0))
129}
130
131impl<T: Clone> Selector<T> for GreedySelector {
132 fn select(
133 &self,
134 mut inputs: Vec<SelectorInput<T>>,
135 budget: usize,
136 weights: &SelectorWeights,
137 ) -> Result<SelectorOutput<T>, FoldError> {
138 inputs.retain(|i| i.score.is_finite() && i.score >= weights.min_score);
140
141 if !weights.category_weights.is_empty() {
143 for item in &mut inputs {
144 if let Some(ref cat) = item.category {
145 if let Some(&w) = weights.category_weights.get(cat.as_str()) {
146 item.score *= w.max(0.0);
147 }
148 }
149 }
150 inputs.retain(|i| i.score.is_finite() && i.score >= weights.min_score);
151 }
152
153 let ew = weights.epistemic_weight;
154
155 inputs.sort_by(|a, b| {
158 let a_eff = pragmatic_plus_epistemic(a, ew);
159 let b_eff = pragmatic_plus_epistemic(b, ew);
160 b_eff
161 .total_cmp(&a_eff)
162 .then_with(|| a.size.cmp(&b.size))
163 .then_with(|| a.id.cmp(&b.id))
164 });
165
166 let mut selected = Vec::new();
167 let mut total_size = 0usize;
168
169 if weights.diversity_bias == 0.0 {
170 for input in inputs {
172 if input.size <= budget.saturating_sub(total_size) {
173 total_size += input.size;
174 selected.push(input);
175 }
176 }
177 } else {
178 let mut remaining = inputs;
180 let mut category_counts: std::collections::BTreeMap<String, usize> =
181 std::collections::BTreeMap::new();
182
183 while !remaining.is_empty() && total_size < budget {
184 let best_idx = remaining
185 .iter()
186 .enumerate()
187 .filter(|(_, item)| item.size <= budget.saturating_sub(total_size))
188 .max_by(|(_, a), (_, b)| {
189 let a_eff =
190 effective_score(a, &category_counts, weights.diversity_bias, ew);
191 let b_eff =
192 effective_score(b, &category_counts, weights.diversity_bias, ew);
193 a_eff
194 .total_cmp(&b_eff)
195 .then_with(|| b.size.cmp(&a.size))
196 .then_with(|| a.id.cmp(&b.id))
197 })
198 .map(|(i, _)| i);
199
200 match best_idx {
201 Some(idx) => {
202 let item = remaining.swap_remove(idx);
203 if let Some(ref cat) = item.category {
204 *category_counts.entry(cat.clone()).or_default() += 1;
205 }
206 total_size += item.size;
207 selected.push(item);
208 }
209 None => break,
210 }
211 }
212 }
213
214 Ok(SelectorOutput {
215 selected,
216 total_size,
217 budget,
218 })
219 }
220}
221
222#[cfg(test)]
227mod tests {
228 use super::*;
229
230 fn input(id: &str, size: usize, score: f32) -> SelectorInput<()> {
231 SelectorInput {
232 id: id.to_string(),
233 content: (),
234 size,
235 score,
236 category: None,
237 information_gain: None,
238 }
239 }
240
241 fn input_cat(id: &str, size: usize, score: f32, cat: &str) -> SelectorInput<()> {
242 SelectorInput {
243 id: id.to_string(),
244 content: (),
245 size,
246 score,
247 category: Some(cat.to_string()),
248 information_gain: None,
249 }
250 }
251
252 fn weights(min_score: f32) -> SelectorWeights {
253 SelectorWeights {
254 min_score,
255 ..Default::default()
256 }
257 }
258
259 #[test]
260 fn empty_input() {
261 let inputs: Vec<SelectorInput<()>> = vec![];
262 let out = GreedySelector.select(inputs, 1000, &weights(0.0)).unwrap();
263 assert!(out.selected.is_empty());
264 assert_eq!(out.total_size, 0);
265 assert_eq!(out.budget, 1000);
266 }
267
268 #[test]
269 fn packs_highest_scores_first() {
270 let inputs = vec![
271 input("a", 100, 0.5),
272 input("b", 100, 0.9),
273 input("c", 100, 0.7),
274 ];
275 let out = GreedySelector.select(inputs, 200, &weights(0.0)).unwrap();
276 assert_eq!(out.selected.len(), 2);
277 assert_eq!(out.selected[0].id, "b");
278 assert_eq!(out.selected[1].id, "c");
279 assert_eq!(out.total_size, 200);
280 }
281
282 #[test]
283 fn respects_budget() {
284 let inputs = vec![
285 input("a", 300, 0.9),
286 input("b", 300, 0.8),
287 input("c", 300, 0.7),
288 ];
289 let out = GreedySelector.select(inputs, 500, &weights(0.0)).unwrap();
290 assert_eq!(out.selected.len(), 1);
291 assert_eq!(out.selected[0].id, "a");
292 assert_eq!(out.total_size, 300);
293 }
294
295 #[test]
296 fn filters_below_min_score() {
297 let inputs = vec![
298 input("a", 10, 0.8),
299 input("b", 10, 0.1),
300 input("c", 10, 0.5),
301 ];
302 let out = GreedySelector.select(inputs, 1000, &weights(0.3)).unwrap();
303 assert_eq!(out.selected.len(), 2);
304 assert_eq!(out.selected[0].id, "a");
305 assert_eq!(out.selected[1].id, "c");
306 }
307
308 #[test]
309 fn filters_nan_and_inf() {
310 let inputs = vec![
311 input("nan", 10, f32::NAN),
312 input("inf", 10, f32::INFINITY),
313 input("neg_inf", 10, f32::NEG_INFINITY),
314 input("ok", 10, 0.5),
315 ];
316 let out = GreedySelector.select(inputs, 1000, &weights(0.0)).unwrap();
317 assert_eq!(out.selected.len(), 1);
318 assert_eq!(out.selected[0].id, "ok");
319 }
320
321 #[test]
322 fn tie_break_size_ascending() {
323 let inputs = vec![input("big", 200, 0.5), input("small", 50, 0.5)];
324 let out = GreedySelector.select(inputs, 1000, &weights(0.0)).unwrap();
325 assert_eq!(out.selected[0].id, "small");
326 assert_eq!(out.selected[1].id, "big");
327 }
328
329 #[test]
330 fn tie_break_id_ascending() {
331 let inputs = vec![input("z", 100, 0.5), input("a", 100, 0.5)];
332 let out = GreedySelector.select(inputs, 1000, &weights(0.0)).unwrap();
333 assert_eq!(out.selected[0].id, "a");
334 assert_eq!(out.selected[1].id, "z");
335 }
336
337 #[test]
338 fn skips_oversized_items_takes_smaller() {
339 let inputs = vec![
340 input("huge", 900, 0.9),
341 input("small1", 40, 0.3),
342 input("small2", 40, 0.2),
343 ];
344 let out = GreedySelector.select(inputs, 100, &weights(0.0)).unwrap();
345 assert_eq!(out.selected.len(), 2);
346 assert_eq!(out.selected[0].id, "small1");
347 assert_eq!(out.selected[1].id, "small2");
348 assert_eq!(out.total_size, 80);
349 }
350
351 #[test]
352 fn zero_budget() {
353 let inputs = vec![input("a", 1, 0.9)];
354 let out = GreedySelector.select(inputs, 0, &weights(0.0)).unwrap();
355 assert!(out.selected.is_empty());
356 }
357
358 #[test]
359 fn deterministic_across_input_order() {
360 let a = vec![
361 input("x", 50, 0.7),
362 input("y", 50, 0.7),
363 input("z", 50, 0.7),
364 ];
365 let b = vec![
366 input("z", 50, 0.7),
367 input("x", 50, 0.7),
368 input("y", 50, 0.7),
369 ];
370 let out_a = GreedySelector.select(a, 100, &weights(0.0)).unwrap();
371 let out_b = GreedySelector.select(b, 100, &weights(0.0)).unwrap();
372 let ids_a: Vec<&str> = out_a.selected.iter().map(|i| i.id.as_str()).collect();
373 let ids_b: Vec<&str> = out_b.selected.iter().map(|i| i.id.as_str()).collect();
374 assert_eq!(ids_a, ids_b);
375 assert_eq!(ids_a, vec!["x", "y"]);
376 }
377
378 #[test]
379 fn exact_budget_fit() {
380 let inputs = vec![input("a", 50, 0.9), input("b", 50, 0.8)];
381 let out = GreedySelector.select(inputs, 100, &weights(0.0)).unwrap();
382 assert_eq!(out.selected.len(), 2);
383 assert_eq!(out.total_size, 100);
384 }
385
386 #[test]
387 fn category_weights_boost_preferred_category() {
388 let inputs = vec![
389 input_cat("a", 100, 0.9, "low"),
390 input_cat("b", 100, 0.5, "high"),
391 ];
392 let w = SelectorWeights {
393 category_weights: [("high".to_string(), 2.0f32), ("low".to_string(), 1.0f32)]
394 .into_iter()
395 .collect(),
396 ..Default::default()
397 };
398 let out = GreedySelector.select(inputs, 100, &w).unwrap();
399 assert_eq!(out.selected.len(), 1);
400 assert_eq!(out.selected[0].id, "b");
401 }
402
403 #[test]
404 fn category_weights_can_push_below_min_score() {
405 let inputs = vec![
406 input_cat("a", 10, 0.4, "bad"),
407 input_cat("b", 10, 0.8, "good"),
408 ];
409 let w = SelectorWeights {
410 min_score: 0.3,
411 category_weights: [("bad".to_string(), 0.5f32)].into_iter().collect(),
412 ..Default::default()
413 };
414 let out = GreedySelector.select(inputs, 1000, &w).unwrap();
415 assert_eq!(out.selected.len(), 1);
416 assert_eq!(out.selected[0].id, "b");
417 }
418
419 #[test]
420 fn diversity_bias_zero_identical_to_greedy() {
421 let make = || {
422 vec![
423 input_cat("a", 100, 0.9, "x"),
424 input_cat("b", 100, 0.8, "x"),
425 input_cat("c", 100, 0.7, "y"),
426 ]
427 };
428 let w_greedy = SelectorWeights {
429 ..Default::default()
430 };
431 let w_bias0 = SelectorWeights {
432 diversity_bias: 0.0,
433 ..Default::default()
434 };
435 let out_g = GreedySelector.select(make(), 200, &w_greedy).unwrap();
436 let out_b = GreedySelector.select(make(), 200, &w_bias0).unwrap();
437 let ids_g: Vec<&str> = out_g.selected.iter().map(|i| i.id.as_str()).collect();
438 let ids_b: Vec<&str> = out_b.selected.iter().map(|i| i.id.as_str()).collect();
439 assert_eq!(ids_g, ids_b);
440 }
441
442 #[test]
443 fn diversity_bias_prefers_different_categories() {
444 let inputs = vec![
445 input_cat("a", 100, 0.9, "x"),
446 input_cat("b", 100, 0.8, "x"),
447 input_cat("c", 100, 0.7, "y"),
448 ];
449 let w = SelectorWeights {
450 diversity_bias: 1.0,
451 ..Default::default()
452 };
453 let out = GreedySelector.select(inputs, 200, &w).unwrap();
454 assert_eq!(out.selected.len(), 2);
455 let ids: Vec<&str> = out.selected.iter().map(|i| i.id.as_str()).collect();
456 assert!(ids.contains(&"a"), "a should always be selected");
457 assert!(
458 ids.contains(&"c"),
459 "c should be preferred over b due to diversity"
460 );
461 }
462
463 #[test]
464 fn no_overflow_near_usize_max() {
465 let large = usize::MAX - 1;
467 let inputs = vec![
468 SelectorInput {
469 id: "a".to_string(),
470 content: (),
471 size: large,
472 score: 0.9,
473 category: None,
474 information_gain: None,
475 },
476 SelectorInput {
477 id: "b".to_string(),
478 content: (),
479 size: 10,
480 score: 0.8,
481 category: None,
482 information_gain: None,
483 },
484 ];
485 let out = GreedySelector.select(inputs, 100, &weights(0.0)).unwrap();
487 assert_eq!(out.selected.len(), 1);
488 assert_eq!(out.selected[0].id, "b");
489 }
490
491 #[test]
492 fn diversity_bias_no_categories_unaffected() {
493 let inputs = vec![
494 input("a", 100, 0.9),
495 input("b", 100, 0.8),
496 input("c", 100, 0.7),
497 ];
498 let w = SelectorWeights {
499 diversity_bias: 1.0,
500 ..Default::default()
501 };
502 let out = GreedySelector.select(inputs, 200, &w).unwrap();
503 assert_eq!(out.selected.len(), 2);
504 assert_eq!(out.selected[0].id, "a");
505 assert_eq!(out.selected[1].id, "b");
506 }
507
508 fn input_with_gain(id: &str, size: usize, score: f32, gain: f32) -> SelectorInput<()> {
511 SelectorInput {
512 id: id.to_string(),
513 content: (),
514 size,
515 score,
516 category: None,
517 information_gain: Some(gain),
518 }
519 }
520
521 #[test]
522 fn epistemic_weight_zero_preserves_behavior() {
523 let make = || {
525 vec![
526 input_with_gain("a", 100, 0.9, 10.0),
527 input_with_gain("b", 100, 0.8, 0.0),
528 input_with_gain("c", 100, 0.7, 5.0),
529 ]
530 };
531 let w_default = SelectorWeights {
532 ..Default::default()
533 };
534 let w_zero = SelectorWeights {
535 epistemic_weight: 0.0,
536 ..Default::default()
537 };
538 let out_d = GreedySelector.select(make(), 200, &w_default).unwrap();
539 let out_z = GreedySelector.select(make(), 200, &w_zero).unwrap();
540 let ids_d: Vec<&str> = out_d.selected.iter().map(|i| i.id.as_str()).collect();
541 let ids_z: Vec<&str> = out_z.selected.iter().map(|i| i.id.as_str()).collect();
542 assert_eq!(ids_d, ids_z);
543 assert_eq!(ids_d, vec!["a", "b"]);
545 }
546
547 #[test]
548 fn epistemic_weight_positive_reorders_by_gain() {
549 let inputs = vec![
553 input_with_gain("a", 100, 0.5, 10.0),
554 input_with_gain("b", 100, 0.9, 0.0),
555 ];
556 let w = SelectorWeights {
557 epistemic_weight: 1.0,
558 ..Default::default()
559 };
560 let out = GreedySelector.select(inputs, 100, &w).unwrap();
561 assert_eq!(out.selected.len(), 1);
562 assert_eq!(out.selected[0].id, "a");
563 }
564
565 #[test]
566 fn information_gain_none_equivalent_to_zero() {
567 let with_none = vec![
569 input("a", 100, 0.9), input("b", 100, 0.8),
571 ];
572 let with_zero = vec![
573 input_with_gain("a", 100, 0.9, 0.0),
574 input_with_gain("b", 100, 0.8, 0.0),
575 ];
576 let w = SelectorWeights {
577 epistemic_weight: 1.0,
578 ..Default::default()
579 };
580 let out_none = GreedySelector.select(with_none, 200, &w).unwrap();
581 let out_zero = GreedySelector.select(with_zero, 200, &w).unwrap();
582 let ids_none: Vec<&str> = out_none.selected.iter().map(|i| i.id.as_str()).collect();
583 let ids_zero: Vec<&str> = out_zero.selected.iter().map(|i| i.id.as_str()).collect();
584 assert_eq!(ids_none, ids_zero);
585 }
586
587 #[test]
588 fn epistemic_weight_works_with_diversity_bias() {
589 let inputs = vec![
596 {
597 let mut i = input_with_gain("a", 100, 0.5, 10.0);
598 i.category = Some("x".to_string());
599 i
600 },
601 {
602 let mut i = input_with_gain("b", 100, 0.8, 0.0);
603 i.category = Some("x".to_string());
604 i
605 },
606 {
607 let mut i = input_with_gain("c", 100, 0.3, 0.0);
608 i.category = Some("y".to_string());
609 i
610 },
611 ];
612 let w = SelectorWeights {
613 epistemic_weight: 1.0,
614 diversity_bias: 0.5,
615 ..Default::default()
616 };
617 let out = GreedySelector.select(inputs, 200, &w).unwrap();
618 assert_eq!(out.selected.len(), 2);
619 assert_eq!(out.selected[0].id, "a");
620 assert_eq!(out.selected[1].id, "b");
622 }
623}