Skip to main content

oxirs_graphrag/
feedback.rs

1//! Interactive feedback loop for graph-based RAG retrieval refinement.
2//!
3//! Users can mark triples as relevant (positive feedback) or irrelevant
4//! (negative feedback).  The `FeedbackSession` adjusts per-triple weights
5//! multiplicatively so that subsequent retrievals favour positively-rated
6//! triples and suppress negatively-rated ones.
7//!
8//! # v0.4.0 additions
9//!
10//! [`TripleRelevanceFeedback`] provides a seahash-keyed, multiplicative-weight
11//! session with a typed [`Relevance`] enum and a sorted `apply_to_scores`
12//! method, suitable for single-session adaptive retrieval.
13
14use serde::{Deserialize, Serialize};
15use std::collections::HashMap;
16
17// ─────────────────────────────────────────────────────────────────────────────
18// Feedback types
19// ─────────────────────────────────────────────────────────────────────────────
20
21/// The valence of a feedback signal.
22#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
23pub enum FeedbackKind {
24    Positive,
25    Negative,
26}
27
28/// A single feedback event from the user.
29#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct FeedbackEvent {
31    /// Fingerprint of the triple: `"{subject}|{predicate}|{object}"`.
32    pub triple_key: String,
33    /// Whether the user found this triple useful.
34    pub kind: FeedbackKind,
35    /// Optional textual note from the user.
36    pub note: Option<String>,
37}
38
39impl FeedbackEvent {
40    /// Construct a feedback event for the given triple.
41    pub fn new(subject: &str, predicate: &str, object: &str, kind: FeedbackKind) -> Self {
42        Self {
43            triple_key: triple_key(subject, predicate, object),
44            kind,
45            note: None,
46        }
47    }
48
49    pub fn with_note(mut self, note: impl Into<String>) -> Self {
50        self.note = Some(note.into());
51        self
52    }
53}
54
55fn triple_key(s: &str, p: &str, o: &str) -> String {
56    format!("{}|{}|{}", s, p, o)
57}
58
59// ─────────────────────────────────────────────────────────────────────────────
60// Weight configuration
61// ─────────────────────────────────────────────────────────────────────────────
62
63/// Hyperparameters for weight adjustment.
64#[derive(Debug, Clone, Serialize, Deserialize)]
65pub struct FeedbackConfig {
66    /// Multiplicative boost applied per positive feedback event.
67    pub positive_factor: f64,
68    /// Multiplicative penalty applied per negative feedback event.
69    pub negative_factor: f64,
70    /// Minimum allowed weight (prevents complete suppression).
71    pub min_weight: f64,
72    /// Maximum allowed weight (prevents runaway amplification).
73    pub max_weight: f64,
74}
75
76impl Default for FeedbackConfig {
77    fn default() -> Self {
78        Self {
79            positive_factor: 1.5,
80            negative_factor: 0.6,
81            min_weight: 0.01,
82            max_weight: 10.0,
83        }
84    }
85}
86
87// ─────────────────────────────────────────────────────────────────────────────
88// FeedbackSession
89// ─────────────────────────────────────────────────────────────────────────────
90
91/// Maintains per-triple learned weights across a user session.
92///
93/// Weights start at 1.0 and are adjusted multiplicatively with each
94/// feedback signal.  Use [`FeedbackSession::apply_weights`] to rescore a result set
95/// before the next retrieval round.
96#[derive(Debug, Clone, Serialize, Deserialize)]
97pub struct FeedbackSession {
98    /// Per-triple multiplicative weight, keyed by triple fingerprint.
99    weights: HashMap<String, f64>,
100    /// Ordered history of all feedback events.
101    history: Vec<FeedbackEvent>,
102    /// Configuration hyperparameters.
103    pub config: FeedbackConfig,
104}
105
106impl FeedbackSession {
107    pub fn new() -> Self {
108        Self::with_config(FeedbackConfig::default())
109    }
110
111    pub fn with_config(config: FeedbackConfig) -> Self {
112        Self {
113            weights: HashMap::new(),
114            history: Vec::new(),
115            config,
116        }
117    }
118
119    /// Record a feedback event and update the triple's weight.
120    pub fn record(&mut self, event: FeedbackEvent) {
121        let factor = match event.kind {
122            FeedbackKind::Positive => self.config.positive_factor,
123            FeedbackKind::Negative => self.config.negative_factor,
124        };
125        let w = self.weights.entry(event.triple_key.clone()).or_insert(1.0);
126        *w = (*w * factor).clamp(self.config.min_weight, self.config.max_weight);
127        self.history.push(event);
128    }
129
130    /// Convenience: record a positive signal for a triple.
131    pub fn like(&mut self, subject: &str, predicate: &str, object: &str) {
132        self.record(FeedbackEvent::new(
133            subject,
134            predicate,
135            object,
136            FeedbackKind::Positive,
137        ));
138    }
139
140    /// Convenience: record a negative signal for a triple.
141    pub fn dislike(&mut self, subject: &str, predicate: &str, object: &str) {
142        self.record(FeedbackEvent::new(
143            subject,
144            predicate,
145            object,
146            FeedbackKind::Negative,
147        ));
148    }
149
150    /// Return the learned weight for a triple (1.0 if unseen).
151    pub fn weight(&self, subject: &str, predicate: &str, object: &str) -> f64 {
152        let key = triple_key(subject, predicate, object);
153        *self.weights.get(&key).unwrap_or(&1.0)
154    }
155
156    /// Apply learned weights to a set of `(triple_key, score)` pairs.
157    ///
158    /// The adjusted score is `original_score * weight`.
159    pub fn apply_weights(&self, scores: &HashMap<String, f64>) -> HashMap<String, f64> {
160        scores
161            .iter()
162            .map(|(k, &v)| {
163                let w = self.weights.get(k).copied().unwrap_or(1.0);
164                (k.clone(), v * w)
165            })
166            .collect()
167    }
168
169    /// Return the full event history.
170    pub fn history(&self) -> &[FeedbackEvent] {
171        &self.history
172    }
173
174    /// Count positive feedback events.
175    pub fn positive_count(&self) -> usize {
176        self.history
177            .iter()
178            .filter(|e| e.kind == FeedbackKind::Positive)
179            .count()
180    }
181
182    /// Count negative feedback events.
183    pub fn negative_count(&self) -> usize {
184        self.history
185            .iter()
186            .filter(|e| e.kind == FeedbackKind::Negative)
187            .count()
188    }
189
190    /// Reset all learned weights and clear history.
191    pub fn reset(&mut self) {
192        self.weights.clear();
193        self.history.clear();
194    }
195}
196
197impl Default for FeedbackSession {
198    fn default() -> Self {
199        Self::new()
200    }
201}
202
203// ─────────────────────────────────────────────────────────────────────────────
204// Tests
205// ─────────────────────────────────────────────────────────────────────────────
206
207#[cfg(test)]
208mod tests {
209    use super::*;
210
211    #[test]
212    fn test_initial_weight_is_one() {
213        let session = FeedbackSession::new();
214        assert_eq!(session.weight("A", "p", "B"), 1.0);
215    }
216
217    #[test]
218    fn test_positive_feedback_increases_weight() {
219        let mut session = FeedbackSession::new();
220        session.like("A", "p", "B");
221        assert!(session.weight("A", "p", "B") > 1.0);
222    }
223
224    #[test]
225    fn test_negative_feedback_decreases_weight() {
226        let mut session = FeedbackSession::new();
227        session.dislike("A", "p", "B");
228        assert!(session.weight("A", "p", "B") < 1.0);
229    }
230
231    #[test]
232    fn test_repeated_positive_capped_at_max() {
233        let mut session = FeedbackSession::new();
234        for _ in 0..100 {
235            session.like("A", "p", "B");
236        }
237        assert!(session.weight("A", "p", "B") <= session.config.max_weight);
238    }
239
240    #[test]
241    fn test_repeated_negative_capped_at_min() {
242        let mut session = FeedbackSession::new();
243        for _ in 0..100 {
244            session.dislike("A", "p", "B");
245        }
246        assert!(session.weight("A", "p", "B") >= session.config.min_weight);
247    }
248
249    #[test]
250    fn test_apply_weights() {
251        let mut session = FeedbackSession::new();
252        session.like("A", "knows", "B");
253        let key = "A|knows|B".to_string();
254        let mut scores = HashMap::new();
255        scores.insert(key.clone(), 0.5);
256        let adjusted = session.apply_weights(&scores);
257        assert!(adjusted[&key] > 0.5);
258    }
259
260    #[test]
261    fn test_event_history() {
262        let mut session = FeedbackSession::new();
263        session.like("X", "p", "Y");
264        session.dislike("X", "q", "Z");
265        assert_eq!(session.history().len(), 2);
266        assert_eq!(session.positive_count(), 1);
267        assert_eq!(session.negative_count(), 1);
268    }
269
270    #[test]
271    fn test_reset_clears_state() {
272        let mut session = FeedbackSession::new();
273        session.like("A", "p", "B");
274        session.reset();
275        assert_eq!(session.weight("A", "p", "B"), 1.0);
276        assert_eq!(session.history().len(), 0);
277    }
278
279    #[test]
280    fn test_feedback_event_with_note() {
281        let event =
282            FeedbackEvent::new("A", "p", "B", FeedbackKind::Positive).with_note("very relevant");
283        assert_eq!(event.note.as_deref(), Some("very relevant"));
284    }
285
286    #[test]
287    fn test_custom_config() {
288        let config = FeedbackConfig {
289            positive_factor: 2.0,
290            negative_factor: 0.5,
291            min_weight: 0.1,
292            max_weight: 5.0,
293        };
294        let mut session = FeedbackSession::with_config(config);
295        session.like("A", "p", "B");
296        assert_eq!(session.weight("A", "p", "B"), 2.0);
297    }
298}
299
300// ─────────────────────────────────────────────────────────────────────────────
301// v0.4.0: TripleRelevanceFeedback (seahash-keyed adaptive weights)
302// ─────────────────────────────────────────────────────────────────────────────
303
304/// Relevance signal for a triple.
305#[derive(Debug, Clone, PartialEq, Eq)]
306pub enum Relevance {
307    /// User found this triple useful.
308    Positive,
309    /// User found this triple not useful.
310    Negative,
311    /// User has no preference (resets weight to 1.0).
312    Neutral,
313}
314
315/// Unique identifier for a triple — seahash of `"{s}|{p}|{o}"`.
316pub type TripleId = u64;
317
318/// Compute a [`TripleId`] for the given triple components.
319fn triple_id(s: &str, p: &str, o: &str) -> TripleId {
320    seahash::hash(format!("{s}|{p}|{o}").as_bytes())
321}
322
323/// Session-scoped adaptive feedback for triple retrieval.
324///
325/// Weights start at 1.0 and are updated multiplicatively on each
326/// [`Relevance::Positive`] or [`Relevance::Negative`] signal.
327/// [`Relevance::Neutral`] resets the weight to exactly 1.0.
328///
329/// Use [`TripleRelevanceFeedback::apply_to_scores`] to re-rank a scored
330/// result list before serving the next retrieval round.
331pub struct TripleRelevanceFeedback {
332    positive: std::collections::HashSet<TripleId>,
333    negative: std::collections::HashSet<TripleId>,
334    weights: std::collections::HashMap<TripleId, f64>,
335}
336
337impl TripleRelevanceFeedback {
338    /// Positive weight multiplier per feedback event.
339    const POSITIVE_FACTOR: f64 = 1.5;
340    /// Negative weight multiplier per feedback event.
341    const NEGATIVE_FACTOR: f64 = 0.5;
342    /// Minimum weight (prevents complete suppression).
343    const MIN_WEIGHT: f64 = 0.1;
344    /// Maximum weight (prevents runaway amplification).
345    const MAX_WEIGHT: f64 = 2.0;
346
347    pub fn new() -> Self {
348        Self {
349            positive: std::collections::HashSet::new(),
350            negative: std::collections::HashSet::new(),
351            weights: std::collections::HashMap::new(),
352        }
353    }
354
355    /// Record a relevance signal for the triple `(subject, predicate, object)`.
356    ///
357    /// - `Positive`: weight *= 1.5, capped at 2.0
358    /// - `Negative`: weight *= 0.5, floored at 0.1
359    /// - `Neutral`:  weight reset to 1.0
360    pub fn record_feedback(
361        &mut self,
362        subject: &str,
363        predicate: &str,
364        object: &str,
365        signal: Relevance,
366    ) {
367        let id = triple_id(subject, predicate, object);
368        match signal {
369            Relevance::Positive => {
370                let current = self.weights.get(&id).copied().unwrap_or(1.0);
371                let next = (current * Self::POSITIVE_FACTOR).min(Self::MAX_WEIGHT);
372                self.weights.insert(id, next);
373                self.positive.insert(id);
374                self.negative.remove(&id);
375            }
376            Relevance::Negative => {
377                let current = self.weights.get(&id).copied().unwrap_or(1.0);
378                let next = (current * Self::NEGATIVE_FACTOR).max(Self::MIN_WEIGHT);
379                self.weights.insert(id, next);
380                self.negative.insert(id);
381                self.positive.remove(&id);
382            }
383            Relevance::Neutral => {
384                self.weights.insert(id, 1.0);
385                self.positive.remove(&id);
386                self.negative.remove(&id);
387            }
388        }
389    }
390
391    /// Get the multiplicative weight for a triple (default 1.0 if no feedback).
392    pub fn weight_of(&self, subject: &str, predicate: &str, object: &str) -> f64 {
393        let id = triple_id(subject, predicate, object);
394        self.weights.get(&id).copied().unwrap_or(1.0)
395    }
396
397    /// Apply weights to a scored list of triples and return re-weighted scores
398    /// sorted descending.
399    ///
400    /// `scores`: `Vec<((subject, predicate, object), raw_score)>`
401    pub fn apply_to_scores(
402        &self,
403        scores: Vec<((String, String, String), f64)>,
404    ) -> Vec<((String, String, String), f64)> {
405        let mut weighted: Vec<((String, String, String), f64)> = scores
406            .into_iter()
407            .map(|((s, p, o), raw)| {
408                let w = self.weight_of(&s, &p, &o);
409                ((s, p, o), raw * w)
410            })
411            .collect();
412        weighted.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
413        weighted
414    }
415
416    /// Clear all feedback and weights for this session.
417    pub fn reset(&mut self) {
418        self.positive.clear();
419        self.negative.clear();
420        self.weights.clear();
421    }
422}
423
424impl Default for TripleRelevanceFeedback {
425    fn default() -> Self {
426        Self::new()
427    }
428}
429
430// ─── TripleRelevanceFeedback tests ────────────────────────────────────────────
431
432#[cfg(test)]
433mod triple_relevance_tests {
434    use super::{Relevance, TripleRelevanceFeedback};
435
436    #[test]
437    fn test_positive_feedback_boosts_score() {
438        let mut session = TripleRelevanceFeedback::new();
439        session.record_feedback("A", "p", "B", Relevance::Positive);
440        let raw = 0.5_f64;
441        let boosted = session.weight_of("A", "p", "B") * raw;
442        assert!(
443            boosted > raw,
444            "positive feedback should boost score: {boosted} vs {raw}"
445        );
446    }
447
448    #[test]
449    fn test_negative_feedback_reduces_score() {
450        let mut session = TripleRelevanceFeedback::new();
451        session.record_feedback("A", "p", "B", Relevance::Negative);
452        let raw = 0.5_f64;
453        let reduced = session.weight_of("A", "p", "B") * raw;
454        assert!(
455            reduced < raw,
456            "negative feedback should reduce score: {reduced} vs {raw}"
457        );
458    }
459
460    #[test]
461    fn test_neutral_no_feedback_leaves_score_unchanged() {
462        let session = TripleRelevanceFeedback::new();
463        // No feedback recorded → weight is 1.0.
464        let raw = 0.42_f64;
465        let result = session.weight_of("X", "q", "Y") * raw;
466        assert!(
467            (result - raw).abs() < 1e-12,
468            "neutral (no feedback) should leave score unchanged: {result} vs {raw}"
469        );
470    }
471
472    #[test]
473    fn test_repeated_positive_capped_at_max() {
474        let mut session = TripleRelevanceFeedback::new();
475        for _ in 0..100 {
476            session.record_feedback("A", "p", "B", Relevance::Positive);
477        }
478        let w = session.weight_of("A", "p", "B");
479        assert!(
480            w <= 2.0,
481            "repeated positive feedback must not exceed 2.0, got {w}"
482        );
483    }
484
485    #[test]
486    fn test_repeated_negative_floored_at_min() {
487        let mut session = TripleRelevanceFeedback::new();
488        for _ in 0..100 {
489            session.record_feedback("A", "p", "B", Relevance::Negative);
490        }
491        let w = session.weight_of("A", "p", "B");
492        assert!(
493            w >= 0.1,
494            "repeated negative feedback must not go below 0.1, got {w}"
495        );
496    }
497
498    #[test]
499    fn test_reset_clears_all_weights() {
500        let mut session = TripleRelevanceFeedback::new();
501        session.record_feedback("A", "p", "B", Relevance::Positive);
502        session.reset();
503        let w = session.weight_of("A", "p", "B");
504        assert!(
505            (w - 1.0).abs() < 1e-12,
506            "after reset, weight should be 1.0, got {w}"
507        );
508    }
509
510    #[test]
511    fn test_apply_to_scores_sorted_descending() {
512        let mut session = TripleRelevanceFeedback::new();
513        session.record_feedback("A", "p", "B", Relevance::Positive);
514        // "A|p|B" gets weight 1.5, "X|q|Y" stays 1.0.
515        let scores = vec![
516            (("X".into(), "q".into(), "Y".into()), 0.8_f64),
517            (("A".into(), "p".into(), "B".into()), 0.5_f64),
518        ];
519        let result = session.apply_to_scores(scores);
520        // A|p|B: 0.5 * 1.5 = 0.75; X|q|Y: 0.8 * 1.0 = 0.8 → X first.
521        assert_eq!(result.len(), 2, "should return same number of triples");
522        let (_, first_score) = &result[0];
523        let (_, second_score) = &result[1];
524        assert!(
525            first_score >= second_score,
526            "results should be sorted descending: {first_score} >= {second_score}"
527        );
528    }
529}