1#[derive(Debug, Clone, PartialEq, Eq, Hash)]
46pub enum HintContext {
47 Global,
49 Widget(String),
51 Mode(String),
53}
54
55#[derive(Debug, Clone)]
57pub struct HintStats {
58 pub alpha: f64,
60 pub beta: f64,
62 pub cost: f64,
64 pub static_priority: u32,
66 pub observations: u64,
68}
69
70impl HintStats {
71 #[inline]
73 pub fn expected_utility(&self) -> f64 {
74 self.alpha / (self.alpha + self.beta)
75 }
76
77 #[inline]
79 pub fn variance(&self) -> f64 {
80 let sum = self.alpha + self.beta;
81 (self.alpha * self.beta) / (sum * sum * (sum + 1.0))
82 }
83
84 #[inline]
86 pub fn voi(&self) -> f64 {
87 self.variance().sqrt()
88 }
89}
90
91#[derive(Debug, Clone)]
93pub struct HintEntry {
94 pub id: usize,
96 pub label: String,
98 pub cost: f64,
100 pub context: HintContext,
102 pub stats: HintStats,
104}
105
106#[derive(Debug, Clone)]
108pub struct RankingEvidence {
109 pub id: usize,
110 pub label: String,
111 pub expected_utility: f64,
112 pub cost: f64,
113 pub net_value: f64,
114 pub voi: f64,
115 pub rank: usize,
116}
117
118impl RankingEvidence {
119 #[must_use]
121 pub fn to_jsonl(&self) -> String {
122 format!(
123 r#"{{"schema":"hint-ranking-v1","id":{},"label":"{}","expected_utility":{:.6},"cost":{:.4},"net_value":{:.6},"voi":{:.6},"rank":{}}}"#,
124 self.id,
125 self.label,
126 self.expected_utility,
127 self.cost,
128 self.net_value,
129 self.voi,
130 self.rank,
131 )
132 }
133}
134
135#[derive(Debug, Clone)]
137pub struct RankerConfig {
138 pub prior_alpha: f64,
140 pub prior_beta: f64,
142 pub lambda: f64,
144 pub hysteresis: f64,
146 pub voi_weight: f64,
148}
149
150impl Default for RankerConfig {
151 fn default() -> Self {
152 Self {
153 prior_alpha: 1.0,
154 prior_beta: 1.0,
155 lambda: 0.01,
156 hysteresis: 0.02,
157 voi_weight: 0.1,
158 }
159 }
160}
161
162#[derive(Debug, Clone)]
164pub struct HintRanker {
165 config: RankerConfig,
166 hints: Vec<HintEntry>,
167 last_ordering: Vec<usize>,
169 last_context: Option<String>,
171}
172
173impl HintRanker {
174 pub fn new(config: RankerConfig) -> Self {
176 Self {
177 config,
178 hints: Vec::new(),
179 last_ordering: Vec::new(),
180 last_context: None,
181 }
182 }
183
184 pub fn register(
186 &mut self,
187 label: impl Into<String>,
188 cost_columns: f64,
189 context: HintContext,
190 static_priority: u32,
191 ) -> usize {
192 let id = self.hints.len();
193 self.hints.push(HintEntry {
194 id,
195 label: label.into(),
196 cost: cost_columns,
197 context,
198 stats: HintStats {
199 alpha: self.config.prior_alpha,
200 beta: self.config.prior_beta,
201 cost: cost_columns,
202 static_priority,
203 observations: 0,
204 },
205 });
206 id
207 }
208
209 pub fn record_usage(&mut self, hint_id: usize) {
211 if let Some(h) = self.hints.get_mut(hint_id) {
212 h.stats.alpha += 1.0;
213 h.stats.observations += 1;
214 }
215 }
216
217 pub fn record_shown_not_used(&mut self, hint_id: usize) {
219 if let Some(h) = self.hints.get_mut(hint_id) {
220 h.stats.beta += 1.0;
221 h.stats.observations += 1;
222 }
223 }
224
225 fn net_value(&self, h: &HintEntry) -> f64 {
227 let eu = h.stats.expected_utility();
228 let voi = h.stats.voi();
229 eu + self.config.voi_weight * voi - self.config.lambda * h.cost
230 }
231
232 pub fn rank(&mut self, context_key: Option<&str>) -> (Vec<usize>, Vec<RankingEvidence>) {
236 let context_str = context_key.map(String::from);
237
238 let mut candidates: Vec<(usize, f64)> = self
240 .hints
241 .iter()
242 .filter(|h| match (&h.context, context_key) {
243 (HintContext::Global, _) => true,
244 (HintContext::Widget(w), Some(ctx)) => w == ctx,
245 (HintContext::Mode(m), Some(ctx)) => m == ctx,
246 _ => context_key.is_none(), })
248 .map(|h| {
249 let v = if h.stats.observations == 0 {
250 -(h.stats.static_priority as f64)
252 } else {
253 self.net_value(h)
254 };
255 (h.id, v)
256 })
257 .collect();
258
259 candidates.sort_by(|a, b| {
261 b.1.partial_cmp(&a.1)
262 .unwrap_or(std::cmp::Ordering::Equal)
263 .then_with(|| a.0.cmp(&b.0))
264 });
265
266 let new_ordering: Vec<usize> = candidates.iter().map(|(id, _)| *id).collect();
267
268 let ordering = if self.last_context == context_str && !self.last_ordering.is_empty() {
270 self.apply_hysteresis(&new_ordering, &candidates)
271 } else {
272 new_ordering.clone()
273 };
274
275 let ledger: Vec<RankingEvidence> = ordering
277 .iter()
278 .enumerate()
279 .map(|(rank, &id)| {
280 let h = &self.hints[id];
281 RankingEvidence {
282 id,
283 label: h.label.clone(),
284 expected_utility: h.stats.expected_utility(),
285 cost: h.cost,
286 net_value: self.net_value(h),
287 voi: h.stats.voi(),
288 rank,
289 }
290 })
291 .collect();
292
293 self.last_ordering = ordering.clone();
294 self.last_context = context_str;
295
296 (ordering, ledger)
297 }
298
299 fn apply_hysteresis(&self, new_order: &[usize], scores: &[(usize, f64)]) -> Vec<usize> {
301 let score_map: std::collections::HashMap<usize, f64> = scores.iter().copied().collect();
303
304 let mut result = self.last_ordering.clone();
305
306 result.retain(|id| new_order.contains(id));
308
309 for &id in new_order {
311 if !result.contains(&id) {
312 result.push(id);
313 }
314 }
315
316 let eps = self.config.hysteresis;
318 let mut changed = true;
319 while changed {
320 changed = false;
321 for i in 0..result.len().saturating_sub(1) {
322 let a = result[i];
323 let b = result[i + 1];
324 let sa = score_map.get(&a).copied().unwrap_or(f64::NEG_INFINITY);
325 let sb = score_map.get(&b).copied().unwrap_or(f64::NEG_INFINITY);
326 if sb > sa + eps {
327 result.swap(i, i + 1);
328 changed = true;
329 }
330 }
331 }
332
333 result
334 }
335
336 pub fn top_n(&mut self, n: usize, context_key: Option<&str>) -> Vec<&HintEntry> {
338 let (ordering, _) = self.rank(context_key);
339 ordering
340 .into_iter()
341 .take(n)
342 .filter_map(|id| self.hints.get(id))
343 .collect()
344 }
345
346 #[must_use = "use the returned stats (if any)"]
348 pub fn stats(&self, id: usize) -> Option<&HintStats> {
349 self.hints.get(id).map(|h| &h.stats)
350 }
351
352 pub fn hint_count(&self) -> usize {
354 self.hints.len()
355 }
356}
357
358impl Default for HintRanker {
359 fn default() -> Self {
360 Self::new(RankerConfig::default())
361 }
362}
363
364#[cfg(test)]
365mod tests {
366 use super::*;
367
368 fn make_ranker() -> HintRanker {
369 let mut r = HintRanker::new(RankerConfig::default());
370 r.register("Ctrl+S Save", 12.0, HintContext::Global, 1);
371 r.register("Ctrl+Z Undo", 12.0, HintContext::Global, 2);
372 r.register("Ctrl+F Find", 12.0, HintContext::Global, 3);
373 r.register("Tab Complete", 13.0, HintContext::Widget("input".into()), 4);
374 r.register("Esc Cancel", 11.0, HintContext::Global, 5);
375 r
376 }
377
378 #[test]
379 fn empty_ranker_returns_empty() {
380 let mut r = HintRanker::default();
381 let (ordering, ledger) = r.rank(None);
382 assert!(ordering.is_empty());
383 assert!(ledger.is_empty());
384 }
385
386 #[test]
387 fn cold_start_uses_static_priority() {
388 let mut r = make_ranker();
389 let (ordering, _) = r.rank(None);
390 assert_eq!(ordering[0], 0); assert_eq!(ordering[1], 1); assert_eq!(ordering[2], 2); }
399
400 #[test]
401 fn unit_prior_updates() {
402 let mut r = HintRanker::default();
403 let id = r.register("test", 10.0, HintContext::Global, 1);
404
405 assert!((r.stats(id).unwrap().expected_utility() - 0.5).abs() < 1e-10);
407
408 for _ in 0..4 {
410 r.record_usage(id);
411 }
412 assert!((r.stats(id).unwrap().expected_utility() - 5.0 / 6.0).abs() < 1e-10);
414
415 for _ in 0..2 {
417 r.record_shown_not_used(id);
418 }
419 assert!((r.stats(id).unwrap().expected_utility() - 5.0 / 8.0).abs() < 1e-10);
421 }
422
423 #[test]
424 fn unit_ranking_stability() {
425 let mut r = HintRanker::new(RankerConfig {
426 hysteresis: 0.05,
427 ..Default::default()
428 });
429 let a = r.register("A", 10.0, HintContext::Global, 1);
430 let b = r.register("B", 10.0, HintContext::Global, 2);
431
432 for _ in 0..20 {
434 r.record_usage(a);
435 }
436 for _ in 0..10 {
437 r.record_usage(b);
438 }
439
440 let (order1, _) = r.rank(None);
441 assert_eq!(order1[0], a);
442 assert_eq!(order1[1], b);
443
444 r.record_usage(b);
446 let (order2, _) = r.rank(None);
447 assert_eq!(order2[0], a, "hysteresis should prevent flicker");
448 }
449
450 #[test]
451 fn context_filtering() {
452 let mut r = make_ranker();
453 let (ordering, _) = r.rank(Some("input"));
455 assert!(ordering.contains(&3), "input widget hint should appear");
457
458 let (ordering2, _) = r.rank(Some("list"));
460 assert!(
462 !ordering2.contains(&3),
463 "input widget hint should not appear for list"
464 );
465 }
466
467 #[test]
468 fn property_context_switch_reranks() {
469 let mut r = make_ranker();
470
471 for _ in 0..10 {
473 r.record_usage(0); }
475 for _ in 0..5 {
476 r.record_usage(2); }
478
479 let (order_none, _) = r.rank(None);
480 let (order_list, _) = r.rank(Some("list"));
481
482 assert!(
484 order_none.contains(&3),
485 "None context should include input widget hint"
486 );
487 assert!(
488 !order_list.contains(&3),
489 "list context should exclude input widget hint"
490 );
491 }
492
493 #[test]
494 fn voi_exploration_bonus() {
495 let mut r = HintRanker::new(RankerConfig {
496 voi_weight: 1.0, lambda: 0.0, hysteresis: 0.0,
499 ..Default::default()
500 });
501 let a = r.register("A", 10.0, HintContext::Global, 1);
502 let _b = r.register("B", 10.0, HintContext::Global, 2);
503
504 for _ in 0..100 {
506 r.record_usage(a);
507 r.record_shown_not_used(a);
508 }
509
510 let (ordering, _) = r.rank(None);
511 let a_eu = r.stats(a).unwrap().expected_utility();
518 let a_voi = r.stats(a).unwrap().voi();
519 assert!(a_eu > 0.4); assert!(a_voi < 0.1); assert_eq!(ordering[0], a);
525 }
526
527 #[test]
528 fn top_n_returns_limited() {
529 let mut r = make_ranker();
530 let top = r.top_n(2, None);
531 assert_eq!(top.len(), 2);
532 }
533
534 #[test]
535 fn deterministic_under_same_history() {
536 let run = || {
537 let mut r = make_ranker();
538 r.record_usage(0);
539 r.record_usage(0);
540 r.record_usage(2);
541 r.record_shown_not_used(1);
542 r.record_shown_not_used(4);
543 let (ordering, _) = r.rank(None);
544 ordering
545 };
546
547 assert_eq!(run(), run());
548 }
549
550 #[test]
551 fn ledger_records_all_ranked_hints() {
552 let mut r = make_ranker();
553 for _ in 0..5 {
554 r.record_usage(0);
555 }
556 let (ordering, ledger) = r.rank(None);
557 assert_eq!(ordering.len(), ledger.len());
558
559 for (i, entry) in ledger.iter().enumerate() {
561 assert_eq!(entry.rank, i);
562 }
563 }
564
565 #[test]
566 fn usage_promotes_hint() {
567 let mut r = HintRanker::new(RankerConfig {
568 hysteresis: 0.0,
569 ..Default::default()
570 });
571 let a = r.register("A", 10.0, HintContext::Global, 2); let b = r.register("B", 10.0, HintContext::Global, 1); let (order1, _) = r.rank(None);
576 assert_eq!(order1[0], b);
577
578 for _ in 0..20 {
580 r.record_usage(a);
581 }
582 r.record_shown_not_used(b);
584
585 let (order2, _) = r.rank(None);
586 assert_eq!(order2[0], a, "heavy usage should promote A above B");
587 }
588
589 #[test]
590 fn ranker_config_defaults() {
591 let cfg = RankerConfig::default();
592 assert!((cfg.prior_alpha - 1.0).abs() < f64::EPSILON);
593 assert!((cfg.prior_beta - 1.0).abs() < f64::EPSILON);
594 assert!((cfg.lambda - 0.01).abs() < f64::EPSILON);
595 assert!((cfg.hysteresis - 0.02).abs() < f64::EPSILON);
596 assert!((cfg.voi_weight - 0.1).abs() < f64::EPSILON);
597 }
598
599 #[test]
600 fn hint_ranker_default_is_empty() {
601 let r = HintRanker::default();
602 assert_eq!(r.hint_count(), 0);
603 }
604
605 #[test]
606 fn hint_count_tracks_registrations() {
607 let mut r = HintRanker::default();
608 assert_eq!(r.hint_count(), 0);
609 r.register("A", 10.0, HintContext::Global, 1);
610 assert_eq!(r.hint_count(), 1);
611 r.register("B", 5.0, HintContext::Global, 2);
612 assert_eq!(r.hint_count(), 2);
613 }
614
615 #[test]
616 fn stats_returns_none_for_invalid_id() {
617 let r = HintRanker::default();
618 assert!(r.stats(0).is_none());
619 assert!(r.stats(999).is_none());
620 }
621
622 #[test]
623 fn record_usage_invalid_id_is_noop() {
624 let mut r = HintRanker::default();
625 r.record_usage(0);
627 r.record_usage(999);
628 assert_eq!(r.hint_count(), 0);
629 }
630
631 #[test]
632 fn record_shown_not_used_invalid_id_is_noop() {
633 let mut r = HintRanker::default();
634 r.record_shown_not_used(0);
635 r.record_shown_not_used(42);
636 assert_eq!(r.hint_count(), 0);
637 }
638
639 #[test]
640 fn variance_and_voi_computation() {
641 let s = HintStats {
642 alpha: 3.0,
643 beta: 7.0,
644 cost: 10.0,
645 static_priority: 1,
646 observations: 10,
647 };
648 assert!((s.expected_utility() - 0.3).abs() < 1e-10);
650 let expected_var = 21.0 / 1100.0;
652 assert!((s.variance() - expected_var).abs() < 1e-10);
653 assert!((s.voi() - expected_var.sqrt()).abs() < 1e-10);
655 }
656
657 #[test]
658 fn observations_track_both_usage_and_nonusage() {
659 let mut r = HintRanker::default();
660 let id = r.register("test", 10.0, HintContext::Global, 1);
661 r.record_usage(id);
662 r.record_usage(id);
663 r.record_shown_not_used(id);
664 let s = r.stats(id).unwrap();
665 assert_eq!(s.observations, 3);
666 assert!((s.alpha - 3.0).abs() < f64::EPSILON); assert!((s.beta - 2.0).abs() < f64::EPSILON); }
669
670 #[test]
671 fn mode_context_filtering() {
672 let mut r = HintRanker::new(RankerConfig {
673 hysteresis: 0.0,
674 ..Default::default()
675 });
676 let g = r.register("Global", 10.0, HintContext::Global, 1);
677 let ins = r.register("Insert", 10.0, HintContext::Mode("insert".into()), 2);
678 let norm = r.register("Normal", 10.0, HintContext::Mode("normal".into()), 3);
679
680 let (order, _) = r.rank(Some("insert"));
682 assert!(order.contains(&g));
683 assert!(order.contains(&ins));
684 assert!(
685 !order.contains(&norm),
686 "normal mode hint should not appear in insert context"
687 );
688
689 let (order2, _) = r.rank(Some("normal"));
691 assert!(order2.contains(&g));
692 assert!(order2.contains(&norm));
693 assert!(
694 !order2.contains(&ins),
695 "insert mode hint should not appear in normal context"
696 );
697 }
698
699 #[test]
700 fn high_lambda_penalises_costly_hints() {
701 let mut r = HintRanker::new(RankerConfig {
702 lambda: 1.0, hysteresis: 0.0,
704 voi_weight: 0.0,
705 ..Default::default()
706 });
707 let cheap = r.register("Cheap", 1.0, HintContext::Global, 2);
708 let expensive = r.register("Expensive", 100.0, HintContext::Global, 1);
709
710 for _ in 0..10 {
712 r.record_usage(cheap);
713 r.record_usage(expensive);
714 }
715
716 let (order, _) = r.rank(None);
717 assert_eq!(
718 order[0], cheap,
719 "cheap hint should rank first with high lambda"
720 );
721 }
722
723 #[test]
724 fn ledger_fields_are_accurate() {
725 let mut r = HintRanker::new(RankerConfig {
726 hysteresis: 0.0,
727 ..Default::default()
728 });
729 let id = r.register("Ctrl+X Cut", 11.0, HintContext::Global, 1);
730 for _ in 0..5 {
731 r.record_usage(id);
732 }
733
734 let (_, ledger) = r.rank(None);
735 assert_eq!(ledger.len(), 1);
736 let entry = &ledger[0];
737 assert_eq!(entry.id, id);
738 assert_eq!(entry.label, "Ctrl+X Cut");
739 assert!((entry.cost - 11.0).abs() < f64::EPSILON);
740 assert_eq!(entry.rank, 0);
741 assert!((entry.expected_utility - 6.0 / 7.0).abs() < 1e-10);
743 assert!(entry.voi > 0.0);
744 }
745
746 #[test]
747 fn hysteresis_with_new_hint_appearing() {
748 let mut r = HintRanker::new(RankerConfig {
749 hysteresis: 0.05,
750 ..Default::default()
751 });
752 let a = r.register("A", 10.0, HintContext::Global, 1);
753
754 let (order1, _) = r.rank(None);
756 assert_eq!(order1, vec![a]);
757
758 let b = r.register("B", 10.0, HintContext::Global, 2);
760 let (order2, _) = r.rank(None);
761 assert!(order2.contains(&a));
762 assert!(order2.contains(&b));
763 }
764
765 #[test]
766 fn top_n_with_zero_returns_empty() {
767 let mut r = make_ranker();
768 let top = r.top_n(0, None);
769 assert!(top.is_empty());
770 }
771
772 #[test]
773 fn top_n_exceeding_count_returns_all() {
774 let mut r = make_ranker();
775 let all = r.top_n(100, None);
776 assert_eq!(all.len(), 5); }
778
779 #[test]
780 fn register_returns_sequential_ids() {
781 let mut r = HintRanker::default();
782 assert_eq!(r.register("A", 1.0, HintContext::Global, 1), 0);
783 assert_eq!(r.register("B", 1.0, HintContext::Global, 2), 1);
784 assert_eq!(r.register("C", 1.0, HintContext::Global, 3), 2);
785 }
786
787 #[test]
788 fn zero_cost_hint_net_value() {
789 let mut r = HintRanker::new(RankerConfig {
790 lambda: 0.5,
791 hysteresis: 0.0,
792 voi_weight: 0.0,
793 ..Default::default()
794 });
795 let id = r.register("Free", 0.0, HintContext::Global, 1);
796 for _ in 0..10 {
797 r.record_usage(id);
798 }
799 let (_, ledger) = r.rank(None);
801 assert!((ledger[0].net_value - 11.0 / 12.0).abs() < 1e-10);
802 }
803
804 #[test]
805 fn repeated_rank_same_context_uses_hysteresis_path() {
806 let mut r = HintRanker::new(RankerConfig {
807 hysteresis: 0.5, voi_weight: 0.0,
809 ..Default::default()
810 });
811 let a = r.register("A", 10.0, HintContext::Global, 1);
812 let b = r.register("B", 10.0, HintContext::Global, 2);
813
814 for _ in 0..10 {
816 r.record_usage(a);
817 }
818 for _ in 0..5 {
819 r.record_usage(b);
820 }
821
822 let (order1, _) = r.rank(Some("ctx"));
824 assert_eq!(order1[0], a);
825
826 r.record_usage(b);
829 let (order2, _) = r.rank(Some("ctx"));
830 assert_eq!(order2[0], a, "hysteresis should stabilize ordering");
831 }
832
833 #[test]
834 fn hint_context_equality() {
835 assert_eq!(HintContext::Global, HintContext::Global);
836 assert_eq!(
837 HintContext::Widget("foo".into()),
838 HintContext::Widget("foo".into())
839 );
840 assert_ne!(
841 HintContext::Widget("foo".into()),
842 HintContext::Mode("foo".into())
843 );
844 assert_ne!(HintContext::Global, HintContext::Mode("x".into()));
845 }
846}