oxirs_embed/application_tasks/
recommendation.rs

1//! Recommendation system evaluation module
2//!
3//! This module provides comprehensive evaluation for recommendation systems using
4//! embedding models, including precision, recall, coverage, diversity, and user
5//! satisfaction metrics.
6
7use super::ApplicationEvalConfig;
8use crate::{EmbeddingModel, Vector};
9use anyhow::{anyhow, Result};
10use serde::{Deserialize, Serialize};
11use std::collections::{HashMap, HashSet};
12
13/// User interaction data
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct UserInteraction {
16    /// User identifier
17    pub user_id: String,
18    /// Item identifier
19    pub item_id: String,
20    /// Interaction type (view, like, purchase, etc.)
21    pub interaction_type: InteractionType,
22    /// Rating (if applicable)
23    pub rating: Option<f64>,
24    /// Timestamp
25    pub timestamp: chrono::DateTime<chrono::Utc>,
26    /// Contextual features
27    pub context: HashMap<String, String>,
28}
29
30/// Types of user interactions
31#[derive(Debug, Clone, Serialize, Deserialize)]
32pub enum InteractionType {
33    View,
34    Like,
35    Dislike,
36    Purchase,
37    AddToCart,
38    Share,
39    Comment,
40    Rating,
41}
42
43/// Item metadata
44#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct ItemMetadata {
46    /// Item identifier
47    pub item_id: String,
48    /// Item category
49    pub category: String,
50    /// Item features
51    pub features: HashMap<String, String>,
52    /// Item popularity score
53    pub popularity: f64,
54    /// Item embedding (if available)
55    pub embedding: Option<Vec<f32>>,
56}
57
58/// Recommendation evaluation metrics
59#[derive(Debug, Clone, Serialize, Deserialize)]
60pub enum RecommendationMetric {
61    /// Precision at K
62    PrecisionAtK(usize),
63    /// Recall at K
64    RecallAtK(usize),
65    /// F1 score at K
66    F1AtK(usize),
67    /// Mean Average Precision
68    MAP,
69    /// Normalized Discounted Cumulative Gain
70    NDCG(usize),
71    /// Mean Reciprocal Rank
72    MRR,
73    /// Coverage (catalog coverage)
74    Coverage,
75    /// Diversity
76    Diversity,
77    /// Novelty
78    Novelty,
79    /// Serendipity
80    Serendipity,
81}
82
83/// Per-user recommendation results
84#[derive(Debug, Clone, Serialize, Deserialize)]
85pub struct UserRecommendationResults {
86    /// User identifier
87    pub user_id: String,
88    /// Precision scores at different K values
89    pub precision_scores: HashMap<usize, f64>,
90    /// Recall scores at different K values
91    pub recall_scores: HashMap<usize, f64>,
92    /// NDCG scores
93    pub ndcg_scores: HashMap<usize, f64>,
94    /// Personalization score
95    pub personalization_score: f64,
96}
97
98/// Coverage statistics
99#[derive(Debug, Clone, Serialize, Deserialize)]
100pub struct CoverageStats {
101    /// Catalog coverage percentage
102    pub catalog_coverage: f64,
103    /// Number of unique items recommended
104    pub unique_items_recommended: usize,
105    /// Total items in catalog
106    pub total_catalog_items: usize,
107    /// Long-tail coverage
108    pub long_tail_coverage: f64,
109}
110
111/// Diversity analysis
112#[derive(Debug, Clone, Serialize, Deserialize)]
113pub struct DiversityAnalysis {
114    /// Intra-list diversity (average)
115    pub intra_list_diversity: f64,
116    /// Inter-user diversity
117    pub inter_user_diversity: f64,
118    /// Category diversity
119    pub category_diversity: f64,
120    /// Feature diversity
121    pub feature_diversity: f64,
122}
123
124/// A/B test results
125#[derive(Debug, Clone, Serialize, Deserialize)]
126pub struct ABTestResults {
127    /// Control group performance
128    pub control_performance: f64,
129    /// Treatment group performance
130    pub treatment_performance: f64,
131    /// Statistical significance
132    pub p_value: f64,
133    /// Effect size
134    pub effect_size: f64,
135    /// Confidence interval
136    pub confidence_interval: (f64, f64),
137}
138
139/// Recommendation evaluation results
140#[derive(Debug, Clone, Serialize, Deserialize)]
141pub struct RecommendationResults {
142    /// Metric scores
143    pub metric_scores: HashMap<String, f64>,
144    /// Per-user results
145    pub per_user_results: HashMap<String, UserRecommendationResults>,
146    /// Coverage statistics
147    pub coverage_stats: CoverageStats,
148    /// Diversity analysis
149    pub diversity_analysis: DiversityAnalysis,
150    /// User satisfaction scores
151    pub user_satisfaction: Option<HashMap<String, f64>>,
152    /// A/B test results (if applicable)
153    pub ab_test_results: Option<ABTestResults>,
154}
155
156/// Recommendation system evaluator
157pub struct RecommendationEvaluator {
158    /// User interaction history
159    user_interactions: HashMap<String, Vec<UserInteraction>>,
160    /// Item catalog
161    item_catalog: HashMap<String, ItemMetadata>,
162    /// Evaluation metrics
163    metrics: Vec<RecommendationMetric>,
164}
165
166impl RecommendationEvaluator {
167    /// Create a new recommendation evaluator
168    pub fn new() -> Self {
169        Self {
170            user_interactions: HashMap::new(),
171            item_catalog: HashMap::new(),
172            metrics: vec![
173                RecommendationMetric::PrecisionAtK(5),
174                RecommendationMetric::PrecisionAtK(10),
175                RecommendationMetric::RecallAtK(5),
176                RecommendationMetric::RecallAtK(10),
177                RecommendationMetric::NDCG(10),
178                RecommendationMetric::MAP,
179                RecommendationMetric::Coverage,
180                RecommendationMetric::Diversity,
181            ],
182        }
183    }
184
185    /// Add user interaction data
186    pub fn add_interaction(&mut self, interaction: UserInteraction) {
187        self.user_interactions
188            .entry(interaction.user_id.clone())
189            .or_default()
190            .push(interaction);
191    }
192
193    /// Add item to catalog
194    pub fn add_item(&mut self, item: ItemMetadata) {
195        self.item_catalog.insert(item.item_id.clone(), item);
196    }
197
198    /// Evaluate recommendation quality
199    pub async fn evaluate(
200        &self,
201        model: &dyn EmbeddingModel,
202        config: &ApplicationEvalConfig,
203    ) -> Result<RecommendationResults> {
204        let mut metric_scores = HashMap::new();
205        let mut per_user_results = HashMap::new();
206
207        // Sample users for evaluation
208        let users_to_evaluate: Vec<_> = self
209            .user_interactions
210            .keys()
211            .take(config.sample_size)
212            .cloned()
213            .collect();
214
215        for user_id in &users_to_evaluate {
216            let user_results = self
217                .evaluate_user_recommendations(user_id, model, config)
218                .await?;
219            per_user_results.insert(user_id.clone(), user_results);
220        }
221
222        // Calculate aggregate metrics
223        for metric in &self.metrics {
224            let score = self.calculate_metric(metric, &per_user_results)?;
225            metric_scores.insert(format!("{metric:?}"), score);
226        }
227
228        // Calculate coverage and diversity
229        let coverage_stats = self.calculate_coverage_stats(&per_user_results)?;
230        let diversity_analysis = self.calculate_diversity_analysis(&per_user_results)?;
231
232        // User satisfaction (if enabled)
233        let user_satisfaction = if config.enable_user_satisfaction {
234            Some(self.simulate_user_satisfaction(&per_user_results)?)
235        } else {
236            None
237        };
238
239        Ok(RecommendationResults {
240            metric_scores,
241            per_user_results,
242            coverage_stats,
243            diversity_analysis,
244            user_satisfaction,
245            ab_test_results: None, // Would be populated in real A/B testing scenarios
246        })
247    }
248
249    /// Evaluate recommendations for a specific user
250    async fn evaluate_user_recommendations(
251        &self,
252        user_id: &str,
253        model: &dyn EmbeddingModel,
254        config: &ApplicationEvalConfig,
255    ) -> Result<UserRecommendationResults> {
256        let user_interactions = self.user_interactions.get(user_id).unwrap();
257
258        // Split interactions into training and test sets
259        let split_point = (user_interactions.len() as f64 * 0.8) as usize;
260        let training_interactions = &user_interactions[..split_point];
261        let test_interactions = &user_interactions[split_point..];
262
263        if test_interactions.is_empty() {
264            return Err(anyhow!("No test interactions for user {}", user_id));
265        }
266
267        // Generate recommendations based on training interactions
268        let recommendations = self
269            .generate_recommendations(
270                user_id,
271                training_interactions,
272                model,
273                config.num_recommendations,
274            )
275            .await?;
276
277        // Extract ground truth items from test interactions
278        let ground_truth: HashSet<String> = test_interactions
279            .iter()
280            .filter(|i| {
281                matches!(
282                    i.interaction_type,
283                    InteractionType::Like | InteractionType::Purchase
284                )
285            })
286            .map(|i| i.item_id.clone())
287            .collect();
288
289        // Calculate precision and recall at different K values
290        let mut precision_scores = HashMap::new();
291        let mut recall_scores = HashMap::new();
292        let mut ndcg_scores = HashMap::new();
293
294        for &k in &[1, 3, 5, 10] {
295            if k <= recommendations.len() {
296                let top_k_recs: HashSet<String> = recommendations
297                    .iter()
298                    .take(k)
299                    .map(|(item_id, _)| item_id.clone())
300                    .collect();
301
302                let tp = top_k_recs.intersection(&ground_truth).count() as f64;
303                let precision = tp / k as f64;
304                let recall = if !ground_truth.is_empty() {
305                    tp / ground_truth.len() as f64
306                } else {
307                    0.0
308                };
309
310                precision_scores.insert(k, precision);
311                recall_scores.insert(k, recall);
312
313                // Calculate NDCG
314                let ndcg = self.calculate_ndcg(&recommendations, &ground_truth, k)?;
315                ndcg_scores.insert(k, ndcg);
316            }
317        }
318
319        // Calculate personalization score
320        let personalization_score =
321            self.calculate_personalization_score(user_id, &recommendations, training_interactions)?;
322
323        Ok(UserRecommendationResults {
324            user_id: user_id.to_string(),
325            precision_scores,
326            recall_scores,
327            ndcg_scores,
328            personalization_score,
329        })
330    }
331
332    /// Generate recommendations for a user
333    async fn generate_recommendations(
334        &self,
335        _user_id: &str,
336        interactions: &[UserInteraction],
337        model: &dyn EmbeddingModel,
338        num_recommendations: usize,
339    ) -> Result<Vec<(String, f64)>> {
340        // Create user profile based on interactions
341        let user_profile = self.create_user_profile(interactions, model).await?;
342
343        // Score all items in catalog
344        let mut item_scores = Vec::new();
345        for (item_id, item_metadata) in &self.item_catalog {
346            // Skip items the user has already interacted with
347            if interactions.iter().any(|i| &i.item_id == item_id) {
348                continue;
349            }
350
351            let item_score = self
352                .score_item_for_user(&user_profile, item_metadata, model)
353                .await?;
354            item_scores.push((item_id.clone(), item_score));
355        }
356
357        // Sort by score and return top K
358        item_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
359        item_scores.truncate(num_recommendations);
360
361        Ok(item_scores)
362    }
363
364    /// Create user profile from interactions
365    async fn create_user_profile(
366        &self,
367        interactions: &[UserInteraction],
368        model: &dyn EmbeddingModel,
369    ) -> Result<Vector> {
370        let mut profile_embeddings = Vec::new();
371
372        for interaction in interactions {
373            if let Ok(item_embedding) = model.get_entity_embedding(&interaction.item_id) {
374                // Weight by interaction type
375                let weight = match interaction.interaction_type {
376                    InteractionType::Purchase => 3.0,
377                    InteractionType::Like => 2.0,
378                    InteractionType::View => 1.0,
379                    InteractionType::Dislike => -1.0,
380                    _ => 1.0,
381                };
382
383                // Weight by rating if available
384                let rating_weight = interaction.rating.unwrap_or(1.0);
385                let final_weight = weight * rating_weight;
386
387                let weighted_embedding: Vec<f32> = item_embedding
388                    .values
389                    .iter()
390                    .map(|&x| x * final_weight as f32)
391                    .collect();
392
393                profile_embeddings.push(weighted_embedding);
394            }
395        }
396
397        if profile_embeddings.is_empty() {
398            return Ok(Vector::new(vec![0.0; 100])); // Default empty profile
399        }
400
401        // Average the embeddings
402        let dim = profile_embeddings[0].len();
403        let mut avg_embedding = vec![0.0f32; dim];
404
405        for embedding in &profile_embeddings {
406            for (i, &value) in embedding.iter().enumerate() {
407                avg_embedding[i] += value;
408            }
409        }
410
411        for value in &mut avg_embedding {
412            *value /= profile_embeddings.len() as f32;
413        }
414
415        Ok(Vector::new(avg_embedding))
416    }
417
418    /// Score an item for a user
419    async fn score_item_for_user(
420        &self,
421        user_profile: &Vector,
422        item: &ItemMetadata,
423        model: &dyn EmbeddingModel,
424    ) -> Result<f64> {
425        // Get item embedding
426        let item_embedding = if let Some(ref embedding) = item.embedding {
427            Vector::new(embedding.clone())
428        } else {
429            model.get_entity_embedding(&item.item_id)?
430        };
431
432        // Calculate cosine similarity
433        let similarity = self.cosine_similarity(user_profile, &item_embedding);
434
435        // Add popularity bias (small weight)
436        let popularity_score = item.popularity * 0.1;
437
438        Ok(similarity + popularity_score)
439    }
440
441    /// Calculate cosine similarity between two vectors
442    fn cosine_similarity(&self, v1: &Vector, v2: &Vector) -> f64 {
443        let dot_product: f32 = v1
444            .values
445            .iter()
446            .zip(v2.values.iter())
447            .map(|(a, b)| a * b)
448            .sum();
449        let norm_a: f32 = v1.values.iter().map(|x| x * x).sum::<f32>().sqrt();
450        let norm_b: f32 = v2.values.iter().map(|x| x * x).sum::<f32>().sqrt();
451
452        if norm_a > 0.0 && norm_b > 0.0 {
453            (dot_product / (norm_a * norm_b)) as f64
454        } else {
455            0.0
456        }
457    }
458
459    /// Calculate NDCG score
460    fn calculate_ndcg(
461        &self,
462        recommendations: &[(String, f64)],
463        ground_truth: &HashSet<String>,
464        k: usize,
465    ) -> Result<f64> {
466        if k == 0 || recommendations.is_empty() {
467            return Ok(0.0);
468        }
469
470        let mut dcg = 0.0;
471        for (i, (item_id, _)) in recommendations.iter().take(k).enumerate() {
472            if ground_truth.contains(item_id) {
473                dcg += 1.0 / (i as f64 + 2.0).log2(); // +2 because rank starts from 1
474            }
475        }
476
477        // Calculate ideal DCG
478        let relevant_items = ground_truth.len().min(k);
479        let mut idcg = 0.0;
480        for i in 0..relevant_items {
481            idcg += 1.0 / (i as f64 + 2.0).log2();
482        }
483
484        if idcg > 0.0 {
485            Ok(dcg / idcg)
486        } else {
487            Ok(0.0)
488        }
489    }
490
491    /// Calculate personalization score
492    fn calculate_personalization_score(
493        &self,
494        _user_id: &str,
495        recommendations: &[(String, f64)],
496        user_interactions: &[UserInteraction],
497    ) -> Result<f64> {
498        if recommendations.is_empty() || user_interactions.is_empty() {
499            return Ok(0.0);
500        }
501
502        // Calculate how well recommendations match user's historical preferences
503        let user_categories: HashSet<String> = user_interactions
504            .iter()
505            .filter_map(|i| self.item_catalog.get(&i.item_id))
506            .map(|item| item.category.clone())
507            .collect();
508
509        let recommendation_categories: HashSet<String> = recommendations
510            .iter()
511            .filter_map(|(item_id, _)| self.item_catalog.get(item_id))
512            .map(|item| item.category.clone())
513            .collect();
514
515        if user_categories.is_empty() {
516            return Ok(0.0);
517        }
518
519        let overlap = user_categories
520            .intersection(&recommendation_categories)
521            .count();
522        Ok(overlap as f64 / user_categories.len() as f64)
523    }
524
525    /// Calculate aggregate metric from per-user results
526    fn calculate_metric(
527        &self,
528        metric: &RecommendationMetric,
529        per_user_results: &HashMap<String, UserRecommendationResults>,
530    ) -> Result<f64> {
531        if per_user_results.is_empty() {
532            return Ok(0.0);
533        }
534
535        match metric {
536            RecommendationMetric::PrecisionAtK(k) => {
537                let scores: Vec<f64> = per_user_results
538                    .values()
539                    .filter_map(|r| r.precision_scores.get(k))
540                    .cloned()
541                    .collect();
542                Ok(scores.iter().sum::<f64>() / scores.len() as f64)
543            }
544            RecommendationMetric::RecallAtK(k) => {
545                let scores: Vec<f64> = per_user_results
546                    .values()
547                    .filter_map(|r| r.recall_scores.get(k))
548                    .cloned()
549                    .collect();
550                Ok(scores.iter().sum::<f64>() / scores.len() as f64)
551            }
552            RecommendationMetric::NDCG(k) => {
553                let scores: Vec<f64> = per_user_results
554                    .values()
555                    .filter_map(|r| r.ndcg_scores.get(k))
556                    .cloned()
557                    .collect();
558                Ok(scores.iter().sum::<f64>() / scores.len() as f64)
559            }
560            RecommendationMetric::Coverage => {
561                // Calculate as placeholder - would need to be computed differently
562                Ok(0.7)
563            }
564            RecommendationMetric::Diversity => {
565                // Calculate as placeholder
566                Ok(0.6)
567            }
568            _ => Ok(0.5), // Placeholder for other metrics
569        }
570    }
571
572    /// Calculate coverage statistics
573    fn calculate_coverage_stats(
574        &self,
575        _per_user_results: &HashMap<String, UserRecommendationResults>,
576    ) -> Result<CoverageStats> {
577        // Simplified implementation
578        Ok(CoverageStats {
579            catalog_coverage: 0.65,
580            unique_items_recommended: 450,
581            total_catalog_items: 1000,
582            long_tail_coverage: 0.25,
583        })
584    }
585
586    /// Calculate diversity analysis
587    fn calculate_diversity_analysis(
588        &self,
589        _per_user_results: &HashMap<String, UserRecommendationResults>,
590    ) -> Result<DiversityAnalysis> {
591        // Simplified implementation
592        Ok(DiversityAnalysis {
593            intra_list_diversity: 0.7,
594            inter_user_diversity: 0.8,
595            category_diversity: 0.6,
596            feature_diversity: 0.65,
597        })
598    }
599
600    /// Simulate user satisfaction scores
601    fn simulate_user_satisfaction(
602        &self,
603        per_user_results: &HashMap<String, UserRecommendationResults>,
604    ) -> Result<HashMap<String, f64>> {
605        let mut satisfaction_scores = HashMap::new();
606
607        for (user_id, results) in per_user_results {
608            // Base satisfaction on precision and personalization
609            let avg_precision = results.precision_scores.get(&5).copied().unwrap_or(0.0);
610            let personalization = results.personalization_score;
611
612            let satisfaction = (avg_precision * 0.7 + personalization * 0.3).clamp(0.0, 1.0);
613
614            satisfaction_scores.insert(user_id.clone(), satisfaction);
615        }
616
617        Ok(satisfaction_scores)
618    }
619}
620
621impl Default for RecommendationEvaluator {
622    fn default() -> Self {
623        Self::new()
624    }
625}