use crate::Vector;
use crate::VectorStore;
use anyhow::{anyhow, Result};
use parking_lot::RwLock;
use scirs2_core::random::RngExt;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, SystemTime};
type SimilarityMatrix = Arc<RwLock<Option<HashMap<(String, String), f32>>>>;
pub struct PersonalizedSearchEngine {
config: PersonalizationConfig,
vector_store: Arc<RwLock<VectorStore>>,
user_profiles: Arc<RwLock<HashMap<String, UserProfile>>>,
item_profiles: Arc<RwLock<HashMap<String, ItemProfile>>>,
interaction_history: Arc<RwLock<Vec<UserInteraction>>>,
similarity_matrix: SimilarityMatrix,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PersonalizationConfig {
pub user_embedding_dim: usize,
pub learning_rate: f32,
pub time_decay_factor: f32,
pub collaborative_weight: f32,
pub content_weight: f32,
pub enable_bandits: bool,
pub exploration_rate: f32,
pub enable_privacy: bool,
pub privacy_epsilon: f32,
pub min_interactions: usize,
pub user_similarity_threshold: f32,
pub enable_realtime_updates: bool,
pub cold_start_strategy: ColdStartStrategy,
}
impl Default for PersonalizationConfig {
fn default() -> Self {
Self {
user_embedding_dim: 128,
learning_rate: 0.01,
time_decay_factor: 0.95,
collaborative_weight: 0.4,
content_weight: 0.6,
enable_bandits: true,
exploration_rate: 0.1,
enable_privacy: false,
privacy_epsilon: 1.0,
min_interactions: 5,
user_similarity_threshold: 0.7,
enable_realtime_updates: true,
cold_start_strategy: ColdStartStrategy::PopularityBased,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ColdStartStrategy {
PopularityBased,
DemographicBased,
RandomExploration,
Hybrid,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UserProfile {
pub user_id: String,
pub embedding: Vec<f32>,
pub preferences: HashMap<String, f32>,
pub interaction_count: usize,
pub last_updated: SystemTime,
pub demographics: Option<UserDemographics>,
pub similar_users: Vec<(String, f32)>, pub favorite_categories: HashMap<String, f32>,
pub negative_items: Vec<String>, }
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UserDemographics {
pub age_group: Option<String>,
pub location: Option<String>,
pub language: Option<String>,
pub interests: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ItemProfile {
pub item_id: String,
pub embedding: Vec<f32>,
pub popularity_score: f32,
pub categories: Vec<String>,
pub interaction_count: usize,
pub average_rating: f32,
pub last_accessed: SystemTime,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UserInteraction {
pub user_id: String,
pub item_id: String,
pub interaction_type: InteractionType,
pub score: f32,
pub timestamp: SystemTime,
pub context: HashMap<String, String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum InteractionType {
View,
Click,
Like,
Dislike,
Share,
Purchase,
Rating(f32),
DwellTime(Duration),
Custom(String),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UserFeedback {
pub user_id: String,
pub item_id: String,
pub feedback_type: FeedbackType,
pub score: f32,
pub timestamp: SystemTime,
pub metadata: HashMap<String, String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum FeedbackType {
Explicit(f32), Click, View, Skip, Purchase, Share, LongDwell, QuickBounce, Custom(String),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PersonalizedResult {
pub id: String,
pub score: f32,
pub personalization_score: f32,
pub content_score: f32,
pub collaborative_score: f32,
pub exploration_bonus: f32,
pub metadata: HashMap<String, String>,
pub explanation: Option<String>,
}
impl PersonalizedSearchEngine {
pub fn new_default() -> Result<Self> {
Self::new(PersonalizationConfig::default(), None)
}
pub fn new(config: PersonalizationConfig, vector_store: Option<VectorStore>) -> Result<Self> {
let default_store = VectorStore::new();
let vector_store = Arc::new(RwLock::new(vector_store.unwrap_or(default_store)));
Ok(Self {
config,
vector_store,
user_profiles: Arc::new(RwLock::new(HashMap::new())),
item_profiles: Arc::new(RwLock::new(HashMap::new())),
interaction_history: Arc::new(RwLock::new(Vec::new())),
similarity_matrix: Arc::new(RwLock::new(None)),
})
}
pub fn register_user(
&mut self,
user_id: impl Into<String>,
demographics: Option<UserDemographics>,
) -> Result<()> {
let user_id = user_id.into();
let embedding = self.initialize_user_embedding(&user_id, demographics.as_ref())?;
let profile = UserProfile {
user_id: user_id.clone(),
embedding,
preferences: HashMap::new(),
interaction_count: 0,
last_updated: SystemTime::now(),
demographics,
similar_users: Vec::new(),
favorite_categories: HashMap::new(),
negative_items: Vec::new(),
};
self.user_profiles.write().insert(user_id, profile);
Ok(())
}
pub fn personalized_search(
&self,
user_id: impl Into<String>,
query: impl Into<String>,
k: usize,
) -> Result<Vec<PersonalizedResult>> {
let user_id = user_id.into();
let query = query.into();
let user_profiles = self.user_profiles.read();
let user_profile = user_profiles
.get(&user_id)
.ok_or_else(|| anyhow!("User not found: {}", user_id))?;
let use_personalization = user_profile.interaction_count >= self.config.min_interactions;
let base_results = self.content_based_search(&query, k * 3)?;
let personalized_results = if use_personalization {
self.apply_personalization(&user_id, base_results, k)?
} else {
self.apply_cold_start_strategy(&user_id, base_results, k)?
};
Ok(personalized_results)
}
fn content_based_search(&self, query: &str, k: usize) -> Result<Vec<PersonalizedResult>> {
let _query_embedding = self.create_query_embedding(query)?;
let store = self.vector_store.read();
let results = store.similarity_search(query, k)?;
Ok(results
.into_iter()
.map(|(id, score)| PersonalizedResult {
id,
score,
personalization_score: 0.0,
content_score: score,
collaborative_score: 0.0,
exploration_bonus: 0.0,
metadata: HashMap::new(),
explanation: None,
})
.collect())
}
fn apply_personalization(
&self,
user_id: &str,
mut results: Vec<PersonalizedResult>,
k: usize,
) -> Result<Vec<PersonalizedResult>> {
let user_profiles = self.user_profiles.read();
let user_profile = user_profiles
.get(user_id)
.ok_or_else(|| anyhow!("User not found"))?;
for result in &mut results {
let collab_score = self.compute_collaborative_score(user_profile, &result.id)?;
let personal_score = self.compute_personalization_score(user_profile, &result.id)?;
let exploration_bonus = if self.config.enable_bandits {
self.compute_exploration_bonus(user_profile, &result.id)?
} else {
0.0
};
result.collaborative_score = collab_score;
result.personalization_score = personal_score;
result.exploration_bonus = exploration_bonus;
result.score = self.config.content_weight * result.content_score
+ self.config.collaborative_weight * collab_score
+ (1.0 - self.config.content_weight - self.config.collaborative_weight)
* personal_score
+ exploration_bonus;
result.explanation = Some(self.generate_explanation(result));
}
results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
let diversified = self.apply_diversity(&results, k)?;
Ok(diversified)
}
fn compute_collaborative_score(
&self,
user_profile: &UserProfile,
item_id: &str,
) -> Result<f32> {
let item_profiles = self.item_profiles.read();
if let Some(item_profile) = item_profiles.get(item_id) {
let mut collab_score = 0.0;
let mut total_weight = 0.0;
for (similar_user_id, similarity) in &user_profile.similar_users {
let interactions = self.interaction_history.read();
let user_interacted = interactions.iter().any(|i| {
&i.user_id == similar_user_id && i.item_id == item_id && i.score > 0.0
});
if user_interacted {
collab_score += similarity;
total_weight += similarity;
}
}
if total_weight > 0.0 {
collab_score /= total_weight;
}
collab_score += item_profile.popularity_score * 0.1;
Ok(collab_score.min(1.0))
} else {
Ok(0.0)
}
}
fn compute_personalization_score(
&self,
user_profile: &UserProfile,
item_id: &str,
) -> Result<f32> {
let item_profiles = self.item_profiles.read();
if let Some(item_profile) = item_profiles.get(item_id) {
let similarity =
self.cosine_similarity(&user_profile.embedding, &item_profile.embedding);
if user_profile.negative_items.contains(&item_id.to_string()) {
return Ok(similarity * 0.5); }
let category_boost = item_profile
.categories
.iter()
.filter_map(|cat| user_profile.favorite_categories.get(cat))
.sum::<f32>()
/ item_profile.categories.len().max(1) as f32;
Ok((similarity + category_boost * 0.3).min(1.0))
} else {
Ok(0.0)
}
}
fn compute_exploration_bonus(&self, user_profile: &UserProfile, item_id: &str) -> Result<f32> {
let item_profiles = self.item_profiles.read();
if let Some(item_profile) = item_profiles.get(item_id) {
let n = user_profile.interaction_count as f32;
let n_i = item_profile.interaction_count as f32;
if n_i == 0.0 {
return Ok(self.config.exploration_rate);
}
let exploration_bonus = self.config.exploration_rate * ((2.0 * n.ln() / n_i).sqrt());
Ok(exploration_bonus.min(0.5))
} else {
Ok(0.0)
}
}
fn apply_cold_start_strategy(
&self,
_user_id: &str,
mut results: Vec<PersonalizedResult>,
k: usize,
) -> Result<Vec<PersonalizedResult>> {
match self.config.cold_start_strategy {
ColdStartStrategy::PopularityBased => {
let item_profiles = self.item_profiles.read();
for result in &mut results {
if let Some(item_profile) = item_profiles.get(&result.id) {
result.score += item_profile.popularity_score * 0.3;
}
}
results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
}
ColdStartStrategy::RandomExploration => {
use scirs2_core::random::rng;
let mut rng_instance = rng();
for result in &mut results {
let random_val = (rng_instance.random::<u64>() as f32 / u64::MAX as f32) * 0.2;
result.score += random_val;
}
results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
}
ColdStartStrategy::DemographicBased => {
results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
}
ColdStartStrategy::Hybrid => {
use scirs2_core::random::rng;
let item_profiles = self.item_profiles.read();
let mut rng_instance = rng();
for result in &mut results {
if let Some(item_profile) = item_profiles.get(&result.id) {
let random_val =
(rng_instance.random::<u64>() as f32 / u64::MAX as f32) * 0.1;
result.score += item_profile.popularity_score * 0.2 + random_val;
}
}
results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
}
}
Ok(results.into_iter().take(k).collect())
}
pub fn record_feedback(&mut self, feedback: UserFeedback) -> Result<()> {
let interaction = UserInteraction {
user_id: feedback.user_id.clone(),
item_id: feedback.item_id.clone(),
interaction_type: Self::feedback_to_interaction_type(&feedback.feedback_type),
score: feedback.score,
timestamp: feedback.timestamp,
context: feedback.metadata.clone(),
};
self.interaction_history.write().push(interaction.clone());
if self.config.enable_realtime_updates {
self.update_user_profile(&feedback.user_id, &interaction)?;
}
self.update_item_profile(&feedback.item_id, &interaction)?;
Ok(())
}
fn update_user_profile(&mut self, user_id: &str, interaction: &UserInteraction) -> Result<()> {
let mut user_profiles = self.user_profiles.write();
if let Some(profile) = user_profiles.get_mut(user_id) {
profile.interaction_count += 1;
profile.last_updated = SystemTime::now();
let item_profiles = self.item_profiles.read();
if let Some(item_profile) = item_profiles.get(&interaction.item_id) {
let learning_rate = self.config.learning_rate;
for (i, emb_val) in profile.embedding.iter_mut().enumerate() {
if i < item_profile.embedding.len() {
let target = item_profile.embedding[i];
let gradient = (target - *emb_val) * interaction.score;
*emb_val += learning_rate * gradient;
}
}
let norm: f32 = profile.embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
profile.embedding.iter_mut().for_each(|x| *x /= norm);
}
for category in &item_profile.categories {
let current = profile
.favorite_categories
.get(category)
.copied()
.unwrap_or(0.0);
let updated = current * 0.9 + interaction.score * 0.1;
profile
.favorite_categories
.insert(category.clone(), updated);
}
if interaction.score < 0.0 {
profile.negative_items.push(interaction.item_id.clone());
}
}
}
Ok(())
}
fn update_item_profile(&mut self, item_id: &str, interaction: &UserInteraction) -> Result<()> {
let mut item_profiles = self.item_profiles.write();
if let Some(profile) = item_profiles.get_mut(item_id) {
profile.interaction_count += 1;
profile.last_accessed = SystemTime::now();
let old_avg = profile.average_rating;
let count = profile.interaction_count as f32;
profile.average_rating = (old_avg * (count - 1.0) + interaction.score) / count;
profile.popularity_score = profile.popularity_score * 0.95 + interaction.score * 0.05;
}
Ok(())
}
pub fn update_user_similarities(&mut self) -> Result<()> {
let user_profiles = self.user_profiles.read();
let user_ids: Vec<String> = user_profiles.keys().cloned().collect();
for user_id in &user_ids {
if let Some(user_profile) = user_profiles.get(user_id) {
let mut similar_users = Vec::new();
for other_id in &user_ids {
if other_id != user_id {
if let Some(other_profile) = user_profiles.get(other_id) {
let similarity = self.cosine_similarity(
&user_profile.embedding,
&other_profile.embedding,
);
if similarity >= self.config.user_similarity_threshold {
similar_users.push((other_id.clone(), similarity));
}
}
}
}
similar_users
.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
similar_users.truncate(10);
drop(user_profiles);
let mut user_profiles = self.user_profiles.write();
if let Some(profile) = user_profiles.get_mut(user_id) {
profile.similar_users = similar_users;
}
return Ok(()); }
}
Ok(())
}
fn apply_diversity(
&self,
results: &[PersonalizedResult],
k: usize,
) -> Result<Vec<PersonalizedResult>> {
let mut diversified = Vec::new();
let mut remaining: Vec<PersonalizedResult> = results.to_vec();
if !remaining.is_empty() {
diversified.push(remaining.remove(0));
}
let lambda = 0.7;
while diversified.len() < k && !remaining.is_empty() {
let mut best_idx = 0;
let mut best_score = f32::NEG_INFINITY;
for (i, candidate) in remaining.iter().enumerate() {
let mut min_similarity = 1.0f32;
for selected in &diversified {
let similarity = if selected.metadata.get("category")
== candidate.metadata.get("category")
{
0.8
} else {
0.2
};
min_similarity = min_similarity.min(similarity);
}
let mmr_score = lambda * candidate.score + (1.0 - lambda) * (1.0 - min_similarity);
if mmr_score > best_score {
best_score = mmr_score;
best_idx = i;
}
}
diversified.push(remaining.remove(best_idx));
}
Ok(diversified)
}
fn generate_explanation(&self, result: &PersonalizedResult) -> String {
let mut reasons = Vec::new();
if result.personalization_score > 0.5 {
reasons.push("matches your interests");
}
if result.collaborative_score > 0.5 {
reasons.push("liked by similar users");
}
if result.exploration_bonus > 0.1 {
reasons.push("new discovery");
}
if reasons.is_empty() {
reasons.push("relevant to your query");
}
format!("Recommended because: {}", reasons.join(", "))
}
fn initialize_user_embedding(
&self,
_user_id: &str,
demographics: Option<&UserDemographics>,
) -> Result<Vec<f32>> {
use scirs2_core::random::rng;
let mut embedding = vec![0.0f32; self.config.user_embedding_dim];
if let Some(demo) = demographics {
for (_i, interest) in demo.interests.iter().enumerate().take(embedding.len() / 2) {
let hash = Self::hash_string(interest);
let idx = (hash % self.config.user_embedding_dim as u64) as usize;
embedding[idx] = 0.5;
}
} else {
let mut rng_instance = rng();
for val in &mut embedding {
let random_val =
(rng_instance.random::<u64>() as f32 / u64::MAX as f32) * 0.2 - 0.1;
*val = random_val;
}
}
let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
embedding.iter_mut().for_each(|x| *x /= norm);
}
Ok(embedding)
}
fn create_query_embedding(&self, query: &str) -> Result<Vector> {
let tokens: Vec<String> = query
.to_lowercase()
.split_whitespace()
.map(String::from)
.collect();
let mut embedding = vec![0.0f32; 128];
for token in tokens {
let hash = Self::hash_string(&token);
let idx = (hash % embedding.len() as u64) as usize;
embedding[idx] += 1.0;
}
let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
embedding.iter_mut().for_each(|x| *x /= norm);
}
Ok(Vector::new(embedding))
}
fn cosine_similarity(&self, a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() {
return 0.0;
}
let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a == 0.0 || norm_b == 0.0 {
return 0.0;
}
dot_product / (norm_a * norm_b)
}
fn feedback_to_interaction_type(feedback_type: &FeedbackType) -> InteractionType {
match feedback_type {
FeedbackType::Explicit(rating) => InteractionType::Rating(*rating),
FeedbackType::Click => InteractionType::Click,
FeedbackType::View => InteractionType::View,
FeedbackType::Skip => InteractionType::Custom("skip".to_string()),
FeedbackType::Purchase => InteractionType::Purchase,
FeedbackType::Share => InteractionType::Share,
FeedbackType::LongDwell => InteractionType::DwellTime(Duration::from_secs(60)),
FeedbackType::QuickBounce => InteractionType::DwellTime(Duration::from_secs(5)),
FeedbackType::Custom(name) => InteractionType::Custom(name.clone()),
}
}
fn hash_string(s: &str) -> u64 {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
s.hash(&mut hasher);
hasher.finish()
}
pub fn get_user_profile(&self, user_id: &str) -> Option<UserProfile> {
self.user_profiles.read().get(user_id).cloned()
}
pub fn get_statistics(&self) -> PersonalizationStatistics {
let user_profiles = self.user_profiles.read();
let item_profiles = self.item_profiles.read();
let interactions = self.interaction_history.read();
PersonalizationStatistics {
total_users: user_profiles.len(),
total_items: item_profiles.len(),
total_interactions: interactions.len(),
average_interactions_per_user: if user_profiles.is_empty() {
0.0
} else {
interactions.len() as f32 / user_profiles.len() as f32
},
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PersonalizationStatistics {
pub total_users: usize,
pub total_items: usize,
pub total_interactions: usize,
pub average_interactions_per_user: f32,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_register_user() -> Result<()> {
let mut engine = PersonalizedSearchEngine::new_default()?;
engine.register_user("user1", None)?;
let profile = engine.get_user_profile("user1");
assert!(profile.is_some());
Ok(())
}
#[test]
fn test_feedback_recording() -> Result<()> {
let mut engine = PersonalizedSearchEngine::new_default()?;
engine.register_user("user1", None)?;
let feedback = UserFeedback {
user_id: "user1".to_string(),
item_id: "item1".to_string(),
feedback_type: FeedbackType::Click,
score: 1.0,
timestamp: SystemTime::now(),
metadata: HashMap::new(),
};
engine.record_feedback(feedback)?;
let stats = engine.get_statistics();
assert_eq!(stats.total_interactions, 1);
Ok(())
}
#[test]
fn test_cold_start_strategy() -> Result<()> {
let engine = PersonalizedSearchEngine::new_default()?;
let query_embedding = engine.create_query_embedding("test query")?;
assert_eq!(query_embedding.dimensions, 128);
Ok(())
}
#[test]
fn test_cosine_similarity() -> Result<()> {
let engine = PersonalizedSearchEngine::new_default()?;
let a = vec![1.0, 0.0, 0.0];
let b = vec![1.0, 0.0, 0.0];
let similarity = engine.cosine_similarity(&a, &b);
assert!((similarity - 1.0).abs() < 0.001);
Ok(())
}
}