use crate::error::RecommendResult;
use crate::{ContentMetadata, Recommendation, RecommendationReason, RecommendationRequest};
use ndarray::{Array2, ArrayView2};
use std::collections::HashMap;
use uuid::Uuid;
#[derive(Debug, Clone)]
pub struct UserItemMatrix {
data: Array2<f32>,
user_to_index: HashMap<Uuid, usize>,
item_to_index: HashMap<Uuid, usize>,
index_to_user: Vec<Uuid>,
index_to_item: Vec<Uuid>,
}
impl UserItemMatrix {
#[must_use]
pub fn new(num_users: usize, num_items: usize) -> Self {
Self {
data: Array2::zeros((num_users, num_items)),
user_to_index: HashMap::new(),
item_to_index: HashMap::new(),
index_to_user: Vec::new(),
index_to_item: Vec::new(),
}
}
pub fn add_user(&mut self, user_id: Uuid) -> usize {
if let Some(&index) = self.user_to_index.get(&user_id) {
return index;
}
let index = self.index_to_user.len();
self.user_to_index.insert(user_id, index);
self.index_to_user.push(user_id);
if index >= self.data.nrows() {
let new_rows = index + 1 - self.data.nrows();
let zeros = Array2::zeros((new_rows, self.data.ncols()));
self.data = ndarray::concatenate![ndarray::Axis(0), self.data, zeros];
}
index
}
pub fn add_item(&mut self, item_id: Uuid) -> usize {
if let Some(&index) = self.item_to_index.get(&item_id) {
return index;
}
let index = self.index_to_item.len();
self.item_to_index.insert(item_id, index);
self.index_to_item.push(item_id);
if index >= self.data.ncols() {
let new_cols = index + 1 - self.data.ncols();
let zeros = Array2::zeros((self.data.nrows(), new_cols));
self.data = ndarray::concatenate![ndarray::Axis(1), self.data, zeros];
}
index
}
pub fn set_rating(&mut self, user_id: Uuid, item_id: Uuid, rating: f32) {
let user_idx = self.add_user(user_id);
let item_idx = self.add_item(item_id);
self.data[[user_idx, item_idx]] = rating;
}
#[must_use]
pub fn get_rating(&self, user_id: Uuid, item_id: Uuid) -> Option<f32> {
let user_idx = self.user_to_index.get(&user_id)?;
let item_idx = self.item_to_index.get(&item_id)?;
Some(self.data[[*user_idx, *item_idx]])
}
#[must_use]
pub fn get_user_ratings(&self, user_id: Uuid) -> Option<Vec<f32>> {
let user_idx = self.user_to_index.get(&user_id)?;
Some(self.data.row(*user_idx).to_vec())
}
#[must_use]
pub fn get_item_ratings(&self, item_id: Uuid) -> Option<Vec<f32>> {
let item_idx = self.item_to_index.get(&item_id)?;
Some(self.data.column(*item_idx).to_vec())
}
#[must_use]
pub fn as_view(&self) -> ArrayView2<'_, f32> {
self.data.view()
}
#[must_use]
pub fn num_users(&self) -> usize {
self.index_to_user.len()
}
#[must_use]
pub fn num_items(&self) -> usize {
self.index_to_item.len()
}
#[must_use]
pub fn get_item_id(&self, index: usize) -> Option<Uuid> {
self.index_to_item.get(index).copied()
}
#[must_use]
pub fn get_user_id(&self, index: usize) -> Option<Uuid> {
self.index_to_user.get(index).copied()
}
#[must_use]
pub fn get_rated_items(&self, user_id: Uuid) -> Vec<(Uuid, f32)> {
let Some(&user_idx) = self.user_to_index.get(&user_id) else {
return Vec::new();
};
self.data
.row(user_idx)
.iter()
.enumerate()
.filter(|(_, &rating)| rating > 0.0)
.filter_map(|(item_idx, &rating)| {
self.index_to_item
.get(item_idx)
.map(|&item_id| (item_id, rating))
})
.collect()
}
}
pub struct CollaborativeEngine {
matrix: UserItemMatrix,
content_metadata: HashMap<Uuid, ContentMetadata>,
knn: super::knn::KnnCalculator,
}
impl CollaborativeEngine {
#[must_use]
pub fn new() -> Self {
Self {
matrix: UserItemMatrix::new(0, 0),
content_metadata: HashMap::new(),
knn: super::knn::KnnCalculator::new(10),
}
}
pub fn add_rating(&mut self, user_id: Uuid, content_id: Uuid, rating: f32) {
self.matrix.set_rating(user_id, content_id, rating);
}
pub fn add_content(&mut self, content_id: Uuid, metadata: ContentMetadata) {
self.content_metadata.insert(content_id, metadata);
}
pub fn recommend(
&self,
request: &RecommendationRequest,
) -> RecommendResult<Vec<Recommendation>> {
let similar_users = self
.knn
.find_similar_users(&self.matrix, request.user_id, 20)?;
let mut candidate_items: HashMap<Uuid, f32> = HashMap::new();
for (similar_user, similarity) in similar_users {
let rated_items = self.matrix.get_rated_items(similar_user);
for (item_id, rating) in rated_items {
if self.matrix.get_rating(request.user_id, item_id).is_some() {
continue;
}
*candidate_items.entry(item_id).or_insert(0.0) += rating * similarity;
}
}
let mut recommendations: Vec<Recommendation> = candidate_items
.into_iter()
.filter_map(|(content_id, score)| {
self.content_metadata
.get(&content_id)
.map(|metadata| Recommendation {
content_id,
score,
rank: 0,
reasons: vec![RecommendationReason::CollaborativeFiltering {
confidence: score,
}],
metadata: metadata.clone(),
explanation: None,
})
})
.collect();
recommendations.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
for (idx, rec) in recommendations.iter_mut().enumerate() {
rec.rank = idx + 1;
}
recommendations.truncate(request.limit);
Ok(recommendations)
}
}
impl Default for CollaborativeEngine {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_user_item_matrix_creation() {
let matrix = UserItemMatrix::new(10, 20);
assert_eq!(matrix.data.nrows(), 10);
assert_eq!(matrix.data.ncols(), 20);
}
#[test]
fn test_add_user() {
let mut matrix = UserItemMatrix::new(0, 0);
let user_id = Uuid::new_v4();
let index = matrix.add_user(user_id);
assert_eq!(index, 0);
let index2 = matrix.add_user(user_id);
assert_eq!(index2, 0); }
#[test]
fn test_add_item() {
let mut matrix = UserItemMatrix::new(0, 0);
let item_id = Uuid::new_v4();
let index = matrix.add_item(item_id);
assert_eq!(index, 0);
}
#[test]
fn test_set_get_rating() {
let mut matrix = UserItemMatrix::new(0, 0);
let user_id = Uuid::new_v4();
let item_id = Uuid::new_v4();
matrix.set_rating(user_id, item_id, 4.5);
let rating = matrix.get_rating(user_id, item_id);
assert_eq!(rating, Some(4.5));
}
#[test]
fn test_collaborative_engine_creation() {
let engine = CollaborativeEngine::new();
assert_eq!(engine.matrix.num_users(), 0);
assert_eq!(engine.matrix.num_items(), 0);
}
#[test]
fn test_add_rating_to_engine() {
let mut engine = CollaborativeEngine::new();
let user_id = Uuid::new_v4();
let content_id = Uuid::new_v4();
engine.add_rating(user_id, content_id, 5.0);
let rating = engine.matrix.get_rating(user_id, content_id);
assert_eq!(rating, Some(5.0));
}
}