1use serde::{Deserialize, Serialize};
15use std::collections::HashMap;
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
23pub enum FeedbackKind {
24 Positive,
25 Negative,
26}
27
28#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct FeedbackEvent {
31 pub triple_key: String,
33 pub kind: FeedbackKind,
35 pub note: Option<String>,
37}
38
39impl FeedbackEvent {
40 pub fn new(subject: &str, predicate: &str, object: &str, kind: FeedbackKind) -> Self {
42 Self {
43 triple_key: triple_key(subject, predicate, object),
44 kind,
45 note: None,
46 }
47 }
48
49 pub fn with_note(mut self, note: impl Into<String>) -> Self {
50 self.note = Some(note.into());
51 self
52 }
53}
54
55fn triple_key(s: &str, p: &str, o: &str) -> String {
56 format!("{}|{}|{}", s, p, o)
57}
58
59#[derive(Debug, Clone, Serialize, Deserialize)]
65pub struct FeedbackConfig {
66 pub positive_factor: f64,
68 pub negative_factor: f64,
70 pub min_weight: f64,
72 pub max_weight: f64,
74}
75
76impl Default for FeedbackConfig {
77 fn default() -> Self {
78 Self {
79 positive_factor: 1.5,
80 negative_factor: 0.6,
81 min_weight: 0.01,
82 max_weight: 10.0,
83 }
84 }
85}
86
87#[derive(Debug, Clone, Serialize, Deserialize)]
97pub struct FeedbackSession {
98 weights: HashMap<String, f64>,
100 history: Vec<FeedbackEvent>,
102 pub config: FeedbackConfig,
104}
105
106impl FeedbackSession {
107 pub fn new() -> Self {
108 Self::with_config(FeedbackConfig::default())
109 }
110
111 pub fn with_config(config: FeedbackConfig) -> Self {
112 Self {
113 weights: HashMap::new(),
114 history: Vec::new(),
115 config,
116 }
117 }
118
119 pub fn record(&mut self, event: FeedbackEvent) {
121 let factor = match event.kind {
122 FeedbackKind::Positive => self.config.positive_factor,
123 FeedbackKind::Negative => self.config.negative_factor,
124 };
125 let w = self.weights.entry(event.triple_key.clone()).or_insert(1.0);
126 *w = (*w * factor).clamp(self.config.min_weight, self.config.max_weight);
127 self.history.push(event);
128 }
129
130 pub fn like(&mut self, subject: &str, predicate: &str, object: &str) {
132 self.record(FeedbackEvent::new(
133 subject,
134 predicate,
135 object,
136 FeedbackKind::Positive,
137 ));
138 }
139
140 pub fn dislike(&mut self, subject: &str, predicate: &str, object: &str) {
142 self.record(FeedbackEvent::new(
143 subject,
144 predicate,
145 object,
146 FeedbackKind::Negative,
147 ));
148 }
149
150 pub fn weight(&self, subject: &str, predicate: &str, object: &str) -> f64 {
152 let key = triple_key(subject, predicate, object);
153 *self.weights.get(&key).unwrap_or(&1.0)
154 }
155
156 pub fn apply_weights(&self, scores: &HashMap<String, f64>) -> HashMap<String, f64> {
160 scores
161 .iter()
162 .map(|(k, &v)| {
163 let w = self.weights.get(k).copied().unwrap_or(1.0);
164 (k.clone(), v * w)
165 })
166 .collect()
167 }
168
169 pub fn history(&self) -> &[FeedbackEvent] {
171 &self.history
172 }
173
174 pub fn positive_count(&self) -> usize {
176 self.history
177 .iter()
178 .filter(|e| e.kind == FeedbackKind::Positive)
179 .count()
180 }
181
182 pub fn negative_count(&self) -> usize {
184 self.history
185 .iter()
186 .filter(|e| e.kind == FeedbackKind::Negative)
187 .count()
188 }
189
190 pub fn reset(&mut self) {
192 self.weights.clear();
193 self.history.clear();
194 }
195}
196
197impl Default for FeedbackSession {
198 fn default() -> Self {
199 Self::new()
200 }
201}
202
203#[cfg(test)]
208mod tests {
209 use super::*;
210
211 #[test]
212 fn test_initial_weight_is_one() {
213 let session = FeedbackSession::new();
214 assert_eq!(session.weight("A", "p", "B"), 1.0);
215 }
216
217 #[test]
218 fn test_positive_feedback_increases_weight() {
219 let mut session = FeedbackSession::new();
220 session.like("A", "p", "B");
221 assert!(session.weight("A", "p", "B") > 1.0);
222 }
223
224 #[test]
225 fn test_negative_feedback_decreases_weight() {
226 let mut session = FeedbackSession::new();
227 session.dislike("A", "p", "B");
228 assert!(session.weight("A", "p", "B") < 1.0);
229 }
230
231 #[test]
232 fn test_repeated_positive_capped_at_max() {
233 let mut session = FeedbackSession::new();
234 for _ in 0..100 {
235 session.like("A", "p", "B");
236 }
237 assert!(session.weight("A", "p", "B") <= session.config.max_weight);
238 }
239
240 #[test]
241 fn test_repeated_negative_capped_at_min() {
242 let mut session = FeedbackSession::new();
243 for _ in 0..100 {
244 session.dislike("A", "p", "B");
245 }
246 assert!(session.weight("A", "p", "B") >= session.config.min_weight);
247 }
248
249 #[test]
250 fn test_apply_weights() {
251 let mut session = FeedbackSession::new();
252 session.like("A", "knows", "B");
253 let key = "A|knows|B".to_string();
254 let mut scores = HashMap::new();
255 scores.insert(key.clone(), 0.5);
256 let adjusted = session.apply_weights(&scores);
257 assert!(adjusted[&key] > 0.5);
258 }
259
260 #[test]
261 fn test_event_history() {
262 let mut session = FeedbackSession::new();
263 session.like("X", "p", "Y");
264 session.dislike("X", "q", "Z");
265 assert_eq!(session.history().len(), 2);
266 assert_eq!(session.positive_count(), 1);
267 assert_eq!(session.negative_count(), 1);
268 }
269
270 #[test]
271 fn test_reset_clears_state() {
272 let mut session = FeedbackSession::new();
273 session.like("A", "p", "B");
274 session.reset();
275 assert_eq!(session.weight("A", "p", "B"), 1.0);
276 assert_eq!(session.history().len(), 0);
277 }
278
279 #[test]
280 fn test_feedback_event_with_note() {
281 let event =
282 FeedbackEvent::new("A", "p", "B", FeedbackKind::Positive).with_note("very relevant");
283 assert_eq!(event.note.as_deref(), Some("very relevant"));
284 }
285
286 #[test]
287 fn test_custom_config() {
288 let config = FeedbackConfig {
289 positive_factor: 2.0,
290 negative_factor: 0.5,
291 min_weight: 0.1,
292 max_weight: 5.0,
293 };
294 let mut session = FeedbackSession::with_config(config);
295 session.like("A", "p", "B");
296 assert_eq!(session.weight("A", "p", "B"), 2.0);
297 }
298}
299
300#[derive(Debug, Clone, PartialEq, Eq)]
306pub enum Relevance {
307 Positive,
309 Negative,
311 Neutral,
313}
314
315pub type TripleId = u64;
317
318fn triple_id(s: &str, p: &str, o: &str) -> TripleId {
320 seahash::hash(format!("{s}|{p}|{o}").as_bytes())
321}
322
323pub struct TripleRelevanceFeedback {
332 positive: std::collections::HashSet<TripleId>,
333 negative: std::collections::HashSet<TripleId>,
334 weights: std::collections::HashMap<TripleId, f64>,
335}
336
337impl TripleRelevanceFeedback {
338 const POSITIVE_FACTOR: f64 = 1.5;
340 const NEGATIVE_FACTOR: f64 = 0.5;
342 const MIN_WEIGHT: f64 = 0.1;
344 const MAX_WEIGHT: f64 = 2.0;
346
347 pub fn new() -> Self {
348 Self {
349 positive: std::collections::HashSet::new(),
350 negative: std::collections::HashSet::new(),
351 weights: std::collections::HashMap::new(),
352 }
353 }
354
355 pub fn record_feedback(
361 &mut self,
362 subject: &str,
363 predicate: &str,
364 object: &str,
365 signal: Relevance,
366 ) {
367 let id = triple_id(subject, predicate, object);
368 match signal {
369 Relevance::Positive => {
370 let current = self.weights.get(&id).copied().unwrap_or(1.0);
371 let next = (current * Self::POSITIVE_FACTOR).min(Self::MAX_WEIGHT);
372 self.weights.insert(id, next);
373 self.positive.insert(id);
374 self.negative.remove(&id);
375 }
376 Relevance::Negative => {
377 let current = self.weights.get(&id).copied().unwrap_or(1.0);
378 let next = (current * Self::NEGATIVE_FACTOR).max(Self::MIN_WEIGHT);
379 self.weights.insert(id, next);
380 self.negative.insert(id);
381 self.positive.remove(&id);
382 }
383 Relevance::Neutral => {
384 self.weights.insert(id, 1.0);
385 self.positive.remove(&id);
386 self.negative.remove(&id);
387 }
388 }
389 }
390
391 pub fn weight_of(&self, subject: &str, predicate: &str, object: &str) -> f64 {
393 let id = triple_id(subject, predicate, object);
394 self.weights.get(&id).copied().unwrap_or(1.0)
395 }
396
397 pub fn apply_to_scores(
402 &self,
403 scores: Vec<((String, String, String), f64)>,
404 ) -> Vec<((String, String, String), f64)> {
405 let mut weighted: Vec<((String, String, String), f64)> = scores
406 .into_iter()
407 .map(|((s, p, o), raw)| {
408 let w = self.weight_of(&s, &p, &o);
409 ((s, p, o), raw * w)
410 })
411 .collect();
412 weighted.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
413 weighted
414 }
415
416 pub fn reset(&mut self) {
418 self.positive.clear();
419 self.negative.clear();
420 self.weights.clear();
421 }
422}
423
424impl Default for TripleRelevanceFeedback {
425 fn default() -> Self {
426 Self::new()
427 }
428}
429
430#[cfg(test)]
433mod triple_relevance_tests {
434 use super::{Relevance, TripleRelevanceFeedback};
435
436 #[test]
437 fn test_positive_feedback_boosts_score() {
438 let mut session = TripleRelevanceFeedback::new();
439 session.record_feedback("A", "p", "B", Relevance::Positive);
440 let raw = 0.5_f64;
441 let boosted = session.weight_of("A", "p", "B") * raw;
442 assert!(
443 boosted > raw,
444 "positive feedback should boost score: {boosted} vs {raw}"
445 );
446 }
447
448 #[test]
449 fn test_negative_feedback_reduces_score() {
450 let mut session = TripleRelevanceFeedback::new();
451 session.record_feedback("A", "p", "B", Relevance::Negative);
452 let raw = 0.5_f64;
453 let reduced = session.weight_of("A", "p", "B") * raw;
454 assert!(
455 reduced < raw,
456 "negative feedback should reduce score: {reduced} vs {raw}"
457 );
458 }
459
460 #[test]
461 fn test_neutral_no_feedback_leaves_score_unchanged() {
462 let session = TripleRelevanceFeedback::new();
463 let raw = 0.42_f64;
465 let result = session.weight_of("X", "q", "Y") * raw;
466 assert!(
467 (result - raw).abs() < 1e-12,
468 "neutral (no feedback) should leave score unchanged: {result} vs {raw}"
469 );
470 }
471
472 #[test]
473 fn test_repeated_positive_capped_at_max() {
474 let mut session = TripleRelevanceFeedback::new();
475 for _ in 0..100 {
476 session.record_feedback("A", "p", "B", Relevance::Positive);
477 }
478 let w = session.weight_of("A", "p", "B");
479 assert!(
480 w <= 2.0,
481 "repeated positive feedback must not exceed 2.0, got {w}"
482 );
483 }
484
485 #[test]
486 fn test_repeated_negative_floored_at_min() {
487 let mut session = TripleRelevanceFeedback::new();
488 for _ in 0..100 {
489 session.record_feedback("A", "p", "B", Relevance::Negative);
490 }
491 let w = session.weight_of("A", "p", "B");
492 assert!(
493 w >= 0.1,
494 "repeated negative feedback must not go below 0.1, got {w}"
495 );
496 }
497
498 #[test]
499 fn test_reset_clears_all_weights() {
500 let mut session = TripleRelevanceFeedback::new();
501 session.record_feedback("A", "p", "B", Relevance::Positive);
502 session.reset();
503 let w = session.weight_of("A", "p", "B");
504 assert!(
505 (w - 1.0).abs() < 1e-12,
506 "after reset, weight should be 1.0, got {w}"
507 );
508 }
509
510 #[test]
511 fn test_apply_to_scores_sorted_descending() {
512 let mut session = TripleRelevanceFeedback::new();
513 session.record_feedback("A", "p", "B", Relevance::Positive);
514 let scores = vec![
516 (("X".into(), "q".into(), "Y".into()), 0.8_f64),
517 (("A".into(), "p".into(), "B".into()), 0.5_f64),
518 ];
519 let result = session.apply_to_scores(scores);
520 assert_eq!(result.len(), 2, "should return same number of triples");
522 let (_, first_score) = &result[0];
523 let (_, second_score) = &result[1];
524 assert!(
525 first_score >= second_score,
526 "results should be sorted descending: {first_score} >= {second_score}"
527 );
528 }
529}