use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum FeedbackKind {
Positive,
Negative,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FeedbackEvent {
pub triple_key: String,
pub kind: FeedbackKind,
pub note: Option<String>,
}
impl FeedbackEvent {
pub fn new(subject: &str, predicate: &str, object: &str, kind: FeedbackKind) -> Self {
Self {
triple_key: triple_key(subject, predicate, object),
kind,
note: None,
}
}
pub fn with_note(mut self, note: impl Into<String>) -> Self {
self.note = Some(note.into());
self
}
}
fn triple_key(s: &str, p: &str, o: &str) -> String {
format!("{}|{}|{}", s, p, o)
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FeedbackConfig {
pub positive_factor: f64,
pub negative_factor: f64,
pub min_weight: f64,
pub max_weight: f64,
}
impl Default for FeedbackConfig {
fn default() -> Self {
Self {
positive_factor: 1.5,
negative_factor: 0.6,
min_weight: 0.01,
max_weight: 10.0,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FeedbackSession {
weights: HashMap<String, f64>,
history: Vec<FeedbackEvent>,
pub config: FeedbackConfig,
}
impl FeedbackSession {
pub fn new() -> Self {
Self::with_config(FeedbackConfig::default())
}
pub fn with_config(config: FeedbackConfig) -> Self {
Self {
weights: HashMap::new(),
history: Vec::new(),
config,
}
}
pub fn record(&mut self, event: FeedbackEvent) {
let factor = match event.kind {
FeedbackKind::Positive => self.config.positive_factor,
FeedbackKind::Negative => self.config.negative_factor,
};
let w = self.weights.entry(event.triple_key.clone()).or_insert(1.0);
*w = (*w * factor).clamp(self.config.min_weight, self.config.max_weight);
self.history.push(event);
}
pub fn like(&mut self, subject: &str, predicate: &str, object: &str) {
self.record(FeedbackEvent::new(
subject,
predicate,
object,
FeedbackKind::Positive,
));
}
pub fn dislike(&mut self, subject: &str, predicate: &str, object: &str) {
self.record(FeedbackEvent::new(
subject,
predicate,
object,
FeedbackKind::Negative,
));
}
pub fn weight(&self, subject: &str, predicate: &str, object: &str) -> f64 {
let key = triple_key(subject, predicate, object);
*self.weights.get(&key).unwrap_or(&1.0)
}
pub fn apply_weights(&self, scores: &HashMap<String, f64>) -> HashMap<String, f64> {
scores
.iter()
.map(|(k, &v)| {
let w = self.weights.get(k).copied().unwrap_or(1.0);
(k.clone(), v * w)
})
.collect()
}
pub fn history(&self) -> &[FeedbackEvent] {
&self.history
}
pub fn positive_count(&self) -> usize {
self.history
.iter()
.filter(|e| e.kind == FeedbackKind::Positive)
.count()
}
pub fn negative_count(&self) -> usize {
self.history
.iter()
.filter(|e| e.kind == FeedbackKind::Negative)
.count()
}
pub fn reset(&mut self) {
self.weights.clear();
self.history.clear();
}
}
impl Default for FeedbackSession {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_initial_weight_is_one() {
let session = FeedbackSession::new();
assert_eq!(session.weight("A", "p", "B"), 1.0);
}
#[test]
fn test_positive_feedback_increases_weight() {
let mut session = FeedbackSession::new();
session.like("A", "p", "B");
assert!(session.weight("A", "p", "B") > 1.0);
}
#[test]
fn test_negative_feedback_decreases_weight() {
let mut session = FeedbackSession::new();
session.dislike("A", "p", "B");
assert!(session.weight("A", "p", "B") < 1.0);
}
#[test]
fn test_repeated_positive_capped_at_max() {
let mut session = FeedbackSession::new();
for _ in 0..100 {
session.like("A", "p", "B");
}
assert!(session.weight("A", "p", "B") <= session.config.max_weight);
}
#[test]
fn test_repeated_negative_capped_at_min() {
let mut session = FeedbackSession::new();
for _ in 0..100 {
session.dislike("A", "p", "B");
}
assert!(session.weight("A", "p", "B") >= session.config.min_weight);
}
#[test]
fn test_apply_weights() {
let mut session = FeedbackSession::new();
session.like("A", "knows", "B");
let key = "A|knows|B".to_string();
let mut scores = HashMap::new();
scores.insert(key.clone(), 0.5);
let adjusted = session.apply_weights(&scores);
assert!(adjusted[&key] > 0.5);
}
#[test]
fn test_event_history() {
let mut session = FeedbackSession::new();
session.like("X", "p", "Y");
session.dislike("X", "q", "Z");
assert_eq!(session.history().len(), 2);
assert_eq!(session.positive_count(), 1);
assert_eq!(session.negative_count(), 1);
}
#[test]
fn test_reset_clears_state() {
let mut session = FeedbackSession::new();
session.like("A", "p", "B");
session.reset();
assert_eq!(session.weight("A", "p", "B"), 1.0);
assert_eq!(session.history().len(), 0);
}
#[test]
fn test_feedback_event_with_note() {
let event =
FeedbackEvent::new("A", "p", "B", FeedbackKind::Positive).with_note("very relevant");
assert_eq!(event.note.as_deref(), Some("very relevant"));
}
#[test]
fn test_custom_config() {
let config = FeedbackConfig {
positive_factor: 2.0,
negative_factor: 0.5,
min_weight: 0.1,
max_weight: 5.0,
};
let mut session = FeedbackSession::with_config(config);
session.like("A", "p", "B");
assert_eq!(session.weight("A", "p", "B"), 2.0);
}
}