use anyhow::{anyhow, Result};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use crate::simd::cosine_similarity_simd;
use crate::store::Metadata;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UserPreference {
pub item_id: String,
pub rating: f32,
}
impl UserPreference {
pub fn new(item_id: impl Into<String>, rating: f32) -> Self {
Self {
item_id: item_id.into(),
rating,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Recommendation {
pub item_id: String,
pub score: f32,
pub reason: String,
pub metadata: Option<Metadata>,
}
impl Recommendation {
pub fn new(item_id: String, score: f32, reason: String) -> Self {
Self {
item_id,
score,
reason,
metadata: None,
}
}
pub fn with_metadata(mut self, metadata: Metadata) -> Self {
self.metadata = Some(metadata);
self
}
}
#[derive(Debug, Clone)]
struct Item {
id: String,
vector: Vec<f32>,
metadata: Metadata,
}
#[derive(Debug, Clone)]
pub struct ContentBasedRecommender {
items: HashMap<String, Item>,
}
impl ContentBasedRecommender {
pub fn new() -> Self {
Self {
items: HashMap::new(),
}
}
pub fn add_item(
&mut self,
id: impl Into<String>,
vector: Vec<f32>,
metadata: Metadata,
) -> Result<()> {
let id = id.into();
self.items.insert(
id.clone(),
Item {
id,
vector,
metadata,
},
);
Ok(())
}
pub fn recommend(
&self,
preferences: &[UserPreference],
top_k: usize,
) -> Result<Vec<Recommendation>> {
if preferences.is_empty() {
return Err(anyhow!("No preferences provided"));
}
let mut profile = self.build_user_profile(preferences)?;
let norm = profile.iter().map(|&x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
for val in &mut profile {
*val /= norm;
}
}
let mut scores: Vec<(String, f32)> = self
.items
.iter()
.filter_map(|(id, item)| {
if preferences.iter().any(|p| &p.item_id == id) {
return None;
}
let similarity = cosine_similarity_simd(&profile, &item.vector);
Some((id.clone(), similarity))
})
.collect();
scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scores.truncate(top_k);
let recommendations: Vec<Recommendation> = scores
.into_iter()
.map(|(item_id, score)| {
let reason = format!("Similar to items you liked (score: {:.3})", score);
let metadata = self.items.get(&item_id).map(|item| item.metadata.clone());
let mut rec = Recommendation::new(item_id, score, reason);
if let Some(meta) = metadata {
rec = rec.with_metadata(meta);
}
rec
})
.collect();
Ok(recommendations)
}
fn build_user_profile(&self, preferences: &[UserPreference]) -> Result<Vec<f32>> {
let mut profile: Option<Vec<f32>> = None;
let mut total_weight = 0.0;
for pref in preferences {
let item = self
.items
.get(&pref.item_id)
.ok_or_else(|| anyhow!("Item not found: {}", pref.item_id))?;
let weight = pref.rating;
total_weight += weight;
if let Some(ref mut p) = profile {
for (i, &val) in item.vector.iter().enumerate() {
p[i] += val * weight;
}
} else {
profile = Some(item.vector.iter().map(|&v| v * weight).collect());
}
}
let mut profile = profile.ok_or_else(|| anyhow!("No valid items found"))?;
if total_weight > 0.0 {
for val in &mut profile {
*val /= total_weight;
}
}
Ok(profile)
}
pub fn similar_items(&self, item_id: &str, top_k: usize) -> Result<Vec<Recommendation>> {
let item = self
.items
.get(item_id)
.ok_or_else(|| anyhow!("Item not found: {}", item_id))?;
let mut scores: Vec<(String, f32)> = self
.items
.iter()
.filter_map(|(id, other_item)| {
if id == item_id {
return None;
}
let similarity = cosine_similarity_simd(&item.vector, &other_item.vector);
Some((id.clone(), similarity))
})
.collect();
scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scores.truncate(top_k);
let recommendations: Vec<Recommendation> = scores
.into_iter()
.map(|(id, score)| {
let reason = format!("Similar to {} (score: {:.3})", item_id, score);
let metadata = self.items.get(&id).map(|item| item.metadata.clone());
let mut rec = Recommendation::new(id, score, reason);
if let Some(meta) = metadata {
rec = rec.with_metadata(meta);
}
rec
})
.collect();
Ok(recommendations)
}
}
impl Default for ContentBasedRecommender {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct CollaborativeRecommender {
user_ratings: HashMap<String, HashMap<String, f32>>,
item_users: HashMap<String, Vec<String>>,
}
impl CollaborativeRecommender {
pub fn new() -> Self {
Self {
user_ratings: HashMap::new(),
item_users: HashMap::new(),
}
}
pub fn add_rating(
&mut self,
user_id: impl Into<String>,
item_id: impl Into<String>,
rating: f32,
) {
let user_id = user_id.into();
let item_id = item_id.into();
self.user_ratings
.entry(user_id.clone())
.or_insert_with(HashMap::new)
.insert(item_id.clone(), rating);
self.item_users
.entry(item_id)
.or_insert_with(Vec::new)
.push(user_id);
}
pub fn recommend(&self, user_id: &str, top_k: usize) -> Result<Vec<Recommendation>> {
let user_ratings = self
.user_ratings
.get(user_id)
.ok_or_else(|| anyhow!("User not found: {}", user_id))?;
let similar_users = self.find_similar_users(user_id, 20)?;
let mut item_scores: HashMap<String, (f32, f32)> = HashMap::new();
for (other_user, similarity) in similar_users {
if let Some(other_ratings) = self.user_ratings.get(&other_user) {
for (item_id, rating) in other_ratings {
if user_ratings.contains_key(item_id) {
continue;
}
let entry = item_scores.entry(item_id.clone()).or_insert((0.0, 0.0));
entry.0 += rating * similarity;
entry.1 += similarity;
}
}
}
let mut scores: Vec<(String, f32)> = item_scores
.into_iter()
.map(|(item_id, (weighted_sum, weight_sum))| {
let score = if weight_sum > 0.0 {
weighted_sum / weight_sum
} else {
0.0
};
(item_id, score)
})
.collect();
scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scores.truncate(top_k);
let recommendations: Vec<Recommendation> = scores
.into_iter()
.map(|(item_id, score)| {
let reason = format!("Users like you also liked this (score: {:.3})", score);
Recommendation::new(item_id, score, reason)
})
.collect();
Ok(recommendations)
}
fn find_similar_users(&self, user_id: &str, top_k: usize) -> Result<Vec<(String, f32)>> {
let user_ratings = self
.user_ratings
.get(user_id)
.ok_or_else(|| anyhow!("User not found: {}", user_id))?;
let mut similarities: Vec<(String, f32)> = self
.user_ratings
.iter()
.filter_map(|(other_id, other_ratings)| {
if other_id == user_id {
return None;
}
let similarity = self.pearson_correlation(user_ratings, other_ratings);
if similarity > 0.0 {
Some((other_id.clone(), similarity))
} else {
None
}
})
.collect();
similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
similarities.truncate(top_k);
Ok(similarities)
}
fn pearson_correlation(
&self,
ratings1: &HashMap<String, f32>,
ratings2: &HashMap<String, f32>,
) -> f32 {
let common_items: Vec<&String> = ratings1
.keys()
.filter(|item| ratings2.contains_key(*item))
.collect();
if common_items.len() < 2 {
return 0.0;
}
let n = common_items.len() as f32;
let sum1: f32 = common_items.iter().map(|item| ratings1[*item]).sum();
let sum2: f32 = common_items.iter().map(|item| ratings2[*item]).sum();
let sum1_sq: f32 = common_items
.iter()
.map(|item| ratings1[*item].powi(2))
.sum();
let sum2_sq: f32 = common_items
.iter()
.map(|item| ratings2[*item].powi(2))
.sum();
let sum_products: f32 = common_items
.iter()
.map(|item| ratings1[*item] * ratings2[*item])
.sum();
let numerator = sum_products - (sum1 * sum2) / n;
let denominator = ((sum1_sq - sum1.powi(2) / n) * (sum2_sq - sum2.powi(2) / n)).sqrt();
if denominator == 0.0 {
0.0
} else {
numerator / denominator
}
}
}
impl Default for CollaborativeRecommender {
fn default() -> Self {
Self::new()
}
}
pub struct HybridRecommender {
content_based: ContentBasedRecommender,
collaborative: CollaborativeRecommender,
content_weight: f32,
}
impl HybridRecommender {
pub fn new(content_weight: f32) -> Self {
Self {
content_based: ContentBasedRecommender::new(),
collaborative: CollaborativeRecommender::new(),
content_weight: content_weight.clamp(0.0, 1.0),
}
}
pub fn add_item(
&mut self,
id: impl Into<String>,
vector: Vec<f32>,
metadata: Metadata,
) -> Result<()> {
self.content_based.add_item(id, vector, metadata)
}
pub fn add_rating(
&mut self,
user_id: impl Into<String>,
item_id: impl Into<String>,
rating: f32,
) {
self.collaborative.add_rating(user_id, item_id, rating);
}
pub fn recommend(
&self,
user_id: &str,
preferences: &[UserPreference],
top_k: usize,
) -> Result<Vec<Recommendation>> {
let content_recs = self.content_based.recommend(preferences, top_k * 2)?;
let collab_recs = self.collaborative.recommend(user_id, top_k * 2)?;
let mut combined: HashMap<String, (f32, String)> = HashMap::new();
for rec in content_recs {
let score = rec.score * self.content_weight;
combined.insert(rec.item_id, (score, "content-based".to_string()));
}
for rec in collab_recs {
let collab_score = rec.score * (1.0 - self.content_weight);
combined
.entry(rec.item_id)
.and_modify(|(score, reason)| {
*score += collab_score;
*reason = "hybrid (content + collaborative)".to_string();
})
.or_insert((collab_score, "collaborative".to_string()));
}
let mut results: Vec<(String, f32, String)> = combined
.into_iter()
.map(|(id, (score, reason))| (id, score, reason))
.collect();
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
results.truncate(top_k);
let recommendations: Vec<Recommendation> = results
.into_iter()
.map(|(item_id, score, reason)| Recommendation::new(item_id, score, reason))
.collect();
Ok(recommendations)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
fn create_test_metadata(title: &str) -> Metadata {
let mut fields = HashMap::new();
fields.insert("title".to_string(), serde_json::json!(title));
Metadata { fields }
}
#[test]
fn test_content_based_recommender() -> Result<()> {
let mut recommender = ContentBasedRecommender::new();
recommender.add_item(
"movie1",
vec![1.0, 0.0, 0.0],
create_test_metadata("Action Movie"),
)?;
recommender.add_item(
"movie2",
vec![0.9, 0.1, 0.0],
create_test_metadata("Action Comedy"),
)?;
recommender.add_item(
"movie3",
vec![0.0, 1.0, 0.0],
create_test_metadata("Comedy"),
)?;
recommender.add_item("movie4", vec![0.0, 0.0, 1.0], create_test_metadata("Drama"))?;
let preferences = vec![UserPreference::new("movie1", 5.0)];
let recs = recommender.recommend(&preferences, 2)?;
assert_eq!(recs.len(), 2);
assert_eq!(recs[0].item_id, "movie2");
Ok(())
}
#[test]
fn test_similar_items() -> Result<()> {
let mut recommender = ContentBasedRecommender::new();
recommender.add_item("item1", vec![1.0, 0.0], create_test_metadata("Item 1"))?;
recommender.add_item("item2", vec![0.9, 0.1], create_test_metadata("Item 2"))?;
recommender.add_item("item3", vec![0.0, 1.0], create_test_metadata("Item 3"))?;
let similar = recommender.similar_items("item1", 2)?;
assert_eq!(similar.len(), 2);
assert_eq!(similar[0].item_id, "item2");
Ok(())
}
#[test]
fn test_collaborative_recommender() -> Result<()> {
let mut recommender = CollaborativeRecommender::new();
recommender.add_rating("user1", "item1", 5.0);
recommender.add_rating("user1", "item2", 5.0);
recommender.add_rating("user1", "item3", 1.0);
recommender.add_rating("user2", "item1", 4.0);
recommender.add_rating("user2", "item2", 5.0);
recommender.add_rating("user2", "item3", 2.0);
recommender.add_rating("user2", "item4", 5.0);
let recs = recommender.recommend("user1", 2)?;
assert!(!recs.is_empty());
assert!(recs.iter().any(|r| r.item_id == "item4"));
Ok(())
}
#[test]
fn test_hybrid_recommender() -> Result<()> {
let mut recommender = HybridRecommender::new(0.5);
recommender.add_item("item1", vec![1.0, 0.0], create_test_metadata("Item 1"))?;
recommender.add_item("item2", vec![0.9, 0.1], create_test_metadata("Item 2"))?;
recommender.add_item("item3", vec![0.0, 1.0], create_test_metadata("Item 3"))?;
recommender.add_rating("user1", "item1", 5.0);
let preferences = vec![UserPreference::new("item1", 5.0)];
let recs = recommender.recommend("user1", &preferences, 2)?;
assert!(!recs.is_empty());
Ok(())
}
#[test]
fn test_user_profile_building() -> Result<()> {
let mut recommender = ContentBasedRecommender::new();
recommender.add_item("item1", vec![1.0, 0.0], create_test_metadata("Item 1"))?;
recommender.add_item("item2", vec![0.0, 1.0], create_test_metadata("Item 2"))?;
let preferences = vec![
UserPreference::new("item1", 4.0),
UserPreference::new("item2", 2.0),
];
let profile = recommender.build_user_profile(&preferences)?;
assert!(profile[0] > profile[1]);
Ok(())
}
}