Skip to main content

khive_fold/
selector.rs

1//! Selector: many → subset under budget.
2//!
3//! Collapses a set of inputs into a compressed representation that fits a
4//! target budget (tokens, bytes, count). Pure in-memory, synchronous collapse.
5
6use serde::{Deserialize, Serialize};
7
8use crate::error::FoldError;
9
10/// A single input item to a selector operation.
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct SelectorInput<T> {
13    pub id: String,
14    pub content: T,
15    /// Size in the unit of the caller's budget (tokens, bytes, count).
16    pub size: usize,
17    /// Pre-computed relevance score.
18    pub score: f32,
19    /// Optional category for diversity and category-weight scoring.
20    #[serde(default)]
21    pub category: Option<String>,
22    /// Pre-computed information gain (KL divergence proxy) for this candidate.
23    ///
24    /// Callers pre-compute this because the Selector is pure-math and has no
25    /// access to the embedding space required to estimate KL divergence. When
26    /// `None` (the default), the value is treated as 0.0. Only has an effect
27    /// when `SelectorWeights.epistemic_weight > 0.0` (ADR-059).
28    #[serde(default)]
29    pub information_gain: Option<f32>,
30}
31
32/// Result of a selector operation.
33#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct SelectorOutput<T> {
35    /// Selected inputs in final order.
36    pub selected: Vec<SelectorInput<T>>,
37    /// Total budget consumed.
38    pub total_size: usize,
39    /// Budget cap the caller requested.
40    pub budget: usize,
41}
42
43/// Learned weights that a selector implementation may use.
44///
45/// Callers persist this across sessions.
46#[derive(Debug, Clone, Default, Serialize, Deserialize)]
47pub struct SelectorWeights {
48    /// Weight multiplier by input category.
49    pub category_weights: std::collections::BTreeMap<String, f32>,
50    /// Minimum score threshold (inputs below this are excluded even if budget allows).
51    pub min_score: f32,
52    /// Preference for diversity vs. relevance (0.0 = pure relevance, 1.0 = pure diversity).
53    pub diversity_bias: f32,
54    /// Weight for epistemic (uncertainty-reducing) selection.
55    ///
56    /// The effective selection score is `pragmatic_score + epistemic_weight * information_gain`.
57    /// Default 0.0 (pure pragmatic). Higher values prefer candidates that reduce uncertainty.
58    /// When 0.0, behavior is identical to ADR-058 (backwards-compatible, ADR-059).
59    #[serde(default)]
60    pub epistemic_weight: f32,
61}
62
63/// The Selector primitive.
64///
65/// An implementation collapses N inputs into a subset that fits a budget,
66/// using weights and an optional query for relevance context.
67pub trait Selector<T> {
68    fn select(
69        &self,
70        inputs: Vec<SelectorInput<T>>,
71        budget: usize,
72        weights: &SelectorWeights,
73    ) -> Result<SelectorOutput<T>, FoldError>;
74}
75
76// ── GreedySelector ──────────────────────────────────────────────────────────
77
78/// Budget-constrained greedy packer.
79///
80/// Filters by `SelectorWeights.min_score`, applies `category_weights` multipliers
81/// to adjust scores, then greedily packs until the budget is exhausted.
82///
83/// When `diversity_bias > 0`, uses a pick-best-remaining loop: at each step the
84/// item with the highest *effective* score (after diversity penalty) is selected.
85/// The penalty is `score * (1 - bias * n / (n + 1))` where `n` is the number of
86/// already-selected items in the same category. At bias=0 this collapses to a
87/// single-pass sort (backward-compatible).
88///
89/// Tie-breaking is deterministic: size ascending, then id ascending.
90#[derive(Debug, Clone, Copy, Default)]
91pub struct GreedySelector;
92
93/// Compute the base pragmatic score adjusted for epistemic weight.
94///
95/// `base` is the pragmatic score (after category-weight multipliers).
96/// `epistemic_weight * information_gain` is the epistemic bonus (ADR-059).
97#[inline]
98fn pragmatic_plus_epistemic<T>(item: &SelectorInput<T>, epistemic_weight: f32) -> f32 {
99    if epistemic_weight == 0.0 {
100        return item.score;
101    }
102    item.score + epistemic_weight * item.information_gain.unwrap_or(0.0)
103}
104
105fn effective_score<T>(
106    item: &SelectorInput<T>,
107    counts: &std::collections::BTreeMap<String, usize>,
108    bias: f32,
109    epistemic_weight: f32,
110) -> f32 {
111    let base = pragmatic_plus_epistemic(item, epistemic_weight);
112    if bias == 0.0 {
113        return base;
114    }
115    let count = item
116        .category
117        .as_ref()
118        .and_then(|c| counts.get(c))
119        .copied()
120        .unwrap_or(0);
121    base * (1.0 - bias * count as f32 / (count as f32 + 1.0))
122}
123
124impl<T: Clone> Selector<T> for GreedySelector {
125    fn select(
126        &self,
127        mut inputs: Vec<SelectorInput<T>>,
128        budget: usize,
129        weights: &SelectorWeights,
130    ) -> Result<SelectorOutput<T>, FoldError> {
131        // Filter non-finite and below min_score.
132        inputs.retain(|i| i.score.is_finite() && i.score >= weights.min_score);
133
134        // Apply category_weights multipliers.
135        if !weights.category_weights.is_empty() {
136            for item in &mut inputs {
137                if let Some(ref cat) = item.category {
138                    if let Some(&w) = weights.category_weights.get(cat.as_str()) {
139                        item.score *= w.max(0.0);
140                    }
141                }
142            }
143            inputs.retain(|i| i.score.is_finite() && i.score >= weights.min_score);
144        }
145
146        let ew = weights.epistemic_weight;
147
148        // Initial sort: effective score (pragmatic + epistemic bonus) desc, size asc, id asc —
149        // deterministic across platforms.
150        inputs.sort_by(|a, b| {
151            let a_eff = pragmatic_plus_epistemic(a, ew);
152            let b_eff = pragmatic_plus_epistemic(b, ew);
153            b_eff
154                .total_cmp(&a_eff)
155                .then_with(|| a.size.cmp(&b.size))
156                .then_with(|| a.id.cmp(&b.id))
157        });
158
159        let mut selected = Vec::new();
160        let mut total_size = 0usize;
161
162        if weights.diversity_bias == 0.0 {
163            // Fast path: single-pass greedy.
164            for input in inputs {
165                if input.size <= budget.saturating_sub(total_size) {
166                    total_size += input.size;
167                    selected.push(input);
168                }
169            }
170        } else {
171            // Diversity path: pick-best-remaining with per-step effective score.
172            let mut remaining = inputs;
173            let mut category_counts: std::collections::BTreeMap<String, usize> =
174                std::collections::BTreeMap::new();
175
176            while !remaining.is_empty() && total_size < budget {
177                let best_idx = remaining
178                    .iter()
179                    .enumerate()
180                    .filter(|(_, item)| item.size <= budget.saturating_sub(total_size))
181                    .max_by(|(_, a), (_, b)| {
182                        let a_eff =
183                            effective_score(a, &category_counts, weights.diversity_bias, ew);
184                        let b_eff =
185                            effective_score(b, &category_counts, weights.diversity_bias, ew);
186                        a_eff
187                            .total_cmp(&b_eff)
188                            .then_with(|| b.size.cmp(&a.size))
189                            .then_with(|| a.id.cmp(&b.id))
190                    })
191                    .map(|(i, _)| i);
192
193                match best_idx {
194                    Some(idx) => {
195                        let item = remaining.swap_remove(idx);
196                        if let Some(ref cat) = item.category {
197                            *category_counts.entry(cat.clone()).or_default() += 1;
198                        }
199                        total_size += item.size;
200                        selected.push(item);
201                    }
202                    None => break,
203                }
204            }
205        }
206
207        Ok(SelectorOutput {
208            selected,
209            total_size,
210            budget,
211        })
212    }
213}
214
215#[cfg(test)]
216mod tests {
217    use super::*;
218
219    fn input(id: &str, size: usize, score: f32) -> SelectorInput<()> {
220        SelectorInput {
221            id: id.to_string(),
222            content: (),
223            size,
224            score,
225            category: None,
226            information_gain: None,
227        }
228    }
229
230    fn input_cat(id: &str, size: usize, score: f32, cat: &str) -> SelectorInput<()> {
231        SelectorInput {
232            id: id.to_string(),
233            content: (),
234            size,
235            score,
236            category: Some(cat.to_string()),
237            information_gain: None,
238        }
239    }
240
241    fn weights(min_score: f32) -> SelectorWeights {
242        SelectorWeights {
243            min_score,
244            ..Default::default()
245        }
246    }
247
248    #[test]
249    fn empty_input() {
250        let inputs: Vec<SelectorInput<()>> = vec![];
251        let out = GreedySelector.select(inputs, 1000, &weights(0.0)).unwrap();
252        assert!(out.selected.is_empty());
253        assert_eq!(out.total_size, 0);
254        assert_eq!(out.budget, 1000);
255    }
256
257    #[test]
258    fn packs_highest_scores_first() {
259        let inputs = vec![
260            input("a", 100, 0.5),
261            input("b", 100, 0.9),
262            input("c", 100, 0.7),
263        ];
264        let out = GreedySelector.select(inputs, 200, &weights(0.0)).unwrap();
265        assert_eq!(out.selected.len(), 2);
266        assert_eq!(out.selected[0].id, "b");
267        assert_eq!(out.selected[1].id, "c");
268        assert_eq!(out.total_size, 200);
269    }
270
271    #[test]
272    fn respects_budget() {
273        let inputs = vec![
274            input("a", 300, 0.9),
275            input("b", 300, 0.8),
276            input("c", 300, 0.7),
277        ];
278        let out = GreedySelector.select(inputs, 500, &weights(0.0)).unwrap();
279        assert_eq!(out.selected.len(), 1);
280        assert_eq!(out.selected[0].id, "a");
281        assert_eq!(out.total_size, 300);
282    }
283
284    #[test]
285    fn filters_below_min_score() {
286        let inputs = vec![
287            input("a", 10, 0.8),
288            input("b", 10, 0.1),
289            input("c", 10, 0.5),
290        ];
291        let out = GreedySelector.select(inputs, 1000, &weights(0.3)).unwrap();
292        assert_eq!(out.selected.len(), 2);
293        assert_eq!(out.selected[0].id, "a");
294        assert_eq!(out.selected[1].id, "c");
295    }
296
297    #[test]
298    fn filters_nan_and_inf() {
299        let inputs = vec![
300            input("nan", 10, f32::NAN),
301            input("inf", 10, f32::INFINITY),
302            input("neg_inf", 10, f32::NEG_INFINITY),
303            input("ok", 10, 0.5),
304        ];
305        let out = GreedySelector.select(inputs, 1000, &weights(0.0)).unwrap();
306        assert_eq!(out.selected.len(), 1);
307        assert_eq!(out.selected[0].id, "ok");
308    }
309
310    #[test]
311    fn tie_break_size_ascending() {
312        let inputs = vec![input("big", 200, 0.5), input("small", 50, 0.5)];
313        let out = GreedySelector.select(inputs, 1000, &weights(0.0)).unwrap();
314        assert_eq!(out.selected[0].id, "small");
315        assert_eq!(out.selected[1].id, "big");
316    }
317
318    #[test]
319    fn tie_break_id_ascending() {
320        let inputs = vec![input("z", 100, 0.5), input("a", 100, 0.5)];
321        let out = GreedySelector.select(inputs, 1000, &weights(0.0)).unwrap();
322        assert_eq!(out.selected[0].id, "a");
323        assert_eq!(out.selected[1].id, "z");
324    }
325
326    #[test]
327    fn skips_oversized_items_takes_smaller() {
328        let inputs = vec![
329            input("huge", 900, 0.9),
330            input("small1", 40, 0.3),
331            input("small2", 40, 0.2),
332        ];
333        let out = GreedySelector.select(inputs, 100, &weights(0.0)).unwrap();
334        assert_eq!(out.selected.len(), 2);
335        assert_eq!(out.selected[0].id, "small1");
336        assert_eq!(out.selected[1].id, "small2");
337        assert_eq!(out.total_size, 80);
338    }
339
340    #[test]
341    fn zero_budget() {
342        let inputs = vec![input("a", 1, 0.9)];
343        let out = GreedySelector.select(inputs, 0, &weights(0.0)).unwrap();
344        assert!(out.selected.is_empty());
345    }
346
347    #[test]
348    fn deterministic_across_input_order() {
349        let a = vec![
350            input("x", 50, 0.7),
351            input("y", 50, 0.7),
352            input("z", 50, 0.7),
353        ];
354        let b = vec![
355            input("z", 50, 0.7),
356            input("x", 50, 0.7),
357            input("y", 50, 0.7),
358        ];
359        let out_a = GreedySelector.select(a, 100, &weights(0.0)).unwrap();
360        let out_b = GreedySelector.select(b, 100, &weights(0.0)).unwrap();
361        let ids_a: Vec<&str> = out_a.selected.iter().map(|i| i.id.as_str()).collect();
362        let ids_b: Vec<&str> = out_b.selected.iter().map(|i| i.id.as_str()).collect();
363        assert_eq!(ids_a, ids_b);
364        assert_eq!(ids_a, vec!["x", "y"]);
365    }
366
367    #[test]
368    fn exact_budget_fit() {
369        let inputs = vec![input("a", 50, 0.9), input("b", 50, 0.8)];
370        let out = GreedySelector.select(inputs, 100, &weights(0.0)).unwrap();
371        assert_eq!(out.selected.len(), 2);
372        assert_eq!(out.total_size, 100);
373    }
374
375    #[test]
376    fn category_weights_boost_preferred_category() {
377        let inputs = vec![
378            input_cat("a", 100, 0.9, "low"),
379            input_cat("b", 100, 0.5, "high"),
380        ];
381        let w = SelectorWeights {
382            category_weights: [("high".to_string(), 2.0f32), ("low".to_string(), 1.0f32)]
383                .into_iter()
384                .collect(),
385            ..Default::default()
386        };
387        let out = GreedySelector.select(inputs, 100, &w).unwrap();
388        assert_eq!(out.selected.len(), 1);
389        assert_eq!(out.selected[0].id, "b");
390    }
391
392    #[test]
393    fn category_weights_can_push_below_min_score() {
394        let inputs = vec![
395            input_cat("a", 10, 0.4, "bad"),
396            input_cat("b", 10, 0.8, "good"),
397        ];
398        let w = SelectorWeights {
399            min_score: 0.3,
400            category_weights: [("bad".to_string(), 0.5f32)].into_iter().collect(),
401            ..Default::default()
402        };
403        let out = GreedySelector.select(inputs, 1000, &w).unwrap();
404        assert_eq!(out.selected.len(), 1);
405        assert_eq!(out.selected[0].id, "b");
406    }
407
408    #[test]
409    fn diversity_bias_zero_identical_to_greedy() {
410        let make = || {
411            vec![
412                input_cat("a", 100, 0.9, "x"),
413                input_cat("b", 100, 0.8, "x"),
414                input_cat("c", 100, 0.7, "y"),
415            ]
416        };
417        let w_greedy = SelectorWeights {
418            ..Default::default()
419        };
420        let w_bias0 = SelectorWeights {
421            diversity_bias: 0.0,
422            ..Default::default()
423        };
424        let out_g = GreedySelector.select(make(), 200, &w_greedy).unwrap();
425        let out_b = GreedySelector.select(make(), 200, &w_bias0).unwrap();
426        let ids_g: Vec<&str> = out_g.selected.iter().map(|i| i.id.as_str()).collect();
427        let ids_b: Vec<&str> = out_b.selected.iter().map(|i| i.id.as_str()).collect();
428        assert_eq!(ids_g, ids_b);
429    }
430
431    #[test]
432    fn diversity_bias_prefers_different_categories() {
433        let inputs = vec![
434            input_cat("a", 100, 0.9, "x"),
435            input_cat("b", 100, 0.8, "x"),
436            input_cat("c", 100, 0.7, "y"),
437        ];
438        let w = SelectorWeights {
439            diversity_bias: 1.0,
440            ..Default::default()
441        };
442        let out = GreedySelector.select(inputs, 200, &w).unwrap();
443        assert_eq!(out.selected.len(), 2);
444        let ids: Vec<&str> = out.selected.iter().map(|i| i.id.as_str()).collect();
445        assert!(ids.contains(&"a"), "a should always be selected");
446        assert!(
447            ids.contains(&"c"),
448            "c should be preferred over b due to diversity"
449        );
450    }
451
452    #[test]
453    fn no_overflow_near_usize_max() {
454        // Items with near-usize::MAX sizes must not overflow when checking budget.
455        let large = usize::MAX - 1;
456        let inputs = vec![
457            SelectorInput {
458                id: "a".to_string(),
459                content: (),
460                size: large,
461                score: 0.9,
462                category: None,
463                information_gain: None,
464            },
465            SelectorInput {
466                id: "b".to_string(),
467                content: (),
468                size: 10,
469                score: 0.8,
470                category: None,
471                information_gain: None,
472            },
473        ];
474        // Budget is 100 — only item "b" fits.
475        let out = GreedySelector.select(inputs, 100, &weights(0.0)).unwrap();
476        assert_eq!(out.selected.len(), 1);
477        assert_eq!(out.selected[0].id, "b");
478    }
479
480    #[test]
481    fn diversity_bias_no_categories_unaffected() {
482        let inputs = vec![
483            input("a", 100, 0.9),
484            input("b", 100, 0.8),
485            input("c", 100, 0.7),
486        ];
487        let w = SelectorWeights {
488            diversity_bias: 1.0,
489            ..Default::default()
490        };
491        let out = GreedySelector.select(inputs, 200, &w).unwrap();
492        assert_eq!(out.selected.len(), 2);
493        assert_eq!(out.selected[0].id, "a");
494        assert_eq!(out.selected[1].id, "b");
495    }
496
497    // ── ADR-059: epistemic weight tests ──────────────────────────────────────
498
499    fn input_with_gain(id: &str, size: usize, score: f32, gain: f32) -> SelectorInput<()> {
500        SelectorInput {
501            id: id.to_string(),
502            content: (),
503            size,
504            score,
505            category: None,
506            information_gain: Some(gain),
507        }
508    }
509
510    #[test]
511    fn epistemic_weight_zero_preserves_behavior() {
512        // With epistemic_weight=0, result must be identical to the default (no epistemic).
513        let make = || {
514            vec![
515                input_with_gain("a", 100, 0.9, 10.0),
516                input_with_gain("b", 100, 0.8, 0.0),
517                input_with_gain("c", 100, 0.7, 5.0),
518            ]
519        };
520        let w_default = SelectorWeights {
521            ..Default::default()
522        };
523        let w_zero = SelectorWeights {
524            epistemic_weight: 0.0,
525            ..Default::default()
526        };
527        let out_d = GreedySelector.select(make(), 200, &w_default).unwrap();
528        let out_z = GreedySelector.select(make(), 200, &w_zero).unwrap();
529        let ids_d: Vec<&str> = out_d.selected.iter().map(|i| i.id.as_str()).collect();
530        let ids_z: Vec<&str> = out_z.selected.iter().map(|i| i.id.as_str()).collect();
531        assert_eq!(ids_d, ids_z);
532        // Pure score order: a (0.9), b (0.8).
533        assert_eq!(ids_d, vec!["a", "b"]);
534    }
535
536    #[test]
537    fn epistemic_weight_positive_reorders_by_gain() {
538        // a: score=0.5, gain=10.0  → effective = 0.5 + 1.0 * 10.0 = 10.5
539        // b: score=0.9, gain=0.0   → effective = 0.9 + 1.0 * 0.0  = 0.9
540        // With epistemic_weight=1.0, a should be selected first.
541        let inputs = vec![
542            input_with_gain("a", 100, 0.5, 10.0),
543            input_with_gain("b", 100, 0.9, 0.0),
544        ];
545        let w = SelectorWeights {
546            epistemic_weight: 1.0,
547            ..Default::default()
548        };
549        let out = GreedySelector.select(inputs, 100, &w).unwrap();
550        assert_eq!(out.selected.len(), 1);
551        assert_eq!(out.selected[0].id, "a");
552    }
553
554    #[test]
555    fn information_gain_none_equivalent_to_zero() {
556        // None and Some(0.0) must produce identical ordering.
557        let with_none = vec![
558            input("a", 100, 0.9), // information_gain: None
559            input("b", 100, 0.8),
560        ];
561        let with_zero = vec![
562            input_with_gain("a", 100, 0.9, 0.0),
563            input_with_gain("b", 100, 0.8, 0.0),
564        ];
565        let w = SelectorWeights {
566            epistemic_weight: 1.0,
567            ..Default::default()
568        };
569        let out_none = GreedySelector.select(with_none, 200, &w).unwrap();
570        let out_zero = GreedySelector.select(with_zero, 200, &w).unwrap();
571        let ids_none: Vec<&str> = out_none.selected.iter().map(|i| i.id.as_str()).collect();
572        let ids_zero: Vec<&str> = out_zero.selected.iter().map(|i| i.id.as_str()).collect();
573        assert_eq!(ids_none, ids_zero);
574    }
575
576    #[test]
577    fn epistemic_weight_works_with_diversity_bias() {
578        // Combines epistemic and diversity: the effective score incorporates both.
579        // a: score=0.5, gain=10.0, category=x → base effective = 0.5 + 1.0 * 10.0 = 10.5
580        // b: score=0.8, gain=0.0,  category=x → base effective = 0.8
581        // c: score=0.3, gain=0.0,  category=y → base effective = 0.3
582        // Budget=200, bias=0.5: a selected first (10.5 wins), then after a is in x,
583        // b's diversity penalty is 0.8*(1-0.5*1/2)=0.8*0.75=0.6 vs c at 0.3 — b wins.
584        let inputs = vec![
585            {
586                let mut i = input_with_gain("a", 100, 0.5, 10.0);
587                i.category = Some("x".to_string());
588                i
589            },
590            {
591                let mut i = input_with_gain("b", 100, 0.8, 0.0);
592                i.category = Some("x".to_string());
593                i
594            },
595            {
596                let mut i = input_with_gain("c", 100, 0.3, 0.0);
597                i.category = Some("y".to_string());
598                i
599            },
600        ];
601        let w = SelectorWeights {
602            epistemic_weight: 1.0,
603            diversity_bias: 0.5,
604            ..Default::default()
605        };
606        let out = GreedySelector.select(inputs, 200, &w).unwrap();
607        assert_eq!(out.selected.len(), 2);
608        assert_eq!(out.selected[0].id, "a");
609        // b (eff=0.8*0.75=0.6) > c (eff=0.3) after a is placed in category x.
610        assert_eq!(out.selected[1].id, "b");
611    }
612}