use std::collections::HashMap;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum SessionEventType {
Play,
Pause,
Skip,
Complete,
Like,
Dislike,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SessionEvent {
pub item_id: String,
pub timestamp_secs: u64,
pub event_type: SessionEventType,
pub position_pct: f32,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct SessionContext {
pub events: Vec<SessionEvent>,
pub user_id: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CatalogItem {
pub id: String,
pub genres: Vec<String>,
pub duration_secs: u32,
pub popularity_score: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SessionScore {
pub item_id: String,
pub score: f32,
pub reasons: Vec<String>,
}
const RECENCY_LAMBDA: f64 = 1.0 / 1800.0;
const COMPLETION_THRESHOLD: f32 = 0.85;
const COMPLETE_BOOST: f32 = 1.5;
const DISLIKE_PENALTY: f32 = 0.25;
const DIRECT_DISLIKE_PENALTY: f32 = 0.6;
const GENRE_OVERLAP_BASE: f32 = 0.15;
fn genre_jaccard(a: &[String], b: &[String]) -> f32 {
if a.is_empty() && b.is_empty() {
return 0.0;
}
let intersection = a.iter().filter(|g| b.contains(g)).count();
let union = {
let mut combined: Vec<&String> = a.iter().collect();
for g in b {
if !combined.contains(&g) {
combined.push(g);
}
}
combined.len()
};
if union == 0 {
0.0
} else {
intersection as f32 / union as f32
}
}
fn recency_weight(event_ts: u64, latest_ts: u64) -> f32 {
let age_secs = latest_ts.saturating_sub(event_ts) as f64;
f64::exp(-RECENCY_LAMBDA * age_secs) as f32
}
#[derive(Debug, Default)]
struct ItemSignal {
positive_weight: f32,
dislike_weight: f32,
completed: bool,
liked: bool,
disliked: bool,
genres: Vec<String>,
}
fn extract_signals(
context: &SessionContext,
catalog_map: &HashMap<&str, &CatalogItem>,
) -> HashMap<String, ItemSignal> {
let latest_ts = context
.events
.iter()
.map(|e| e.timestamp_secs)
.max()
.unwrap_or(0);
let mut signals: HashMap<String, ItemSignal> = HashMap::new();
for event in &context.events {
let sig = signals.entry(event.item_id.clone()).or_default();
if sig.genres.is_empty() {
if let Some(item) = catalog_map.get(event.item_id.as_str()) {
sig.genres = item.genres.clone();
}
}
let w = recency_weight(event.timestamp_secs, latest_ts);
match event.event_type {
SessionEventType::Play | SessionEventType::Pause => {
sig.positive_weight += w * event.position_pct;
}
SessionEventType::Skip => {
sig.positive_weight += w * event.position_pct * 0.5;
}
SessionEventType::Complete => {
sig.completed = true;
sig.positive_weight += w * 1.0;
}
SessionEventType::Like => {
sig.liked = true;
sig.positive_weight += w * 1.2;
}
SessionEventType::Dislike => {
sig.disliked = true;
sig.dislike_weight += w;
}
}
if event.position_pct >= COMPLETION_THRESHOLD {
sig.completed = true;
}
}
signals
}
#[derive(Debug, Default)]
pub struct SessionRecommender;
impl SessionRecommender {
#[must_use]
pub fn new() -> Self {
Self
}
#[must_use]
pub fn recommend(
&self,
context: &SessionContext,
catalog: &[CatalogItem],
n: usize,
) -> Vec<String> {
if n == 0 {
return Vec::new();
}
let catalog_map: HashMap<&str, &CatalogItem> =
catalog.iter().map(|c| (c.id.as_str(), c)).collect();
let signals = extract_signals(context, &catalog_map);
let genre_affinity = self.build_genre_affinity(&signals);
let mut scores: Vec<SessionScore> = catalog
.iter()
.filter_map(|item| {
let sig = signals.get(&item.id);
if let Some(s) = sig {
if s.completed || s.liked {
return None;
}
}
Some(self.score_item(item, sig, &genre_affinity, genre_affinity.is_empty()))
})
.collect();
scores.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
scores.truncate(n);
scores.into_iter().map(|s| s.item_id).collect()
}
#[must_use]
pub fn recommend_scored(
&self,
context: &SessionContext,
catalog: &[CatalogItem],
n: usize,
) -> Vec<SessionScore> {
if n == 0 {
return Vec::new();
}
let catalog_map: HashMap<&str, &CatalogItem> =
catalog.iter().map(|c| (c.id.as_str(), c)).collect();
let signals = extract_signals(context, &catalog_map);
let genre_affinity = self.build_genre_affinity(&signals);
let popularity_fallback = genre_affinity.is_empty();
let mut scores: Vec<SessionScore> = catalog
.iter()
.filter_map(|item| {
let sig = signals.get(&item.id);
if let Some(s) = sig {
if s.completed || s.liked {
return None;
}
}
Some(self.score_item(item, sig, &genre_affinity, popularity_fallback))
})
.collect();
scores.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
scores.truncate(n);
scores
}
fn build_genre_affinity(&self, signals: &HashMap<String, ItemSignal>) -> HashMap<String, f32> {
let mut affinity: HashMap<String, f32> = HashMap::new();
for sig in signals.values() {
if sig.disliked {
let weight = sig.dislike_weight * DISLIKE_PENALTY;
for genre in &sig.genres {
*affinity.entry(genre.clone()).or_insert(0.0) -= weight;
}
} else {
let mut weight = sig.positive_weight;
if sig.completed {
weight *= COMPLETE_BOOST;
}
if sig.liked {
weight *= 1.3;
}
for genre in &sig.genres {
*affinity.entry(genre.clone()).or_insert(0.0) += weight * GENRE_OVERLAP_BASE;
}
}
}
affinity
}
fn score_item(
&self,
item: &CatalogItem,
sig: Option<&ItemSignal>,
genre_affinity: &HashMap<String, f32>,
popularity_fallback: bool,
) -> SessionScore {
let mut reasons: Vec<String> = Vec::new();
let base_score = if popularity_fallback {
reasons.push(format!("popularity_fallback:{:.2}", item.popularity_score));
item.popularity_score
} else {
let mut genre_score = 0.0_f32;
for genre in &item.genres {
if let Some(&aff) = genre_affinity.get(genre) {
genre_score += aff;
}
}
if !item.genres.is_empty() {
genre_score /= item.genres.len() as f32;
}
if genre_score > 0.0 {
reasons.push(format!("genre_affinity:{genre_score:.3}"));
}
let pop_contribution = item.popularity_score * 0.2;
if pop_contribution > 0.0 {
reasons.push(format!("popularity:{:.2}", item.popularity_score));
}
genre_score + pop_contribution
};
let mut score = base_score;
if let Some(s) = sig {
if s.disliked {
score -= DIRECT_DISLIKE_PENALTY;
reasons.push("direct_dislike_penalty".to_string());
}
}
SessionScore {
item_id: item.id.clone(),
score,
reasons,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_catalog() -> Vec<CatalogItem> {
vec![
CatalogItem {
id: "action_1".to_string(),
genres: vec!["action".to_string(), "thriller".to_string()],
duration_secs: 5400,
popularity_score: 0.8,
},
CatalogItem {
id: "action_2".to_string(),
genres: vec!["action".to_string()],
duration_secs: 4800,
popularity_score: 0.6,
},
CatalogItem {
id: "comedy_1".to_string(),
genres: vec!["comedy".to_string()],
duration_secs: 3600,
popularity_score: 0.9,
},
CatalogItem {
id: "drama_1".to_string(),
genres: vec!["drama".to_string()],
duration_secs: 6000,
popularity_score: 0.5,
},
CatalogItem {
id: "action_3".to_string(),
genres: vec!["action".to_string(), "sci-fi".to_string()],
duration_secs: 7200,
popularity_score: 0.7,
},
CatalogItem {
id: "already_liked".to_string(),
genres: vec!["action".to_string()],
duration_secs: 5000,
popularity_score: 0.75,
},
]
}
fn event(item_id: &str, ts: u64, evt: SessionEventType, pos: f32) -> SessionEvent {
SessionEvent {
item_id: item_id.to_string(),
timestamp_secs: ts,
event_type: evt,
position_pct: pos,
}
}
#[test]
fn test_empty_session_popularity_fallback() {
let ctx = SessionContext::default();
let catalog = make_catalog();
let rec = SessionRecommender::new();
let results = rec.recommend(&ctx, &catalog, 3);
assert_eq!(results.len(), 3);
assert_eq!(results[0], "comedy_1");
}
#[test]
fn test_n_limit_respected() {
let ctx = SessionContext::default();
let catalog = make_catalog();
let rec = SessionRecommender::new();
let results = rec.recommend(&ctx, &catalog, 2);
assert_eq!(results.len(), 2);
}
#[test]
fn test_zero_n_returns_empty() {
let ctx = SessionContext::default();
let catalog = make_catalog();
let rec = SessionRecommender::new();
assert!(rec.recommend(&ctx, &catalog, 0).is_empty());
}
#[test]
fn test_complete_boosts_similar_genre() {
let ctx = SessionContext {
events: vec![event("action_1", 1000, SessionEventType::Complete, 1.0)],
user_id: None,
};
let catalog = make_catalog();
let rec = SessionRecommender::new();
let results = rec.recommend(&ctx, &catalog, 5);
let action_pos: Vec<usize> = results
.iter()
.enumerate()
.filter(|(_, id)| id.starts_with("action"))
.map(|(i, _)| i)
.collect();
let comedy_pos = results.iter().position(|id| id == "comedy_1");
assert!(
action_pos.iter().any(|&p| comedy_pos.map_or(true, |c| p < c)),
"action items should rank before comedy after completing an action item, got: {results:?}"
);
}
#[test]
fn test_liked_item_excluded() {
let ctx = SessionContext {
events: vec![event("already_liked", 1000, SessionEventType::Like, 0.5)],
user_id: None,
};
let catalog = make_catalog();
let rec = SessionRecommender::new();
let results = rec.recommend(&ctx, &catalog, 10);
assert!(!results.contains(&"already_liked".to_string()));
}
#[test]
fn test_dislike_penalises_item() {
let ctx = SessionContext {
events: vec![event("action_2", 1000, SessionEventType::Dislike, 0.1)],
user_id: None,
};
let catalog = make_catalog();
let rec = SessionRecommender::new();
let scores = rec.recommend_scored(&ctx, &catalog, 10);
let action2 = scores.iter().find(|s| s.item_id == "action_2");
let comedy1 = scores.iter().find(|s| s.item_id == "comedy_1");
if let (Some(a), Some(c)) = (action2, comedy1) {
assert!(
a.score < c.score,
"disliked action_2 ({:.3}) should score below comedy_1 ({:.3})",
a.score,
c.score
);
}
}
#[test]
fn test_dislike_genre_penalises_similar() {
let ctx = SessionContext {
events: vec![event("action_1", 1000, SessionEventType::Dislike, 0.2)],
user_id: None,
};
let catalog = make_catalog();
let rec = SessionRecommender::new();
let scores = rec.recommend_scored(&ctx, &catalog, 10);
let action2 = scores.iter().find(|s| s.item_id == "action_2");
let comedy1 = scores.iter().find(|s| s.item_id == "comedy_1");
if let (Some(a), Some(c)) = (action2, comedy1) {
assert!(
a.score < c.score,
"action_2 score ({:.3}) should be below comedy_1 ({:.3}) after disliking action genre",
a.score,
c.score
);
}
}
#[test]
fn test_multiple_positive_events_accumulate() {
let ctx = SessionContext {
events: vec![
event("action_1", 1000, SessionEventType::Complete, 1.0),
event("action_2", 2000, SessionEventType::Like, 0.8),
],
user_id: Some("user_abc".to_string()),
};
let catalog = make_catalog();
let rec = SessionRecommender::new();
let results = rec.recommend(&ctx, &catalog, 5);
assert!(
results.contains(&"action_3".to_string()),
"action_3 should appear in results: {results:?}"
);
let action3_pos = results.iter().position(|r| r == "action_3");
let drama_pos = results.iter().position(|r| r == "drama_1");
if let (Some(a), Some(d)) = (action3_pos, drama_pos) {
assert!(a < d, "action_3 should rank before drama_1");
}
}
#[test]
fn test_scored_reasons_present() {
let ctx = SessionContext::default();
let catalog = make_catalog();
let rec = SessionRecommender::new();
let scores = rec.recommend_scored(&ctx, &catalog, 3);
for s in &scores {
assert!(
!s.reasons.is_empty(),
"item {} should have at least one reason",
s.item_id
);
}
}
#[test]
fn test_empty_catalog_returns_empty() {
let ctx = SessionContext::default();
let rec = SessionRecommender::new();
assert!(rec.recommend(&ctx, &[], 5).is_empty());
}
}