use crate::error::RecommendResult;
use crate::{Recommendation, RecommendationRequest};
use std::collections::HashMap;
use uuid::Uuid;
pub struct HybridCombiner {
weights: MethodWeights,
}
#[derive(Debug, Clone)]
pub struct MethodWeights {
pub content_based: f32,
pub collaborative: f32,
pub trending: f32,
pub popularity: f32,
}
impl Default for MethodWeights {
fn default() -> Self {
Self {
content_based: 0.4,
collaborative: 0.4,
trending: 0.1,
popularity: 0.1,
}
}
}
impl HybridCombiner {
#[must_use]
pub fn new() -> Self {
Self {
weights: MethodWeights::default(),
}
}
pub fn set_weights(&mut self, weights: MethodWeights) {
self.weights = weights;
}
pub fn combine_recommendations(
&self,
recommendations_by_method: Vec<(String, Vec<Recommendation>)>,
) -> RecommendResult<Vec<Recommendation>> {
let mut combined_scores: HashMap<Uuid, f32> = HashMap::new();
let mut recommendation_map: HashMap<Uuid, Recommendation> = HashMap::new();
for (method, recommendations) in recommendations_by_method {
let weight = self.get_method_weight(&method);
for rec in recommendations {
let weighted_score = rec.score * weight;
*combined_scores.entry(rec.content_id).or_insert(0.0) += weighted_score;
recommendation_map.entry(rec.content_id).or_insert(rec);
}
}
let mut combined: Vec<Recommendation> = combined_scores
.into_iter()
.filter_map(|(content_id, score)| {
recommendation_map.get(&content_id).map(|rec| {
let mut combined_rec = rec.clone();
combined_rec.score = score;
combined_rec
})
})
.collect();
combined.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
for (idx, rec) in combined.iter_mut().enumerate() {
rec.rank = idx + 1;
}
Ok(combined)
}
fn get_method_weight(&self, method: &str) -> f32 {
match method {
"content_based" => self.weights.content_based,
"collaborative" => self.weights.collaborative,
"trending" => self.weights.trending,
"popularity" => self.weights.popularity,
_ => 0.0,
}
}
pub fn recommend(
&self,
_request: &RecommendationRequest,
) -> RecommendResult<Vec<Recommendation>> {
Ok(Vec::new())
}
pub fn normalize_weights(&mut self) {
let total = self.weights.content_based
+ self.weights.collaborative
+ self.weights.trending
+ self.weights.popularity;
if total > f32::EPSILON {
self.weights.content_based /= total;
self.weights.collaborative /= total;
self.weights.trending /= total;
self.weights.popularity /= total;
}
}
}
impl Default for HybridCombiner {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Copy)]
pub enum HybridStrategy {
Weighted,
Switching,
Mixed,
Cascade,
}
pub struct HybridRecommender {
combiner: HybridCombiner,
strategy: HybridStrategy,
}
impl HybridRecommender {
#[must_use]
pub fn new(strategy: HybridStrategy) -> Self {
Self {
combiner: HybridCombiner::new(),
strategy,
}
}
pub fn recommend(
&self,
request: &RecommendationRequest,
) -> RecommendResult<Vec<Recommendation>> {
match self.strategy {
HybridStrategy::Weighted => self.weighted_recommend(request),
HybridStrategy::Switching => self.switching_recommend(request),
HybridStrategy::Mixed => self.mixed_recommend(request),
HybridStrategy::Cascade => self.cascade_recommend(request),
}
}
fn weighted_recommend(
&self,
request: &RecommendationRequest,
) -> RecommendResult<Vec<Recommendation>> {
self.combiner.recommend(request)
}
fn switching_recommend(
&self,
_request: &RecommendationRequest,
) -> RecommendResult<Vec<Recommendation>> {
Ok(Vec::new())
}
fn mixed_recommend(
&self,
_request: &RecommendationRequest,
) -> RecommendResult<Vec<Recommendation>> {
Ok(Vec::new())
}
fn cascade_recommend(
&self,
_request: &RecommendationRequest,
) -> RecommendResult<Vec<Recommendation>> {
Ok(Vec::new())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_hybrid_combiner_creation() {
let combiner = HybridCombiner::new();
assert!((combiner.weights.content_based - 0.4).abs() < f32::EPSILON);
}
#[test]
fn test_method_weights_default() {
let weights = MethodWeights::default();
let total =
weights.content_based + weights.collaborative + weights.trending + weights.popularity;
assert!((total - 1.0).abs() < f32::EPSILON);
}
#[test]
fn test_normalize_weights() {
let mut combiner = HybridCombiner::new();
combiner.weights.content_based = 2.0;
combiner.weights.collaborative = 2.0;
combiner.weights.trending = 2.0;
combiner.weights.popularity = 2.0;
combiner.normalize_weights();
let total = combiner.weights.content_based
+ combiner.weights.collaborative
+ combiner.weights.trending
+ combiner.weights.popularity;
assert!((total - 1.0).abs() < f32::EPSILON);
}
#[test]
fn test_hybrid_recommender_creation() {
let recommender = HybridRecommender::new(HybridStrategy::Weighted);
assert!(matches!(recommender.strategy, HybridStrategy::Weighted));
}
}