Skip to main content

khive_fold/
selector.rs

1//! Selector: many → subset under budget.
2//!
3//! Collapses a set of inputs into a compressed representation that fits a
4//! target budget (tokens, bytes, count). Pure in-memory, synchronous collapse.
5
6#[cfg(feature = "serde")]
7use serde::{Deserialize, Serialize};
8
9use crate::error::FoldError;
10
11/// A single input item to a selector operation.
12#[derive(Debug, Clone)]
13#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
14pub struct SelectorInput<T> {
15    /// Stable string identifier for deterministic tie-breaking.
16    pub id: String,
17    /// The item payload carried through selection.
18    pub content: T,
19    /// Size in the unit of the caller's budget (tokens, bytes, count).
20    pub size: usize,
21    /// Pre-computed relevance score.
22    pub score: f32,
23    /// Optional category for diversity and category-weight scoring.
24    #[cfg_attr(feature = "serde", serde(default))]
25    pub category: Option<String>,
26    /// Pre-computed information gain (KL divergence proxy) for this candidate.
27    ///
28    /// Callers pre-compute this because the Selector is pure-math and has no
29    /// access to the embedding space required to estimate KL divergence. When
30    /// `None` (the default), the value is treated as 0.0. Only has an effect
31    /// when `SelectorWeights.epistemic_weight > 0.0`.
32    #[cfg_attr(feature = "serde", serde(default))]
33    pub information_gain: Option<f32>,
34}
35
36/// Result of a selector operation.
37#[derive(Debug, Clone)]
38#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
39pub struct SelectorOutput<T> {
40    /// Selected inputs in final order.
41    pub selected: Vec<SelectorInput<T>>,
42    /// Total budget consumed.
43    pub total_size: usize,
44    /// Budget cap the caller requested.
45    pub budget: usize,
46}
47
48/// Learned weights that a selector implementation may use.
49///
50/// Callers persist this across sessions.
51#[derive(Debug, Clone, Default)]
52#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
53pub struct SelectorWeights {
54    /// Weight multiplier by input category.
55    pub category_weights: std::collections::BTreeMap<String, f32>,
56    /// Minimum score threshold (inputs below this are excluded even if budget allows).
57    pub min_score: f32,
58    /// Preference for diversity vs. relevance (0.0 = pure relevance, 1.0 = pure diversity).
59    pub diversity_bias: f32,
60    /// Weight for epistemic (uncertainty-reducing) selection.
61    ///
62    /// The effective selection score is `pragmatic_score + epistemic_weight * information_gain`.
63    /// Default 0.0 (pure pragmatic). Higher values prefer candidates that reduce uncertainty.
64    /// When 0.0, behavior is identical to pure pragmatic selection (backwards-compatible).
65    #[cfg_attr(feature = "serde", serde(default))]
66    pub epistemic_weight: f32,
67}
68
69/// The Selector primitive.
70///
71/// An implementation collapses N inputs into a subset that fits a budget,
72/// using weights and an optional query for relevance context.
73pub trait Selector<T>: Send + Sync {
74    /// Select a budget-constrained subset from `inputs`.
75    fn select(
76        &self,
77        inputs: Vec<SelectorInput<T>>,
78        budget: usize,
79        weights: &SelectorWeights,
80    ) -> Result<SelectorOutput<T>, FoldError>;
81}
82
83// ── GreedySelector ──────────────────────────────────────────────────────────
84
85/// Budget-constrained greedy packer.
86///
87/// Filters by `SelectorWeights.min_score`, applies `category_weights` multipliers
88/// to adjust scores, then greedily packs until the budget is exhausted.
89///
90/// When `diversity_bias > 0`, uses a pick-best-remaining loop: at each step the
91/// item with the highest *effective* score (after diversity penalty) is selected.
92/// The penalty is `score * (1 - bias * n / (n + 1))` where `n` is the number of
93/// already-selected items in the same category. At bias=0 this collapses to a
94/// single-pass sort (backward-compatible).
95///
96/// Tie-breaking is deterministic: size ascending, then id ascending.
97#[derive(Debug, Clone, Copy, Default)]
98pub struct GreedySelector;
99
100/// Compute the base pragmatic score adjusted for epistemic weight.
101///
102/// `base` is the pragmatic score (after category-weight multipliers).
103/// `epistemic_weight * information_gain` is the epistemic bonus.
104#[inline]
105fn pragmatic_plus_epistemic<T>(item: &SelectorInput<T>, epistemic_weight: f32) -> f32 {
106    if epistemic_weight == 0.0 {
107        return item.score;
108    }
109    item.score + epistemic_weight * item.information_gain.unwrap_or(0.0)
110}
111
112fn effective_score<T>(
113    item: &SelectorInput<T>,
114    counts: &std::collections::BTreeMap<String, usize>,
115    bias: f32,
116    epistemic_weight: f32,
117) -> f32 {
118    let base = pragmatic_plus_epistemic(item, epistemic_weight);
119    if bias == 0.0 {
120        return base;
121    }
122    let count = item
123        .category
124        .as_ref()
125        .and_then(|c| counts.get(c))
126        .copied()
127        .unwrap_or(0);
128    base * (1.0 - bias * count as f32 / (count as f32 + 1.0))
129}
130
131impl<T: Clone> Selector<T> for GreedySelector {
132    fn select(
133        &self,
134        mut inputs: Vec<SelectorInput<T>>,
135        budget: usize,
136        weights: &SelectorWeights,
137    ) -> Result<SelectorOutput<T>, FoldError> {
138        // Filter non-finite and below min_score.
139        inputs.retain(|i| i.score.is_finite() && i.score >= weights.min_score);
140
141        // Apply category_weights multipliers.
142        if !weights.category_weights.is_empty() {
143            for item in &mut inputs {
144                if let Some(ref cat) = item.category {
145                    if let Some(&w) = weights.category_weights.get(cat.as_str()) {
146                        item.score *= w.max(0.0);
147                    }
148                }
149            }
150            inputs.retain(|i| i.score.is_finite() && i.score >= weights.min_score);
151        }
152
153        let ew = weights.epistemic_weight;
154
155        // Initial sort: effective score (pragmatic + epistemic bonus) desc, size asc, id asc —
156        // deterministic across platforms.
157        inputs.sort_by(|a, b| {
158            let a_eff = pragmatic_plus_epistemic(a, ew);
159            let b_eff = pragmatic_plus_epistemic(b, ew);
160            b_eff
161                .total_cmp(&a_eff)
162                .then_with(|| a.size.cmp(&b.size))
163                .then_with(|| a.id.cmp(&b.id))
164        });
165
166        let mut selected = Vec::new();
167        let mut total_size = 0usize;
168
169        if weights.diversity_bias == 0.0 {
170            // Fast path: single-pass greedy.
171            for input in inputs {
172                if input.size <= budget.saturating_sub(total_size) {
173                    total_size += input.size;
174                    selected.push(input);
175                }
176            }
177        } else {
178            // Diversity path: pick-best-remaining with per-step effective score.
179            let mut remaining = inputs;
180            let mut category_counts: std::collections::BTreeMap<String, usize> =
181                std::collections::BTreeMap::new();
182
183            while !remaining.is_empty() && total_size < budget {
184                let best_idx = remaining
185                    .iter()
186                    .enumerate()
187                    .filter(|(_, item)| item.size <= budget.saturating_sub(total_size))
188                    .max_by(|(_, a), (_, b)| {
189                        let a_eff =
190                            effective_score(a, &category_counts, weights.diversity_bias, ew);
191                        let b_eff =
192                            effective_score(b, &category_counts, weights.diversity_bias, ew);
193                        a_eff
194                            .total_cmp(&b_eff)
195                            .then_with(|| b.size.cmp(&a.size))
196                            .then_with(|| a.id.cmp(&b.id))
197                    })
198                    .map(|(i, _)| i);
199
200                match best_idx {
201                    Some(idx) => {
202                        let item = remaining.swap_remove(idx);
203                        if let Some(ref cat) = item.category {
204                            *category_counts.entry(cat.clone()).or_default() += 1;
205                        }
206                        total_size += item.size;
207                        selected.push(item);
208                    }
209                    None => break,
210                }
211            }
212        }
213
214        Ok(SelectorOutput {
215            selected,
216            total_size,
217            budget,
218        })
219    }
220}
221
222// INLINE TEST JUSTIFICATION: selector tests exercise private helper functions
223// (pragmatic_plus_epistemic, effective_score) and internal sort logic that are
224// not accessible from a separate crate-level tests/ file. Consolidating here
225// avoids duplicating the SelectorInput construction scaffolding.
226#[cfg(test)]
227mod tests {
228    use super::*;
229
230    fn input(id: &str, size: usize, score: f32) -> SelectorInput<()> {
231        SelectorInput {
232            id: id.to_string(),
233            content: (),
234            size,
235            score,
236            category: None,
237            information_gain: None,
238        }
239    }
240
241    fn input_cat(id: &str, size: usize, score: f32, cat: &str) -> SelectorInput<()> {
242        SelectorInput {
243            id: id.to_string(),
244            content: (),
245            size,
246            score,
247            category: Some(cat.to_string()),
248            information_gain: None,
249        }
250    }
251
252    fn weights(min_score: f32) -> SelectorWeights {
253        SelectorWeights {
254            min_score,
255            ..Default::default()
256        }
257    }
258
259    #[test]
260    fn empty_input() {
261        let inputs: Vec<SelectorInput<()>> = vec![];
262        let out = GreedySelector.select(inputs, 1000, &weights(0.0)).unwrap();
263        assert!(out.selected.is_empty());
264        assert_eq!(out.total_size, 0);
265        assert_eq!(out.budget, 1000);
266    }
267
268    #[test]
269    fn packs_highest_scores_first() {
270        let inputs = vec![
271            input("a", 100, 0.5),
272            input("b", 100, 0.9),
273            input("c", 100, 0.7),
274        ];
275        let out = GreedySelector.select(inputs, 200, &weights(0.0)).unwrap();
276        assert_eq!(out.selected.len(), 2);
277        assert_eq!(out.selected[0].id, "b");
278        assert_eq!(out.selected[1].id, "c");
279        assert_eq!(out.total_size, 200);
280    }
281
282    #[test]
283    fn respects_budget() {
284        let inputs = vec![
285            input("a", 300, 0.9),
286            input("b", 300, 0.8),
287            input("c", 300, 0.7),
288        ];
289        let out = GreedySelector.select(inputs, 500, &weights(0.0)).unwrap();
290        assert_eq!(out.selected.len(), 1);
291        assert_eq!(out.selected[0].id, "a");
292        assert_eq!(out.total_size, 300);
293    }
294
295    #[test]
296    fn filters_below_min_score() {
297        let inputs = vec![
298            input("a", 10, 0.8),
299            input("b", 10, 0.1),
300            input("c", 10, 0.5),
301        ];
302        let out = GreedySelector.select(inputs, 1000, &weights(0.3)).unwrap();
303        assert_eq!(out.selected.len(), 2);
304        assert_eq!(out.selected[0].id, "a");
305        assert_eq!(out.selected[1].id, "c");
306    }
307
308    #[test]
309    fn filters_nan_and_inf() {
310        let inputs = vec![
311            input("nan", 10, f32::NAN),
312            input("inf", 10, f32::INFINITY),
313            input("neg_inf", 10, f32::NEG_INFINITY),
314            input("ok", 10, 0.5),
315        ];
316        let out = GreedySelector.select(inputs, 1000, &weights(0.0)).unwrap();
317        assert_eq!(out.selected.len(), 1);
318        assert_eq!(out.selected[0].id, "ok");
319    }
320
321    #[test]
322    fn tie_break_size_ascending() {
323        let inputs = vec![input("big", 200, 0.5), input("small", 50, 0.5)];
324        let out = GreedySelector.select(inputs, 1000, &weights(0.0)).unwrap();
325        assert_eq!(out.selected[0].id, "small");
326        assert_eq!(out.selected[1].id, "big");
327    }
328
329    #[test]
330    fn tie_break_id_ascending() {
331        let inputs = vec![input("z", 100, 0.5), input("a", 100, 0.5)];
332        let out = GreedySelector.select(inputs, 1000, &weights(0.0)).unwrap();
333        assert_eq!(out.selected[0].id, "a");
334        assert_eq!(out.selected[1].id, "z");
335    }
336
337    #[test]
338    fn skips_oversized_items_takes_smaller() {
339        let inputs = vec![
340            input("huge", 900, 0.9),
341            input("small1", 40, 0.3),
342            input("small2", 40, 0.2),
343        ];
344        let out = GreedySelector.select(inputs, 100, &weights(0.0)).unwrap();
345        assert_eq!(out.selected.len(), 2);
346        assert_eq!(out.selected[0].id, "small1");
347        assert_eq!(out.selected[1].id, "small2");
348        assert_eq!(out.total_size, 80);
349    }
350
351    #[test]
352    fn zero_budget() {
353        let inputs = vec![input("a", 1, 0.9)];
354        let out = GreedySelector.select(inputs, 0, &weights(0.0)).unwrap();
355        assert!(out.selected.is_empty());
356    }
357
358    #[test]
359    fn deterministic_across_input_order() {
360        let a = vec![
361            input("x", 50, 0.7),
362            input("y", 50, 0.7),
363            input("z", 50, 0.7),
364        ];
365        let b = vec![
366            input("z", 50, 0.7),
367            input("x", 50, 0.7),
368            input("y", 50, 0.7),
369        ];
370        let out_a = GreedySelector.select(a, 100, &weights(0.0)).unwrap();
371        let out_b = GreedySelector.select(b, 100, &weights(0.0)).unwrap();
372        let ids_a: Vec<&str> = out_a.selected.iter().map(|i| i.id.as_str()).collect();
373        let ids_b: Vec<&str> = out_b.selected.iter().map(|i| i.id.as_str()).collect();
374        assert_eq!(ids_a, ids_b);
375        assert_eq!(ids_a, vec!["x", "y"]);
376    }
377
378    #[test]
379    fn exact_budget_fit() {
380        let inputs = vec![input("a", 50, 0.9), input("b", 50, 0.8)];
381        let out = GreedySelector.select(inputs, 100, &weights(0.0)).unwrap();
382        assert_eq!(out.selected.len(), 2);
383        assert_eq!(out.total_size, 100);
384    }
385
386    #[test]
387    fn category_weights_boost_preferred_category() {
388        let inputs = vec![
389            input_cat("a", 100, 0.9, "low"),
390            input_cat("b", 100, 0.5, "high"),
391        ];
392        let w = SelectorWeights {
393            category_weights: [("high".to_string(), 2.0f32), ("low".to_string(), 1.0f32)]
394                .into_iter()
395                .collect(),
396            ..Default::default()
397        };
398        let out = GreedySelector.select(inputs, 100, &w).unwrap();
399        assert_eq!(out.selected.len(), 1);
400        assert_eq!(out.selected[0].id, "b");
401    }
402
403    #[test]
404    fn category_weights_can_push_below_min_score() {
405        let inputs = vec![
406            input_cat("a", 10, 0.4, "bad"),
407            input_cat("b", 10, 0.8, "good"),
408        ];
409        let w = SelectorWeights {
410            min_score: 0.3,
411            category_weights: [("bad".to_string(), 0.5f32)].into_iter().collect(),
412            ..Default::default()
413        };
414        let out = GreedySelector.select(inputs, 1000, &w).unwrap();
415        assert_eq!(out.selected.len(), 1);
416        assert_eq!(out.selected[0].id, "b");
417    }
418
419    #[test]
420    fn diversity_bias_zero_identical_to_greedy() {
421        let make = || {
422            vec![
423                input_cat("a", 100, 0.9, "x"),
424                input_cat("b", 100, 0.8, "x"),
425                input_cat("c", 100, 0.7, "y"),
426            ]
427        };
428        let w_greedy = SelectorWeights {
429            ..Default::default()
430        };
431        let w_bias0 = SelectorWeights {
432            diversity_bias: 0.0,
433            ..Default::default()
434        };
435        let out_g = GreedySelector.select(make(), 200, &w_greedy).unwrap();
436        let out_b = GreedySelector.select(make(), 200, &w_bias0).unwrap();
437        let ids_g: Vec<&str> = out_g.selected.iter().map(|i| i.id.as_str()).collect();
438        let ids_b: Vec<&str> = out_b.selected.iter().map(|i| i.id.as_str()).collect();
439        assert_eq!(ids_g, ids_b);
440    }
441
442    #[test]
443    fn diversity_bias_prefers_different_categories() {
444        let inputs = vec![
445            input_cat("a", 100, 0.9, "x"),
446            input_cat("b", 100, 0.8, "x"),
447            input_cat("c", 100, 0.7, "y"),
448        ];
449        let w = SelectorWeights {
450            diversity_bias: 1.0,
451            ..Default::default()
452        };
453        let out = GreedySelector.select(inputs, 200, &w).unwrap();
454        assert_eq!(out.selected.len(), 2);
455        let ids: Vec<&str> = out.selected.iter().map(|i| i.id.as_str()).collect();
456        assert!(ids.contains(&"a"), "a should always be selected");
457        assert!(
458            ids.contains(&"c"),
459            "c should be preferred over b due to diversity"
460        );
461    }
462
463    #[test]
464    fn no_overflow_near_usize_max() {
465        // Items with near-usize::MAX sizes must not overflow when checking budget.
466        let large = usize::MAX - 1;
467        let inputs = vec![
468            SelectorInput {
469                id: "a".to_string(),
470                content: (),
471                size: large,
472                score: 0.9,
473                category: None,
474                information_gain: None,
475            },
476            SelectorInput {
477                id: "b".to_string(),
478                content: (),
479                size: 10,
480                score: 0.8,
481                category: None,
482                information_gain: None,
483            },
484        ];
485        // Budget is 100 — only item "b" fits.
486        let out = GreedySelector.select(inputs, 100, &weights(0.0)).unwrap();
487        assert_eq!(out.selected.len(), 1);
488        assert_eq!(out.selected[0].id, "b");
489    }
490
491    #[test]
492    fn diversity_bias_no_categories_unaffected() {
493        let inputs = vec![
494            input("a", 100, 0.9),
495            input("b", 100, 0.8),
496            input("c", 100, 0.7),
497        ];
498        let w = SelectorWeights {
499            diversity_bias: 1.0,
500            ..Default::default()
501        };
502        let out = GreedySelector.select(inputs, 200, &w).unwrap();
503        assert_eq!(out.selected.len(), 2);
504        assert_eq!(out.selected[0].id, "a");
505        assert_eq!(out.selected[1].id, "b");
506    }
507
508    // ── epistemic weight tests ────────────────────────────────────────────────
509
510    fn input_with_gain(id: &str, size: usize, score: f32, gain: f32) -> SelectorInput<()> {
511        SelectorInput {
512            id: id.to_string(),
513            content: (),
514            size,
515            score,
516            category: None,
517            information_gain: Some(gain),
518        }
519    }
520
521    #[test]
522    fn epistemic_weight_zero_preserves_behavior() {
523        // With epistemic_weight=0, result must be identical to the default (no epistemic).
524        let make = || {
525            vec![
526                input_with_gain("a", 100, 0.9, 10.0),
527                input_with_gain("b", 100, 0.8, 0.0),
528                input_with_gain("c", 100, 0.7, 5.0),
529            ]
530        };
531        let w_default = SelectorWeights {
532            ..Default::default()
533        };
534        let w_zero = SelectorWeights {
535            epistemic_weight: 0.0,
536            ..Default::default()
537        };
538        let out_d = GreedySelector.select(make(), 200, &w_default).unwrap();
539        let out_z = GreedySelector.select(make(), 200, &w_zero).unwrap();
540        let ids_d: Vec<&str> = out_d.selected.iter().map(|i| i.id.as_str()).collect();
541        let ids_z: Vec<&str> = out_z.selected.iter().map(|i| i.id.as_str()).collect();
542        assert_eq!(ids_d, ids_z);
543        // Pure score order: a (0.9), b (0.8).
544        assert_eq!(ids_d, vec!["a", "b"]);
545    }
546
547    #[test]
548    fn epistemic_weight_positive_reorders_by_gain() {
549        // a: score=0.5, gain=10.0  → effective = 0.5 + 1.0 * 10.0 = 10.5
550        // b: score=0.9, gain=0.0   → effective = 0.9 + 1.0 * 0.0  = 0.9
551        // With epistemic_weight=1.0, a should be selected first.
552        let inputs = vec![
553            input_with_gain("a", 100, 0.5, 10.0),
554            input_with_gain("b", 100, 0.9, 0.0),
555        ];
556        let w = SelectorWeights {
557            epistemic_weight: 1.0,
558            ..Default::default()
559        };
560        let out = GreedySelector.select(inputs, 100, &w).unwrap();
561        assert_eq!(out.selected.len(), 1);
562        assert_eq!(out.selected[0].id, "a");
563    }
564
565    #[test]
566    fn information_gain_none_equivalent_to_zero() {
567        // None and Some(0.0) must produce identical ordering.
568        let with_none = vec![
569            input("a", 100, 0.9), // information_gain: None
570            input("b", 100, 0.8),
571        ];
572        let with_zero = vec![
573            input_with_gain("a", 100, 0.9, 0.0),
574            input_with_gain("b", 100, 0.8, 0.0),
575        ];
576        let w = SelectorWeights {
577            epistemic_weight: 1.0,
578            ..Default::default()
579        };
580        let out_none = GreedySelector.select(with_none, 200, &w).unwrap();
581        let out_zero = GreedySelector.select(with_zero, 200, &w).unwrap();
582        let ids_none: Vec<&str> = out_none.selected.iter().map(|i| i.id.as_str()).collect();
583        let ids_zero: Vec<&str> = out_zero.selected.iter().map(|i| i.id.as_str()).collect();
584        assert_eq!(ids_none, ids_zero);
585    }
586
587    #[test]
588    fn epistemic_weight_works_with_diversity_bias() {
589        // Combines epistemic and diversity: the effective score incorporates both.
590        // a: score=0.5, gain=10.0, category=x → base effective = 0.5 + 1.0 * 10.0 = 10.5
591        // b: score=0.8, gain=0.0,  category=x → base effective = 0.8
592        // c: score=0.3, gain=0.0,  category=y → base effective = 0.3
593        // Budget=200, bias=0.5: a selected first (10.5 wins), then after a is in x,
594        // b's diversity penalty is 0.8*(1-0.5*1/2)=0.8*0.75=0.6 vs c at 0.3 — b wins.
595        let inputs = vec![
596            {
597                let mut i = input_with_gain("a", 100, 0.5, 10.0);
598                i.category = Some("x".to_string());
599                i
600            },
601            {
602                let mut i = input_with_gain("b", 100, 0.8, 0.0);
603                i.category = Some("x".to_string());
604                i
605            },
606            {
607                let mut i = input_with_gain("c", 100, 0.3, 0.0);
608                i.category = Some("y".to_string());
609                i
610            },
611        ];
612        let w = SelectorWeights {
613            epistemic_weight: 1.0,
614            diversity_bias: 0.5,
615            ..Default::default()
616        };
617        let out = GreedySelector.select(inputs, 200, &w).unwrap();
618        assert_eq!(out.selected.len(), 2);
619        assert_eq!(out.selected[0].id, "a");
620        // b (eff=0.8*0.75=0.6) > c (eff=0.3) after a is placed in category x.
621        assert_eq!(out.selected[1].id, "b");
622    }
623}