1#[cfg(feature = "serde")]
7use serde::{Deserialize, Serialize};
8
9use crate::error::FoldError;
10
11#[derive(Debug, Clone)]
13#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
14pub struct SelectorInput<T> {
15 pub id: String,
16 pub content: T,
17 pub size: usize,
19 pub score: f32,
21 #[cfg_attr(feature = "serde", serde(default))]
23 pub category: Option<String>,
24 #[cfg_attr(feature = "serde", serde(default))]
31 pub information_gain: Option<f32>,
32}
33
34#[derive(Debug, Clone)]
36#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
37pub struct SelectorOutput<T> {
38 pub selected: Vec<SelectorInput<T>>,
40 pub total_size: usize,
42 pub budget: usize,
44}
45
46#[derive(Debug, Clone, Default)]
50#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
51pub struct SelectorWeights {
52 pub category_weights: std::collections::BTreeMap<String, f32>,
54 pub min_score: f32,
56 pub diversity_bias: f32,
58 #[cfg_attr(feature = "serde", serde(default))]
64 pub epistemic_weight: f32,
65}
66
67pub trait Selector<T>: Send + Sync {
72 fn select(
73 &self,
74 inputs: Vec<SelectorInput<T>>,
75 budget: usize,
76 weights: &SelectorWeights,
77 ) -> Result<SelectorOutput<T>, FoldError>;
78}
79
80#[derive(Debug, Clone, Copy, Default)]
95pub struct GreedySelector;
96
97#[inline]
102fn pragmatic_plus_epistemic<T>(item: &SelectorInput<T>, epistemic_weight: f32) -> f32 {
103 if epistemic_weight == 0.0 {
104 return item.score;
105 }
106 item.score + epistemic_weight * item.information_gain.unwrap_or(0.0)
107}
108
109fn effective_score<T>(
110 item: &SelectorInput<T>,
111 counts: &std::collections::BTreeMap<String, usize>,
112 bias: f32,
113 epistemic_weight: f32,
114) -> f32 {
115 let base = pragmatic_plus_epistemic(item, epistemic_weight);
116 if bias == 0.0 {
117 return base;
118 }
119 let count = item
120 .category
121 .as_ref()
122 .and_then(|c| counts.get(c))
123 .copied()
124 .unwrap_or(0);
125 base * (1.0 - bias * count as f32 / (count as f32 + 1.0))
126}
127
128impl<T: Clone> Selector<T> for GreedySelector {
129 fn select(
130 &self,
131 mut inputs: Vec<SelectorInput<T>>,
132 budget: usize,
133 weights: &SelectorWeights,
134 ) -> Result<SelectorOutput<T>, FoldError> {
135 inputs.retain(|i| i.score.is_finite() && i.score >= weights.min_score);
137
138 if !weights.category_weights.is_empty() {
140 for item in &mut inputs {
141 if let Some(ref cat) = item.category {
142 if let Some(&w) = weights.category_weights.get(cat.as_str()) {
143 item.score *= w.max(0.0);
144 }
145 }
146 }
147 inputs.retain(|i| i.score.is_finite() && i.score >= weights.min_score);
148 }
149
150 let ew = weights.epistemic_weight;
151
152 inputs.sort_by(|a, b| {
155 let a_eff = pragmatic_plus_epistemic(a, ew);
156 let b_eff = pragmatic_plus_epistemic(b, ew);
157 b_eff
158 .total_cmp(&a_eff)
159 .then_with(|| a.size.cmp(&b.size))
160 .then_with(|| a.id.cmp(&b.id))
161 });
162
163 let mut selected = Vec::new();
164 let mut total_size = 0usize;
165
166 if weights.diversity_bias == 0.0 {
167 for input in inputs {
169 if input.size <= budget.saturating_sub(total_size) {
170 total_size += input.size;
171 selected.push(input);
172 }
173 }
174 } else {
175 let mut remaining = inputs;
177 let mut category_counts: std::collections::BTreeMap<String, usize> =
178 std::collections::BTreeMap::new();
179
180 while !remaining.is_empty() && total_size < budget {
181 let best_idx = remaining
182 .iter()
183 .enumerate()
184 .filter(|(_, item)| item.size <= budget.saturating_sub(total_size))
185 .max_by(|(_, a), (_, b)| {
186 let a_eff =
187 effective_score(a, &category_counts, weights.diversity_bias, ew);
188 let b_eff =
189 effective_score(b, &category_counts, weights.diversity_bias, ew);
190 a_eff
191 .total_cmp(&b_eff)
192 .then_with(|| b.size.cmp(&a.size))
193 .then_with(|| a.id.cmp(&b.id))
194 })
195 .map(|(i, _)| i);
196
197 match best_idx {
198 Some(idx) => {
199 let item = remaining.swap_remove(idx);
200 if let Some(ref cat) = item.category {
201 *category_counts.entry(cat.clone()).or_default() += 1;
202 }
203 total_size += item.size;
204 selected.push(item);
205 }
206 None => break,
207 }
208 }
209 }
210
211 Ok(SelectorOutput {
212 selected,
213 total_size,
214 budget,
215 })
216 }
217}
218
219#[cfg(test)]
220mod tests {
221 use super::*;
222
223 fn input(id: &str, size: usize, score: f32) -> SelectorInput<()> {
224 SelectorInput {
225 id: id.to_string(),
226 content: (),
227 size,
228 score,
229 category: None,
230 information_gain: None,
231 }
232 }
233
234 fn input_cat(id: &str, size: usize, score: f32, cat: &str) -> SelectorInput<()> {
235 SelectorInput {
236 id: id.to_string(),
237 content: (),
238 size,
239 score,
240 category: Some(cat.to_string()),
241 information_gain: None,
242 }
243 }
244
245 fn weights(min_score: f32) -> SelectorWeights {
246 SelectorWeights {
247 min_score,
248 ..Default::default()
249 }
250 }
251
252 #[test]
253 fn empty_input() {
254 let inputs: Vec<SelectorInput<()>> = vec![];
255 let out = GreedySelector.select(inputs, 1000, &weights(0.0)).unwrap();
256 assert!(out.selected.is_empty());
257 assert_eq!(out.total_size, 0);
258 assert_eq!(out.budget, 1000);
259 }
260
261 #[test]
262 fn packs_highest_scores_first() {
263 let inputs = vec![
264 input("a", 100, 0.5),
265 input("b", 100, 0.9),
266 input("c", 100, 0.7),
267 ];
268 let out = GreedySelector.select(inputs, 200, &weights(0.0)).unwrap();
269 assert_eq!(out.selected.len(), 2);
270 assert_eq!(out.selected[0].id, "b");
271 assert_eq!(out.selected[1].id, "c");
272 assert_eq!(out.total_size, 200);
273 }
274
275 #[test]
276 fn respects_budget() {
277 let inputs = vec![
278 input("a", 300, 0.9),
279 input("b", 300, 0.8),
280 input("c", 300, 0.7),
281 ];
282 let out = GreedySelector.select(inputs, 500, &weights(0.0)).unwrap();
283 assert_eq!(out.selected.len(), 1);
284 assert_eq!(out.selected[0].id, "a");
285 assert_eq!(out.total_size, 300);
286 }
287
288 #[test]
289 fn filters_below_min_score() {
290 let inputs = vec![
291 input("a", 10, 0.8),
292 input("b", 10, 0.1),
293 input("c", 10, 0.5),
294 ];
295 let out = GreedySelector.select(inputs, 1000, &weights(0.3)).unwrap();
296 assert_eq!(out.selected.len(), 2);
297 assert_eq!(out.selected[0].id, "a");
298 assert_eq!(out.selected[1].id, "c");
299 }
300
301 #[test]
302 fn filters_nan_and_inf() {
303 let inputs = vec![
304 input("nan", 10, f32::NAN),
305 input("inf", 10, f32::INFINITY),
306 input("neg_inf", 10, f32::NEG_INFINITY),
307 input("ok", 10, 0.5),
308 ];
309 let out = GreedySelector.select(inputs, 1000, &weights(0.0)).unwrap();
310 assert_eq!(out.selected.len(), 1);
311 assert_eq!(out.selected[0].id, "ok");
312 }
313
314 #[test]
315 fn tie_break_size_ascending() {
316 let inputs = vec![input("big", 200, 0.5), input("small", 50, 0.5)];
317 let out = GreedySelector.select(inputs, 1000, &weights(0.0)).unwrap();
318 assert_eq!(out.selected[0].id, "small");
319 assert_eq!(out.selected[1].id, "big");
320 }
321
322 #[test]
323 fn tie_break_id_ascending() {
324 let inputs = vec![input("z", 100, 0.5), input("a", 100, 0.5)];
325 let out = GreedySelector.select(inputs, 1000, &weights(0.0)).unwrap();
326 assert_eq!(out.selected[0].id, "a");
327 assert_eq!(out.selected[1].id, "z");
328 }
329
330 #[test]
331 fn skips_oversized_items_takes_smaller() {
332 let inputs = vec![
333 input("huge", 900, 0.9),
334 input("small1", 40, 0.3),
335 input("small2", 40, 0.2),
336 ];
337 let out = GreedySelector.select(inputs, 100, &weights(0.0)).unwrap();
338 assert_eq!(out.selected.len(), 2);
339 assert_eq!(out.selected[0].id, "small1");
340 assert_eq!(out.selected[1].id, "small2");
341 assert_eq!(out.total_size, 80);
342 }
343
344 #[test]
345 fn zero_budget() {
346 let inputs = vec![input("a", 1, 0.9)];
347 let out = GreedySelector.select(inputs, 0, &weights(0.0)).unwrap();
348 assert!(out.selected.is_empty());
349 }
350
351 #[test]
352 fn deterministic_across_input_order() {
353 let a = vec![
354 input("x", 50, 0.7),
355 input("y", 50, 0.7),
356 input("z", 50, 0.7),
357 ];
358 let b = vec![
359 input("z", 50, 0.7),
360 input("x", 50, 0.7),
361 input("y", 50, 0.7),
362 ];
363 let out_a = GreedySelector.select(a, 100, &weights(0.0)).unwrap();
364 let out_b = GreedySelector.select(b, 100, &weights(0.0)).unwrap();
365 let ids_a: Vec<&str> = out_a.selected.iter().map(|i| i.id.as_str()).collect();
366 let ids_b: Vec<&str> = out_b.selected.iter().map(|i| i.id.as_str()).collect();
367 assert_eq!(ids_a, ids_b);
368 assert_eq!(ids_a, vec!["x", "y"]);
369 }
370
371 #[test]
372 fn exact_budget_fit() {
373 let inputs = vec![input("a", 50, 0.9), input("b", 50, 0.8)];
374 let out = GreedySelector.select(inputs, 100, &weights(0.0)).unwrap();
375 assert_eq!(out.selected.len(), 2);
376 assert_eq!(out.total_size, 100);
377 }
378
379 #[test]
380 fn category_weights_boost_preferred_category() {
381 let inputs = vec![
382 input_cat("a", 100, 0.9, "low"),
383 input_cat("b", 100, 0.5, "high"),
384 ];
385 let w = SelectorWeights {
386 category_weights: [("high".to_string(), 2.0f32), ("low".to_string(), 1.0f32)]
387 .into_iter()
388 .collect(),
389 ..Default::default()
390 };
391 let out = GreedySelector.select(inputs, 100, &w).unwrap();
392 assert_eq!(out.selected.len(), 1);
393 assert_eq!(out.selected[0].id, "b");
394 }
395
396 #[test]
397 fn category_weights_can_push_below_min_score() {
398 let inputs = vec![
399 input_cat("a", 10, 0.4, "bad"),
400 input_cat("b", 10, 0.8, "good"),
401 ];
402 let w = SelectorWeights {
403 min_score: 0.3,
404 category_weights: [("bad".to_string(), 0.5f32)].into_iter().collect(),
405 ..Default::default()
406 };
407 let out = GreedySelector.select(inputs, 1000, &w).unwrap();
408 assert_eq!(out.selected.len(), 1);
409 assert_eq!(out.selected[0].id, "b");
410 }
411
412 #[test]
413 fn diversity_bias_zero_identical_to_greedy() {
414 let make = || {
415 vec![
416 input_cat("a", 100, 0.9, "x"),
417 input_cat("b", 100, 0.8, "x"),
418 input_cat("c", 100, 0.7, "y"),
419 ]
420 };
421 let w_greedy = SelectorWeights {
422 ..Default::default()
423 };
424 let w_bias0 = SelectorWeights {
425 diversity_bias: 0.0,
426 ..Default::default()
427 };
428 let out_g = GreedySelector.select(make(), 200, &w_greedy).unwrap();
429 let out_b = GreedySelector.select(make(), 200, &w_bias0).unwrap();
430 let ids_g: Vec<&str> = out_g.selected.iter().map(|i| i.id.as_str()).collect();
431 let ids_b: Vec<&str> = out_b.selected.iter().map(|i| i.id.as_str()).collect();
432 assert_eq!(ids_g, ids_b);
433 }
434
435 #[test]
436 fn diversity_bias_prefers_different_categories() {
437 let inputs = vec![
438 input_cat("a", 100, 0.9, "x"),
439 input_cat("b", 100, 0.8, "x"),
440 input_cat("c", 100, 0.7, "y"),
441 ];
442 let w = SelectorWeights {
443 diversity_bias: 1.0,
444 ..Default::default()
445 };
446 let out = GreedySelector.select(inputs, 200, &w).unwrap();
447 assert_eq!(out.selected.len(), 2);
448 let ids: Vec<&str> = out.selected.iter().map(|i| i.id.as_str()).collect();
449 assert!(ids.contains(&"a"), "a should always be selected");
450 assert!(
451 ids.contains(&"c"),
452 "c should be preferred over b due to diversity"
453 );
454 }
455
456 #[test]
457 fn no_overflow_near_usize_max() {
458 let large = usize::MAX - 1;
460 let inputs = vec![
461 SelectorInput {
462 id: "a".to_string(),
463 content: (),
464 size: large,
465 score: 0.9,
466 category: None,
467 information_gain: None,
468 },
469 SelectorInput {
470 id: "b".to_string(),
471 content: (),
472 size: 10,
473 score: 0.8,
474 category: None,
475 information_gain: None,
476 },
477 ];
478 let out = GreedySelector.select(inputs, 100, &weights(0.0)).unwrap();
480 assert_eq!(out.selected.len(), 1);
481 assert_eq!(out.selected[0].id, "b");
482 }
483
484 #[test]
485 fn diversity_bias_no_categories_unaffected() {
486 let inputs = vec![
487 input("a", 100, 0.9),
488 input("b", 100, 0.8),
489 input("c", 100, 0.7),
490 ];
491 let w = SelectorWeights {
492 diversity_bias: 1.0,
493 ..Default::default()
494 };
495 let out = GreedySelector.select(inputs, 200, &w).unwrap();
496 assert_eq!(out.selected.len(), 2);
497 assert_eq!(out.selected[0].id, "a");
498 assert_eq!(out.selected[1].id, "b");
499 }
500
501 fn input_with_gain(id: &str, size: usize, score: f32, gain: f32) -> SelectorInput<()> {
504 SelectorInput {
505 id: id.to_string(),
506 content: (),
507 size,
508 score,
509 category: None,
510 information_gain: Some(gain),
511 }
512 }
513
514 #[test]
515 fn epistemic_weight_zero_preserves_behavior() {
516 let make = || {
518 vec![
519 input_with_gain("a", 100, 0.9, 10.0),
520 input_with_gain("b", 100, 0.8, 0.0),
521 input_with_gain("c", 100, 0.7, 5.0),
522 ]
523 };
524 let w_default = SelectorWeights {
525 ..Default::default()
526 };
527 let w_zero = SelectorWeights {
528 epistemic_weight: 0.0,
529 ..Default::default()
530 };
531 let out_d = GreedySelector.select(make(), 200, &w_default).unwrap();
532 let out_z = GreedySelector.select(make(), 200, &w_zero).unwrap();
533 let ids_d: Vec<&str> = out_d.selected.iter().map(|i| i.id.as_str()).collect();
534 let ids_z: Vec<&str> = out_z.selected.iter().map(|i| i.id.as_str()).collect();
535 assert_eq!(ids_d, ids_z);
536 assert_eq!(ids_d, vec!["a", "b"]);
538 }
539
540 #[test]
541 fn epistemic_weight_positive_reorders_by_gain() {
542 let inputs = vec![
546 input_with_gain("a", 100, 0.5, 10.0),
547 input_with_gain("b", 100, 0.9, 0.0),
548 ];
549 let w = SelectorWeights {
550 epistemic_weight: 1.0,
551 ..Default::default()
552 };
553 let out = GreedySelector.select(inputs, 100, &w).unwrap();
554 assert_eq!(out.selected.len(), 1);
555 assert_eq!(out.selected[0].id, "a");
556 }
557
558 #[test]
559 fn information_gain_none_equivalent_to_zero() {
560 let with_none = vec![
562 input("a", 100, 0.9), input("b", 100, 0.8),
564 ];
565 let with_zero = vec![
566 input_with_gain("a", 100, 0.9, 0.0),
567 input_with_gain("b", 100, 0.8, 0.0),
568 ];
569 let w = SelectorWeights {
570 epistemic_weight: 1.0,
571 ..Default::default()
572 };
573 let out_none = GreedySelector.select(with_none, 200, &w).unwrap();
574 let out_zero = GreedySelector.select(with_zero, 200, &w).unwrap();
575 let ids_none: Vec<&str> = out_none.selected.iter().map(|i| i.id.as_str()).collect();
576 let ids_zero: Vec<&str> = out_zero.selected.iter().map(|i| i.id.as_str()).collect();
577 assert_eq!(ids_none, ids_zero);
578 }
579
580 #[test]
581 fn epistemic_weight_works_with_diversity_bias() {
582 let inputs = vec![
589 {
590 let mut i = input_with_gain("a", 100, 0.5, 10.0);
591 i.category = Some("x".to_string());
592 i
593 },
594 {
595 let mut i = input_with_gain("b", 100, 0.8, 0.0);
596 i.category = Some("x".to_string());
597 i
598 },
599 {
600 let mut i = input_with_gain("c", 100, 0.3, 0.0);
601 i.category = Some("y".to_string());
602 i
603 },
604 ];
605 let w = SelectorWeights {
606 epistemic_weight: 1.0,
607 diversity_bias: 0.5,
608 ..Default::default()
609 };
610 let out = GreedySelector.select(inputs, 200, &w).unwrap();
611 assert_eq!(out.selected.len(), 2);
612 assert_eq!(out.selected[0].id, "a");
613 assert_eq!(out.selected[1].id, "b");
615 }
616}