Skip to main content

khive_fold/
selector.rs

1//! Selector: many → subset under budget.
2//!
3//! Collapses a set of inputs into a compressed representation that fits a
4//! target budget (tokens, bytes, count). Pure in-memory, synchronous collapse.
5
6#[cfg(feature = "serde")]
7use serde::{Deserialize, Serialize};
8
9use crate::error::FoldError;
10
11/// A single input item to a selector operation.
12#[derive(Debug, Clone)]
13#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
14pub struct SelectorInput<T> {
15    pub id: String,
16    pub content: T,
17    /// Size in the unit of the caller's budget (tokens, bytes, count).
18    pub size: usize,
19    /// Pre-computed relevance score.
20    pub score: f32,
21    /// Optional category for diversity and category-weight scoring.
22    #[cfg_attr(feature = "serde", serde(default))]
23    pub category: Option<String>,
24    /// Pre-computed information gain (KL divergence proxy) for this candidate.
25    ///
26    /// Callers pre-compute this because the Selector is pure-math and has no
27    /// access to the embedding space required to estimate KL divergence. When
28    /// `None` (the default), the value is treated as 0.0. Only has an effect
29    /// when `SelectorWeights.epistemic_weight > 0.0` (ADR-059).
30    #[cfg_attr(feature = "serde", serde(default))]
31    pub information_gain: Option<f32>,
32}
33
34/// Result of a selector operation.
35#[derive(Debug, Clone)]
36#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
37pub struct SelectorOutput<T> {
38    /// Selected inputs in final order.
39    pub selected: Vec<SelectorInput<T>>,
40    /// Total budget consumed.
41    pub total_size: usize,
42    /// Budget cap the caller requested.
43    pub budget: usize,
44}
45
46/// Learned weights that a selector implementation may use.
47///
48/// Callers persist this across sessions.
49#[derive(Debug, Clone, Default)]
50#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
51pub struct SelectorWeights {
52    /// Weight multiplier by input category.
53    pub category_weights: std::collections::BTreeMap<String, f32>,
54    /// Minimum score threshold (inputs below this are excluded even if budget allows).
55    pub min_score: f32,
56    /// Preference for diversity vs. relevance (0.0 = pure relevance, 1.0 = pure diversity).
57    pub diversity_bias: f32,
58    /// Weight for epistemic (uncertainty-reducing) selection.
59    ///
60    /// The effective selection score is `pragmatic_score + epistemic_weight * information_gain`.
61    /// Default 0.0 (pure pragmatic). Higher values prefer candidates that reduce uncertainty.
62    /// When 0.0, behavior is identical to ADR-058 (backwards-compatible, ADR-059).
63    #[cfg_attr(feature = "serde", serde(default))]
64    pub epistemic_weight: f32,
65}
66
67/// The Selector primitive.
68///
69/// An implementation collapses N inputs into a subset that fits a budget,
70/// using weights and an optional query for relevance context.
71pub trait Selector<T>: Send + Sync {
72    fn select(
73        &self,
74        inputs: Vec<SelectorInput<T>>,
75        budget: usize,
76        weights: &SelectorWeights,
77    ) -> Result<SelectorOutput<T>, FoldError>;
78}
79
80// ── GreedySelector ──────────────────────────────────────────────────────────
81
82/// Budget-constrained greedy packer.
83///
84/// Filters by `SelectorWeights.min_score`, applies `category_weights` multipliers
85/// to adjust scores, then greedily packs until the budget is exhausted.
86///
87/// When `diversity_bias > 0`, uses a pick-best-remaining loop: at each step the
88/// item with the highest *effective* score (after diversity penalty) is selected.
89/// The penalty is `score * (1 - bias * n / (n + 1))` where `n` is the number of
90/// already-selected items in the same category. At bias=0 this collapses to a
91/// single-pass sort (backward-compatible).
92///
93/// Tie-breaking is deterministic: size ascending, then id ascending.
94#[derive(Debug, Clone, Copy, Default)]
95pub struct GreedySelector;
96
97/// Compute the base pragmatic score adjusted for epistemic weight.
98///
99/// `base` is the pragmatic score (after category-weight multipliers).
100/// `epistemic_weight * information_gain` is the epistemic bonus (ADR-059).
101#[inline]
102fn pragmatic_plus_epistemic<T>(item: &SelectorInput<T>, epistemic_weight: f32) -> f32 {
103    if epistemic_weight == 0.0 {
104        return item.score;
105    }
106    item.score + epistemic_weight * item.information_gain.unwrap_or(0.0)
107}
108
109fn effective_score<T>(
110    item: &SelectorInput<T>,
111    counts: &std::collections::BTreeMap<String, usize>,
112    bias: f32,
113    epistemic_weight: f32,
114) -> f32 {
115    let base = pragmatic_plus_epistemic(item, epistemic_weight);
116    if bias == 0.0 {
117        return base;
118    }
119    let count = item
120        .category
121        .as_ref()
122        .and_then(|c| counts.get(c))
123        .copied()
124        .unwrap_or(0);
125    base * (1.0 - bias * count as f32 / (count as f32 + 1.0))
126}
127
128impl<T: Clone> Selector<T> for GreedySelector {
129    fn select(
130        &self,
131        mut inputs: Vec<SelectorInput<T>>,
132        budget: usize,
133        weights: &SelectorWeights,
134    ) -> Result<SelectorOutput<T>, FoldError> {
135        // Filter non-finite and below min_score.
136        inputs.retain(|i| i.score.is_finite() && i.score >= weights.min_score);
137
138        // Apply category_weights multipliers.
139        if !weights.category_weights.is_empty() {
140            for item in &mut inputs {
141                if let Some(ref cat) = item.category {
142                    if let Some(&w) = weights.category_weights.get(cat.as_str()) {
143                        item.score *= w.max(0.0);
144                    }
145                }
146            }
147            inputs.retain(|i| i.score.is_finite() && i.score >= weights.min_score);
148        }
149
150        let ew = weights.epistemic_weight;
151
152        // Initial sort: effective score (pragmatic + epistemic bonus) desc, size asc, id asc —
153        // deterministic across platforms.
154        inputs.sort_by(|a, b| {
155            let a_eff = pragmatic_plus_epistemic(a, ew);
156            let b_eff = pragmatic_plus_epistemic(b, ew);
157            b_eff
158                .total_cmp(&a_eff)
159                .then_with(|| a.size.cmp(&b.size))
160                .then_with(|| a.id.cmp(&b.id))
161        });
162
163        let mut selected = Vec::new();
164        let mut total_size = 0usize;
165
166        if weights.diversity_bias == 0.0 {
167            // Fast path: single-pass greedy.
168            for input in inputs {
169                if input.size <= budget.saturating_sub(total_size) {
170                    total_size += input.size;
171                    selected.push(input);
172                }
173            }
174        } else {
175            // Diversity path: pick-best-remaining with per-step effective score.
176            let mut remaining = inputs;
177            let mut category_counts: std::collections::BTreeMap<String, usize> =
178                std::collections::BTreeMap::new();
179
180            while !remaining.is_empty() && total_size < budget {
181                let best_idx = remaining
182                    .iter()
183                    .enumerate()
184                    .filter(|(_, item)| item.size <= budget.saturating_sub(total_size))
185                    .max_by(|(_, a), (_, b)| {
186                        let a_eff =
187                            effective_score(a, &category_counts, weights.diversity_bias, ew);
188                        let b_eff =
189                            effective_score(b, &category_counts, weights.diversity_bias, ew);
190                        a_eff
191                            .total_cmp(&b_eff)
192                            .then_with(|| b.size.cmp(&a.size))
193                            .then_with(|| a.id.cmp(&b.id))
194                    })
195                    .map(|(i, _)| i);
196
197                match best_idx {
198                    Some(idx) => {
199                        let item = remaining.swap_remove(idx);
200                        if let Some(ref cat) = item.category {
201                            *category_counts.entry(cat.clone()).or_default() += 1;
202                        }
203                        total_size += item.size;
204                        selected.push(item);
205                    }
206                    None => break,
207                }
208            }
209        }
210
211        Ok(SelectorOutput {
212            selected,
213            total_size,
214            budget,
215        })
216    }
217}
218
219#[cfg(test)]
220mod tests {
221    use super::*;
222
223    fn input(id: &str, size: usize, score: f32) -> SelectorInput<()> {
224        SelectorInput {
225            id: id.to_string(),
226            content: (),
227            size,
228            score,
229            category: None,
230            information_gain: None,
231        }
232    }
233
234    fn input_cat(id: &str, size: usize, score: f32, cat: &str) -> SelectorInput<()> {
235        SelectorInput {
236            id: id.to_string(),
237            content: (),
238            size,
239            score,
240            category: Some(cat.to_string()),
241            information_gain: None,
242        }
243    }
244
245    fn weights(min_score: f32) -> SelectorWeights {
246        SelectorWeights {
247            min_score,
248            ..Default::default()
249        }
250    }
251
252    #[test]
253    fn empty_input() {
254        let inputs: Vec<SelectorInput<()>> = vec![];
255        let out = GreedySelector.select(inputs, 1000, &weights(0.0)).unwrap();
256        assert!(out.selected.is_empty());
257        assert_eq!(out.total_size, 0);
258        assert_eq!(out.budget, 1000);
259    }
260
261    #[test]
262    fn packs_highest_scores_first() {
263        let inputs = vec![
264            input("a", 100, 0.5),
265            input("b", 100, 0.9),
266            input("c", 100, 0.7),
267        ];
268        let out = GreedySelector.select(inputs, 200, &weights(0.0)).unwrap();
269        assert_eq!(out.selected.len(), 2);
270        assert_eq!(out.selected[0].id, "b");
271        assert_eq!(out.selected[1].id, "c");
272        assert_eq!(out.total_size, 200);
273    }
274
275    #[test]
276    fn respects_budget() {
277        let inputs = vec![
278            input("a", 300, 0.9),
279            input("b", 300, 0.8),
280            input("c", 300, 0.7),
281        ];
282        let out = GreedySelector.select(inputs, 500, &weights(0.0)).unwrap();
283        assert_eq!(out.selected.len(), 1);
284        assert_eq!(out.selected[0].id, "a");
285        assert_eq!(out.total_size, 300);
286    }
287
288    #[test]
289    fn filters_below_min_score() {
290        let inputs = vec![
291            input("a", 10, 0.8),
292            input("b", 10, 0.1),
293            input("c", 10, 0.5),
294        ];
295        let out = GreedySelector.select(inputs, 1000, &weights(0.3)).unwrap();
296        assert_eq!(out.selected.len(), 2);
297        assert_eq!(out.selected[0].id, "a");
298        assert_eq!(out.selected[1].id, "c");
299    }
300
301    #[test]
302    fn filters_nan_and_inf() {
303        let inputs = vec![
304            input("nan", 10, f32::NAN),
305            input("inf", 10, f32::INFINITY),
306            input("neg_inf", 10, f32::NEG_INFINITY),
307            input("ok", 10, 0.5),
308        ];
309        let out = GreedySelector.select(inputs, 1000, &weights(0.0)).unwrap();
310        assert_eq!(out.selected.len(), 1);
311        assert_eq!(out.selected[0].id, "ok");
312    }
313
314    #[test]
315    fn tie_break_size_ascending() {
316        let inputs = vec![input("big", 200, 0.5), input("small", 50, 0.5)];
317        let out = GreedySelector.select(inputs, 1000, &weights(0.0)).unwrap();
318        assert_eq!(out.selected[0].id, "small");
319        assert_eq!(out.selected[1].id, "big");
320    }
321
322    #[test]
323    fn tie_break_id_ascending() {
324        let inputs = vec![input("z", 100, 0.5), input("a", 100, 0.5)];
325        let out = GreedySelector.select(inputs, 1000, &weights(0.0)).unwrap();
326        assert_eq!(out.selected[0].id, "a");
327        assert_eq!(out.selected[1].id, "z");
328    }
329
330    #[test]
331    fn skips_oversized_items_takes_smaller() {
332        let inputs = vec![
333            input("huge", 900, 0.9),
334            input("small1", 40, 0.3),
335            input("small2", 40, 0.2),
336        ];
337        let out = GreedySelector.select(inputs, 100, &weights(0.0)).unwrap();
338        assert_eq!(out.selected.len(), 2);
339        assert_eq!(out.selected[0].id, "small1");
340        assert_eq!(out.selected[1].id, "small2");
341        assert_eq!(out.total_size, 80);
342    }
343
344    #[test]
345    fn zero_budget() {
346        let inputs = vec![input("a", 1, 0.9)];
347        let out = GreedySelector.select(inputs, 0, &weights(0.0)).unwrap();
348        assert!(out.selected.is_empty());
349    }
350
351    #[test]
352    fn deterministic_across_input_order() {
353        let a = vec![
354            input("x", 50, 0.7),
355            input("y", 50, 0.7),
356            input("z", 50, 0.7),
357        ];
358        let b = vec![
359            input("z", 50, 0.7),
360            input("x", 50, 0.7),
361            input("y", 50, 0.7),
362        ];
363        let out_a = GreedySelector.select(a, 100, &weights(0.0)).unwrap();
364        let out_b = GreedySelector.select(b, 100, &weights(0.0)).unwrap();
365        let ids_a: Vec<&str> = out_a.selected.iter().map(|i| i.id.as_str()).collect();
366        let ids_b: Vec<&str> = out_b.selected.iter().map(|i| i.id.as_str()).collect();
367        assert_eq!(ids_a, ids_b);
368        assert_eq!(ids_a, vec!["x", "y"]);
369    }
370
371    #[test]
372    fn exact_budget_fit() {
373        let inputs = vec![input("a", 50, 0.9), input("b", 50, 0.8)];
374        let out = GreedySelector.select(inputs, 100, &weights(0.0)).unwrap();
375        assert_eq!(out.selected.len(), 2);
376        assert_eq!(out.total_size, 100);
377    }
378
379    #[test]
380    fn category_weights_boost_preferred_category() {
381        let inputs = vec![
382            input_cat("a", 100, 0.9, "low"),
383            input_cat("b", 100, 0.5, "high"),
384        ];
385        let w = SelectorWeights {
386            category_weights: [("high".to_string(), 2.0f32), ("low".to_string(), 1.0f32)]
387                .into_iter()
388                .collect(),
389            ..Default::default()
390        };
391        let out = GreedySelector.select(inputs, 100, &w).unwrap();
392        assert_eq!(out.selected.len(), 1);
393        assert_eq!(out.selected[0].id, "b");
394    }
395
396    #[test]
397    fn category_weights_can_push_below_min_score() {
398        let inputs = vec![
399            input_cat("a", 10, 0.4, "bad"),
400            input_cat("b", 10, 0.8, "good"),
401        ];
402        let w = SelectorWeights {
403            min_score: 0.3,
404            category_weights: [("bad".to_string(), 0.5f32)].into_iter().collect(),
405            ..Default::default()
406        };
407        let out = GreedySelector.select(inputs, 1000, &w).unwrap();
408        assert_eq!(out.selected.len(), 1);
409        assert_eq!(out.selected[0].id, "b");
410    }
411
412    #[test]
413    fn diversity_bias_zero_identical_to_greedy() {
414        let make = || {
415            vec![
416                input_cat("a", 100, 0.9, "x"),
417                input_cat("b", 100, 0.8, "x"),
418                input_cat("c", 100, 0.7, "y"),
419            ]
420        };
421        let w_greedy = SelectorWeights {
422            ..Default::default()
423        };
424        let w_bias0 = SelectorWeights {
425            diversity_bias: 0.0,
426            ..Default::default()
427        };
428        let out_g = GreedySelector.select(make(), 200, &w_greedy).unwrap();
429        let out_b = GreedySelector.select(make(), 200, &w_bias0).unwrap();
430        let ids_g: Vec<&str> = out_g.selected.iter().map(|i| i.id.as_str()).collect();
431        let ids_b: Vec<&str> = out_b.selected.iter().map(|i| i.id.as_str()).collect();
432        assert_eq!(ids_g, ids_b);
433    }
434
435    #[test]
436    fn diversity_bias_prefers_different_categories() {
437        let inputs = vec![
438            input_cat("a", 100, 0.9, "x"),
439            input_cat("b", 100, 0.8, "x"),
440            input_cat("c", 100, 0.7, "y"),
441        ];
442        let w = SelectorWeights {
443            diversity_bias: 1.0,
444            ..Default::default()
445        };
446        let out = GreedySelector.select(inputs, 200, &w).unwrap();
447        assert_eq!(out.selected.len(), 2);
448        let ids: Vec<&str> = out.selected.iter().map(|i| i.id.as_str()).collect();
449        assert!(ids.contains(&"a"), "a should always be selected");
450        assert!(
451            ids.contains(&"c"),
452            "c should be preferred over b due to diversity"
453        );
454    }
455
456    #[test]
457    fn no_overflow_near_usize_max() {
458        // Items with near-usize::MAX sizes must not overflow when checking budget.
459        let large = usize::MAX - 1;
460        let inputs = vec![
461            SelectorInput {
462                id: "a".to_string(),
463                content: (),
464                size: large,
465                score: 0.9,
466                category: None,
467                information_gain: None,
468            },
469            SelectorInput {
470                id: "b".to_string(),
471                content: (),
472                size: 10,
473                score: 0.8,
474                category: None,
475                information_gain: None,
476            },
477        ];
478        // Budget is 100 — only item "b" fits.
479        let out = GreedySelector.select(inputs, 100, &weights(0.0)).unwrap();
480        assert_eq!(out.selected.len(), 1);
481        assert_eq!(out.selected[0].id, "b");
482    }
483
484    #[test]
485    fn diversity_bias_no_categories_unaffected() {
486        let inputs = vec![
487            input("a", 100, 0.9),
488            input("b", 100, 0.8),
489            input("c", 100, 0.7),
490        ];
491        let w = SelectorWeights {
492            diversity_bias: 1.0,
493            ..Default::default()
494        };
495        let out = GreedySelector.select(inputs, 200, &w).unwrap();
496        assert_eq!(out.selected.len(), 2);
497        assert_eq!(out.selected[0].id, "a");
498        assert_eq!(out.selected[1].id, "b");
499    }
500
501    // ── ADR-059: epistemic weight tests ──────────────────────────────────────
502
503    fn input_with_gain(id: &str, size: usize, score: f32, gain: f32) -> SelectorInput<()> {
504        SelectorInput {
505            id: id.to_string(),
506            content: (),
507            size,
508            score,
509            category: None,
510            information_gain: Some(gain),
511        }
512    }
513
514    #[test]
515    fn epistemic_weight_zero_preserves_behavior() {
516        // With epistemic_weight=0, result must be identical to the default (no epistemic).
517        let make = || {
518            vec![
519                input_with_gain("a", 100, 0.9, 10.0),
520                input_with_gain("b", 100, 0.8, 0.0),
521                input_with_gain("c", 100, 0.7, 5.0),
522            ]
523        };
524        let w_default = SelectorWeights {
525            ..Default::default()
526        };
527        let w_zero = SelectorWeights {
528            epistemic_weight: 0.0,
529            ..Default::default()
530        };
531        let out_d = GreedySelector.select(make(), 200, &w_default).unwrap();
532        let out_z = GreedySelector.select(make(), 200, &w_zero).unwrap();
533        let ids_d: Vec<&str> = out_d.selected.iter().map(|i| i.id.as_str()).collect();
534        let ids_z: Vec<&str> = out_z.selected.iter().map(|i| i.id.as_str()).collect();
535        assert_eq!(ids_d, ids_z);
536        // Pure score order: a (0.9), b (0.8).
537        assert_eq!(ids_d, vec!["a", "b"]);
538    }
539
540    #[test]
541    fn epistemic_weight_positive_reorders_by_gain() {
542        // a: score=0.5, gain=10.0  → effective = 0.5 + 1.0 * 10.0 = 10.5
543        // b: score=0.9, gain=0.0   → effective = 0.9 + 1.0 * 0.0  = 0.9
544        // With epistemic_weight=1.0, a should be selected first.
545        let inputs = vec![
546            input_with_gain("a", 100, 0.5, 10.0),
547            input_with_gain("b", 100, 0.9, 0.0),
548        ];
549        let w = SelectorWeights {
550            epistemic_weight: 1.0,
551            ..Default::default()
552        };
553        let out = GreedySelector.select(inputs, 100, &w).unwrap();
554        assert_eq!(out.selected.len(), 1);
555        assert_eq!(out.selected[0].id, "a");
556    }
557
558    #[test]
559    fn information_gain_none_equivalent_to_zero() {
560        // None and Some(0.0) must produce identical ordering.
561        let with_none = vec![
562            input("a", 100, 0.9), // information_gain: None
563            input("b", 100, 0.8),
564        ];
565        let with_zero = vec![
566            input_with_gain("a", 100, 0.9, 0.0),
567            input_with_gain("b", 100, 0.8, 0.0),
568        ];
569        let w = SelectorWeights {
570            epistemic_weight: 1.0,
571            ..Default::default()
572        };
573        let out_none = GreedySelector.select(with_none, 200, &w).unwrap();
574        let out_zero = GreedySelector.select(with_zero, 200, &w).unwrap();
575        let ids_none: Vec<&str> = out_none.selected.iter().map(|i| i.id.as_str()).collect();
576        let ids_zero: Vec<&str> = out_zero.selected.iter().map(|i| i.id.as_str()).collect();
577        assert_eq!(ids_none, ids_zero);
578    }
579
580    #[test]
581    fn epistemic_weight_works_with_diversity_bias() {
582        // Combines epistemic and diversity: the effective score incorporates both.
583        // a: score=0.5, gain=10.0, category=x → base effective = 0.5 + 1.0 * 10.0 = 10.5
584        // b: score=0.8, gain=0.0,  category=x → base effective = 0.8
585        // c: score=0.3, gain=0.0,  category=y → base effective = 0.3
586        // Budget=200, bias=0.5: a selected first (10.5 wins), then after a is in x,
587        // b's diversity penalty is 0.8*(1-0.5*1/2)=0.8*0.75=0.6 vs c at 0.3 — b wins.
588        let inputs = vec![
589            {
590                let mut i = input_with_gain("a", 100, 0.5, 10.0);
591                i.category = Some("x".to_string());
592                i
593            },
594            {
595                let mut i = input_with_gain("b", 100, 0.8, 0.0);
596                i.category = Some("x".to_string());
597                i
598            },
599            {
600                let mut i = input_with_gain("c", 100, 0.3, 0.0);
601                i.category = Some("y".to_string());
602                i
603            },
604        ];
605        let w = SelectorWeights {
606            epistemic_weight: 1.0,
607            diversity_bias: 0.5,
608            ..Default::default()
609        };
610        let out = GreedySelector.select(inputs, 200, &w).unwrap();
611        assert_eq!(out.selected.len(), 2);
612        assert_eq!(out.selected[0].id, "a");
613        // b (eff=0.8*0.75=0.6) > c (eff=0.3) after a is placed in category x.
614        assert_eq!(out.selected[1].id, "b");
615    }
616}