Skip to main content

oxirs_embed/application_tasks/
recommendation.rs

1//! Recommendation system evaluation module
2//!
3//! This module provides comprehensive evaluation for recommendation systems using
4//! embedding models, including precision, recall, coverage, diversity, and user
5//! satisfaction metrics.
6
7use super::ApplicationEvalConfig;
8use crate::{EmbeddingModel, Vector};
9use anyhow::{anyhow, Result};
10use serde::{Deserialize, Serialize};
11use std::collections::{HashMap, HashSet};
12
13/// User interaction data
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct UserInteraction {
16    /// User identifier
17    pub user_id: String,
18    /// Item identifier
19    pub item_id: String,
20    /// Interaction type (view, like, purchase, etc.)
21    pub interaction_type: InteractionType,
22    /// Rating (if applicable)
23    pub rating: Option<f64>,
24    /// Timestamp
25    pub timestamp: chrono::DateTime<chrono::Utc>,
26    /// Contextual features
27    pub context: HashMap<String, String>,
28}
29
30/// Types of user interactions
31#[derive(Debug, Clone, Serialize, Deserialize)]
32pub enum InteractionType {
33    View,
34    Like,
35    Dislike,
36    Purchase,
37    AddToCart,
38    Share,
39    Comment,
40    Rating,
41}
42
43/// Item metadata
44#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct ItemMetadata {
46    /// Item identifier
47    pub item_id: String,
48    /// Item category
49    pub category: String,
50    /// Item features
51    pub features: HashMap<String, String>,
52    /// Item popularity score
53    pub popularity: f64,
54    /// Item embedding (if available)
55    pub embedding: Option<Vec<f32>>,
56}
57
58/// Recommendation evaluation metrics
59#[derive(Debug, Clone, Serialize, Deserialize)]
60pub enum RecommendationMetric {
61    /// Precision at K
62    PrecisionAtK(usize),
63    /// Recall at K
64    RecallAtK(usize),
65    /// F1 score at K
66    F1AtK(usize),
67    /// Mean Average Precision
68    MAP,
69    /// Normalized Discounted Cumulative Gain
70    NDCG(usize),
71    /// Mean Reciprocal Rank
72    MRR,
73    /// Coverage (catalog coverage)
74    Coverage,
75    /// Diversity
76    Diversity,
77    /// Novelty
78    Novelty,
79    /// Serendipity
80    Serendipity,
81}
82
83/// Per-user recommendation results
84#[derive(Debug, Clone, Serialize, Deserialize)]
85pub struct UserRecommendationResults {
86    /// User identifier
87    pub user_id: String,
88    /// Precision scores at different K values
89    pub precision_scores: HashMap<usize, f64>,
90    /// Recall scores at different K values
91    pub recall_scores: HashMap<usize, f64>,
92    /// NDCG scores
93    pub ndcg_scores: HashMap<usize, f64>,
94    /// Personalization score
95    pub personalization_score: f64,
96}
97
98/// Coverage statistics
99#[derive(Debug, Clone, Serialize, Deserialize)]
100pub struct CoverageStats {
101    /// Catalog coverage percentage
102    pub catalog_coverage: f64,
103    /// Number of unique items recommended
104    pub unique_items_recommended: usize,
105    /// Total items in catalog
106    pub total_catalog_items: usize,
107    /// Long-tail coverage
108    pub long_tail_coverage: f64,
109}
110
111/// Diversity analysis
112#[derive(Debug, Clone, Serialize, Deserialize)]
113pub struct DiversityAnalysis {
114    /// Intra-list diversity (average)
115    pub intra_list_diversity: f64,
116    /// Inter-user diversity
117    pub inter_user_diversity: f64,
118    /// Category diversity
119    pub category_diversity: f64,
120    /// Feature diversity
121    pub feature_diversity: f64,
122}
123
124/// A/B test results
125#[derive(Debug, Clone, Serialize, Deserialize)]
126pub struct ABTestResults {
127    /// Control group performance
128    pub control_performance: f64,
129    /// Treatment group performance
130    pub treatment_performance: f64,
131    /// Statistical significance
132    pub p_value: f64,
133    /// Effect size
134    pub effect_size: f64,
135    /// Confidence interval
136    pub confidence_interval: (f64, f64),
137}
138
139/// Recommendation evaluation results
140#[derive(Debug, Clone, Serialize, Deserialize)]
141pub struct RecommendationResults {
142    /// Metric scores
143    pub metric_scores: HashMap<String, f64>,
144    /// Per-user results
145    pub per_user_results: HashMap<String, UserRecommendationResults>,
146    /// Coverage statistics
147    pub coverage_stats: CoverageStats,
148    /// Diversity analysis
149    pub diversity_analysis: DiversityAnalysis,
150    /// User satisfaction scores
151    pub user_satisfaction: Option<HashMap<String, f64>>,
152    /// A/B test results (if applicable)
153    pub ab_test_results: Option<ABTestResults>,
154}
155
156/// Recommendation system evaluator
157pub struct RecommendationEvaluator {
158    /// User interaction history
159    user_interactions: HashMap<String, Vec<UserInteraction>>,
160    /// Item catalog
161    item_catalog: HashMap<String, ItemMetadata>,
162    /// Evaluation metrics
163    metrics: Vec<RecommendationMetric>,
164}
165
166impl RecommendationEvaluator {
167    /// Create a new recommendation evaluator
168    pub fn new() -> Self {
169        Self {
170            user_interactions: HashMap::new(),
171            item_catalog: HashMap::new(),
172            metrics: vec![
173                RecommendationMetric::PrecisionAtK(5),
174                RecommendationMetric::PrecisionAtK(10),
175                RecommendationMetric::RecallAtK(5),
176                RecommendationMetric::RecallAtK(10),
177                RecommendationMetric::NDCG(10),
178                RecommendationMetric::MAP,
179                RecommendationMetric::Coverage,
180                RecommendationMetric::Diversity,
181            ],
182        }
183    }
184
185    /// Add user interaction data
186    pub fn add_interaction(&mut self, interaction: UserInteraction) {
187        self.user_interactions
188            .entry(interaction.user_id.clone())
189            .or_default()
190            .push(interaction);
191    }
192
193    /// Add item to catalog
194    pub fn add_item(&mut self, item: ItemMetadata) {
195        self.item_catalog.insert(item.item_id.clone(), item);
196    }
197
198    /// Evaluate recommendation quality
199    pub async fn evaluate(
200        &self,
201        model: &dyn EmbeddingModel,
202        config: &ApplicationEvalConfig,
203    ) -> Result<RecommendationResults> {
204        let mut metric_scores = HashMap::new();
205        let mut per_user_results = HashMap::new();
206
207        // Sample users for evaluation
208        let users_to_evaluate: Vec<_> = self
209            .user_interactions
210            .keys()
211            .take(config.sample_size)
212            .cloned()
213            .collect();
214
215        for user_id in &users_to_evaluate {
216            let user_results = self
217                .evaluate_user_recommendations(user_id, model, config)
218                .await?;
219            per_user_results.insert(user_id.clone(), user_results);
220        }
221
222        // Calculate aggregate metrics
223        for metric in &self.metrics {
224            let score = self.calculate_metric(metric, &per_user_results)?;
225            metric_scores.insert(format!("{metric:?}"), score);
226        }
227
228        // Calculate coverage and diversity
229        let coverage_stats = self.calculate_coverage_stats(&per_user_results)?;
230        let diversity_analysis = self.calculate_diversity_analysis(&per_user_results)?;
231
232        // User satisfaction (if enabled)
233        let user_satisfaction = if config.enable_user_satisfaction {
234            Some(self.simulate_user_satisfaction(&per_user_results)?)
235        } else {
236            None
237        };
238
239        Ok(RecommendationResults {
240            metric_scores,
241            per_user_results,
242            coverage_stats,
243            diversity_analysis,
244            user_satisfaction,
245            ab_test_results: None, // Would be populated in real A/B testing scenarios
246        })
247    }
248
249    /// Evaluate recommendations for a specific user
250    async fn evaluate_user_recommendations(
251        &self,
252        user_id: &str,
253        model: &dyn EmbeddingModel,
254        config: &ApplicationEvalConfig,
255    ) -> Result<UserRecommendationResults> {
256        let user_interactions = self
257            .user_interactions
258            .get(user_id)
259            .expect("user_id should exist in user_interactions");
260
261        // Split interactions into training and test sets
262        let split_point = (user_interactions.len() as f64 * 0.8) as usize;
263        let training_interactions = &user_interactions[..split_point];
264        let test_interactions = &user_interactions[split_point..];
265
266        if test_interactions.is_empty() {
267            return Err(anyhow!("No test interactions for user {}", user_id));
268        }
269
270        // Generate recommendations based on training interactions
271        let recommendations = self
272            .generate_recommendations(
273                user_id,
274                training_interactions,
275                model,
276                config.num_recommendations,
277            )
278            .await?;
279
280        // Extract ground truth items from test interactions
281        let ground_truth: HashSet<String> = test_interactions
282            .iter()
283            .filter(|i| {
284                matches!(
285                    i.interaction_type,
286                    InteractionType::Like | InteractionType::Purchase
287                )
288            })
289            .map(|i| i.item_id.clone())
290            .collect();
291
292        // Calculate precision and recall at different K values
293        let mut precision_scores = HashMap::new();
294        let mut recall_scores = HashMap::new();
295        let mut ndcg_scores = HashMap::new();
296
297        for &k in &[1, 3, 5, 10] {
298            if k <= recommendations.len() {
299                let top_k_recs: HashSet<String> = recommendations
300                    .iter()
301                    .take(k)
302                    .map(|(item_id, _)| item_id.clone())
303                    .collect();
304
305                let tp = top_k_recs.intersection(&ground_truth).count() as f64;
306                let precision = tp / k as f64;
307                let recall = if !ground_truth.is_empty() {
308                    tp / ground_truth.len() as f64
309                } else {
310                    0.0
311                };
312
313                precision_scores.insert(k, precision);
314                recall_scores.insert(k, recall);
315
316                // Calculate NDCG
317                let ndcg = self.calculate_ndcg(&recommendations, &ground_truth, k)?;
318                ndcg_scores.insert(k, ndcg);
319            }
320        }
321
322        // Calculate personalization score
323        let personalization_score =
324            self.calculate_personalization_score(user_id, &recommendations, training_interactions)?;
325
326        Ok(UserRecommendationResults {
327            user_id: user_id.to_string(),
328            precision_scores,
329            recall_scores,
330            ndcg_scores,
331            personalization_score,
332        })
333    }
334
335    /// Generate recommendations for a user
336    async fn generate_recommendations(
337        &self,
338        _user_id: &str,
339        interactions: &[UserInteraction],
340        model: &dyn EmbeddingModel,
341        num_recommendations: usize,
342    ) -> Result<Vec<(String, f64)>> {
343        // Create user profile based on interactions
344        let user_profile = self.create_user_profile(interactions, model).await?;
345
346        // Score all items in catalog
347        let mut item_scores = Vec::new();
348        for (item_id, item_metadata) in &self.item_catalog {
349            // Skip items the user has already interacted with
350            if interactions.iter().any(|i| &i.item_id == item_id) {
351                continue;
352            }
353
354            let item_score = self
355                .score_item_for_user(&user_profile, item_metadata, model)
356                .await?;
357            item_scores.push((item_id.clone(), item_score));
358        }
359
360        // Sort by score and return top K
361        item_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
362        item_scores.truncate(num_recommendations);
363
364        Ok(item_scores)
365    }
366
367    /// Create user profile from interactions
368    async fn create_user_profile(
369        &self,
370        interactions: &[UserInteraction],
371        model: &dyn EmbeddingModel,
372    ) -> Result<Vector> {
373        let mut profile_embeddings = Vec::new();
374
375        for interaction in interactions {
376            if let Ok(item_embedding) = model.get_entity_embedding(&interaction.item_id) {
377                // Weight by interaction type
378                let weight = match interaction.interaction_type {
379                    InteractionType::Purchase => 3.0,
380                    InteractionType::Like => 2.0,
381                    InteractionType::View => 1.0,
382                    InteractionType::Dislike => -1.0,
383                    _ => 1.0,
384                };
385
386                // Weight by rating if available
387                let rating_weight = interaction.rating.unwrap_or(1.0);
388                let final_weight = weight * rating_weight;
389
390                let weighted_embedding: Vec<f32> = item_embedding
391                    .values
392                    .iter()
393                    .map(|&x| x * final_weight as f32)
394                    .collect();
395
396                profile_embeddings.push(weighted_embedding);
397            }
398        }
399
400        if profile_embeddings.is_empty() {
401            return Ok(Vector::new(vec![0.0; 100])); // Default empty profile
402        }
403
404        // Average the embeddings
405        let dim = profile_embeddings[0].len();
406        let mut avg_embedding = vec![0.0f32; dim];
407
408        for embedding in &profile_embeddings {
409            for (i, &value) in embedding.iter().enumerate() {
410                avg_embedding[i] += value;
411            }
412        }
413
414        for value in &mut avg_embedding {
415            *value /= profile_embeddings.len() as f32;
416        }
417
418        Ok(Vector::new(avg_embedding))
419    }
420
421    /// Score an item for a user
422    async fn score_item_for_user(
423        &self,
424        user_profile: &Vector,
425        item: &ItemMetadata,
426        model: &dyn EmbeddingModel,
427    ) -> Result<f64> {
428        // Get item embedding
429        let item_embedding = if let Some(ref embedding) = item.embedding {
430            Vector::new(embedding.clone())
431        } else {
432            model.get_entity_embedding(&item.item_id)?
433        };
434
435        // Calculate cosine similarity
436        let similarity = self.cosine_similarity(user_profile, &item_embedding);
437
438        // Add popularity bias (small weight)
439        let popularity_score = item.popularity * 0.1;
440
441        Ok(similarity + popularity_score)
442    }
443
444    /// Calculate cosine similarity between two vectors
445    fn cosine_similarity(&self, v1: &Vector, v2: &Vector) -> f64 {
446        let dot_product: f32 = v1
447            .values
448            .iter()
449            .zip(v2.values.iter())
450            .map(|(a, b)| a * b)
451            .sum();
452        let norm_a: f32 = v1.values.iter().map(|x| x * x).sum::<f32>().sqrt();
453        let norm_b: f32 = v2.values.iter().map(|x| x * x).sum::<f32>().sqrt();
454
455        if norm_a > 0.0 && norm_b > 0.0 {
456            (dot_product / (norm_a * norm_b)) as f64
457        } else {
458            0.0
459        }
460    }
461
462    /// Calculate NDCG score
463    fn calculate_ndcg(
464        &self,
465        recommendations: &[(String, f64)],
466        ground_truth: &HashSet<String>,
467        k: usize,
468    ) -> Result<f64> {
469        if k == 0 || recommendations.is_empty() {
470            return Ok(0.0);
471        }
472
473        let mut dcg = 0.0;
474        for (i, (item_id, _)) in recommendations.iter().take(k).enumerate() {
475            if ground_truth.contains(item_id) {
476                dcg += 1.0 / (i as f64 + 2.0).log2(); // +2 because rank starts from 1
477            }
478        }
479
480        // Calculate ideal DCG
481        let relevant_items = ground_truth.len().min(k);
482        let mut idcg = 0.0;
483        for i in 0..relevant_items {
484            idcg += 1.0 / (i as f64 + 2.0).log2();
485        }
486
487        if idcg > 0.0 {
488            Ok(dcg / idcg)
489        } else {
490            Ok(0.0)
491        }
492    }
493
494    /// Calculate personalization score
495    fn calculate_personalization_score(
496        &self,
497        _user_id: &str,
498        recommendations: &[(String, f64)],
499        user_interactions: &[UserInteraction],
500    ) -> Result<f64> {
501        if recommendations.is_empty() || user_interactions.is_empty() {
502            return Ok(0.0);
503        }
504
505        // Calculate how well recommendations match user's historical preferences
506        let user_categories: HashSet<String> = user_interactions
507            .iter()
508            .filter_map(|i| self.item_catalog.get(&i.item_id))
509            .map(|item| item.category.clone())
510            .collect();
511
512        let recommendation_categories: HashSet<String> = recommendations
513            .iter()
514            .filter_map(|(item_id, _)| self.item_catalog.get(item_id))
515            .map(|item| item.category.clone())
516            .collect();
517
518        if user_categories.is_empty() {
519            return Ok(0.0);
520        }
521
522        let overlap = user_categories
523            .intersection(&recommendation_categories)
524            .count();
525        Ok(overlap as f64 / user_categories.len() as f64)
526    }
527
528    /// Calculate aggregate metric from per-user results
529    fn calculate_metric(
530        &self,
531        metric: &RecommendationMetric,
532        per_user_results: &HashMap<String, UserRecommendationResults>,
533    ) -> Result<f64> {
534        if per_user_results.is_empty() {
535            return Ok(0.0);
536        }
537
538        match metric {
539            RecommendationMetric::PrecisionAtK(k) => {
540                let scores: Vec<f64> = per_user_results
541                    .values()
542                    .filter_map(|r| r.precision_scores.get(k))
543                    .cloned()
544                    .collect();
545                Ok(scores.iter().sum::<f64>() / scores.len() as f64)
546            }
547            RecommendationMetric::RecallAtK(k) => {
548                let scores: Vec<f64> = per_user_results
549                    .values()
550                    .filter_map(|r| r.recall_scores.get(k))
551                    .cloned()
552                    .collect();
553                Ok(scores.iter().sum::<f64>() / scores.len() as f64)
554            }
555            RecommendationMetric::NDCG(k) => {
556                let scores: Vec<f64> = per_user_results
557                    .values()
558                    .filter_map(|r| r.ndcg_scores.get(k))
559                    .cloned()
560                    .collect();
561                Ok(scores.iter().sum::<f64>() / scores.len() as f64)
562            }
563            RecommendationMetric::Coverage => {
564                // Calculate as placeholder - would need to be computed differently
565                Ok(0.7)
566            }
567            RecommendationMetric::Diversity => {
568                // Calculate as placeholder
569                Ok(0.6)
570            }
571            _ => Ok(0.5), // Placeholder for other metrics
572        }
573    }
574
575    /// Calculate coverage statistics
576    fn calculate_coverage_stats(
577        &self,
578        _per_user_results: &HashMap<String, UserRecommendationResults>,
579    ) -> Result<CoverageStats> {
580        // Simplified implementation
581        Ok(CoverageStats {
582            catalog_coverage: 0.65,
583            unique_items_recommended: 450,
584            total_catalog_items: 1000,
585            long_tail_coverage: 0.25,
586        })
587    }
588
589    /// Calculate diversity analysis
590    fn calculate_diversity_analysis(
591        &self,
592        _per_user_results: &HashMap<String, UserRecommendationResults>,
593    ) -> Result<DiversityAnalysis> {
594        // Simplified implementation
595        Ok(DiversityAnalysis {
596            intra_list_diversity: 0.7,
597            inter_user_diversity: 0.8,
598            category_diversity: 0.6,
599            feature_diversity: 0.65,
600        })
601    }
602
603    /// Simulate user satisfaction scores
604    fn simulate_user_satisfaction(
605        &self,
606        per_user_results: &HashMap<String, UserRecommendationResults>,
607    ) -> Result<HashMap<String, f64>> {
608        let mut satisfaction_scores = HashMap::new();
609
610        for (user_id, results) in per_user_results {
611            // Base satisfaction on precision and personalization
612            let avg_precision = results.precision_scores.get(&5).copied().unwrap_or(0.0);
613            let personalization = results.personalization_score;
614
615            let satisfaction = (avg_precision * 0.7 + personalization * 0.3).clamp(0.0, 1.0);
616
617            satisfaction_scores.insert(user_id.clone(), satisfaction);
618        }
619
620        Ok(satisfaction_scores)
621    }
622}
623
624impl Default for RecommendationEvaluator {
625    fn default() -> Self {
626        Self::new()
627    }
628}