1use crate::Vector;
43use crate::VectorStore;
44use anyhow::{anyhow, Result};
45use parking_lot::RwLock;
46use scirs2_core::random::RngCore;
47use serde::{Deserialize, Serialize};
48use std::collections::HashMap;
49use std::sync::Arc;
50use std::time::{Duration, SystemTime};
51
52type SimilarityMatrix = Arc<RwLock<Option<HashMap<(String, String), f32>>>>;
54
55pub struct PersonalizedSearchEngine {
57 config: PersonalizationConfig,
58 vector_store: Arc<RwLock<VectorStore>>,
59 user_profiles: Arc<RwLock<HashMap<String, UserProfile>>>,
60 item_profiles: Arc<RwLock<HashMap<String, ItemProfile>>>,
61 interaction_history: Arc<RwLock<Vec<UserInteraction>>>,
62 similarity_matrix: SimilarityMatrix,
63}
64
65#[derive(Debug, Clone, Serialize, Deserialize)]
67pub struct PersonalizationConfig {
68 pub user_embedding_dim: usize,
70 pub learning_rate: f32,
72 pub time_decay_factor: f32,
74 pub collaborative_weight: f32,
76 pub content_weight: f32,
78 pub enable_bandits: bool,
80 pub exploration_rate: f32,
82 pub enable_privacy: bool,
84 pub privacy_epsilon: f32,
86 pub min_interactions: usize,
88 pub user_similarity_threshold: f32,
90 pub enable_realtime_updates: bool,
92 pub cold_start_strategy: ColdStartStrategy,
94}
95
96impl Default for PersonalizationConfig {
97 fn default() -> Self {
98 Self {
99 user_embedding_dim: 128,
100 learning_rate: 0.01,
101 time_decay_factor: 0.95,
102 collaborative_weight: 0.4,
103 content_weight: 0.6,
104 enable_bandits: true,
105 exploration_rate: 0.1,
106 enable_privacy: false,
107 privacy_epsilon: 1.0,
108 min_interactions: 5,
109 user_similarity_threshold: 0.7,
110 enable_realtime_updates: true,
111 cold_start_strategy: ColdStartStrategy::PopularityBased,
112 }
113 }
114}
115
116#[derive(Debug, Clone, Serialize, Deserialize)]
118pub enum ColdStartStrategy {
119 PopularityBased,
121 DemographicBased,
123 RandomExploration,
125 Hybrid,
127}
128
129#[derive(Debug, Clone, Serialize, Deserialize)]
131pub struct UserProfile {
132 pub user_id: String,
133 pub embedding: Vec<f32>,
134 pub preferences: HashMap<String, f32>,
135 pub interaction_count: usize,
136 pub last_updated: SystemTime,
137 pub demographics: Option<UserDemographics>,
138 pub similar_users: Vec<(String, f32)>, pub favorite_categories: HashMap<String, f32>,
140 pub negative_items: Vec<String>, }
142
143#[derive(Debug, Clone, Serialize, Deserialize)]
145pub struct UserDemographics {
146 pub age_group: Option<String>,
147 pub location: Option<String>,
148 pub language: Option<String>,
149 pub interests: Vec<String>,
150}
151
152#[derive(Debug, Clone, Serialize, Deserialize)]
154pub struct ItemProfile {
155 pub item_id: String,
156 pub embedding: Vec<f32>,
157 pub popularity_score: f32,
158 pub categories: Vec<String>,
159 pub interaction_count: usize,
160 pub average_rating: f32,
161 pub last_accessed: SystemTime,
162}
163
164#[derive(Debug, Clone, Serialize, Deserialize)]
166pub struct UserInteraction {
167 pub user_id: String,
168 pub item_id: String,
169 pub interaction_type: InteractionType,
170 pub score: f32,
171 pub timestamp: SystemTime,
172 pub context: HashMap<String, String>,
173}
174
175#[derive(Debug, Clone, Serialize, Deserialize)]
177pub enum InteractionType {
178 View,
179 Click,
180 Like,
181 Dislike,
182 Share,
183 Purchase,
184 Rating(f32),
185 DwellTime(Duration),
186 Custom(String),
187}
188
189#[derive(Debug, Clone, Serialize, Deserialize)]
191pub struct UserFeedback {
192 pub user_id: String,
193 pub item_id: String,
194 pub feedback_type: FeedbackType,
195 pub score: f32,
196 pub timestamp: SystemTime,
197 pub metadata: HashMap<String, String>,
198}
199
200#[derive(Debug, Clone, Serialize, Deserialize)]
202pub enum FeedbackType {
203 Explicit(f32), Click, View, Skip, Purchase, Share, LongDwell, QuickBounce, Custom(String),
212}
213
214#[derive(Debug, Clone, Serialize, Deserialize)]
216pub struct PersonalizedResult {
217 pub id: String,
218 pub score: f32,
219 pub personalization_score: f32,
220 pub content_score: f32,
221 pub collaborative_score: f32,
222 pub exploration_bonus: f32,
223 pub metadata: HashMap<String, String>,
224 pub explanation: Option<String>,
225}
226
227impl PersonalizedSearchEngine {
228 pub fn new_default() -> Result<Self> {
230 Self::new(PersonalizationConfig::default(), None)
231 }
232
233 pub fn new(config: PersonalizationConfig, vector_store: Option<VectorStore>) -> Result<Self> {
235 let default_store = VectorStore::new();
236 let vector_store = Arc::new(RwLock::new(vector_store.unwrap_or(default_store)));
237
238 Ok(Self {
239 config,
240 vector_store,
241 user_profiles: Arc::new(RwLock::new(HashMap::new())),
242 item_profiles: Arc::new(RwLock::new(HashMap::new())),
243 interaction_history: Arc::new(RwLock::new(Vec::new())),
244 similarity_matrix: Arc::new(RwLock::new(None)),
245 })
246 }
247
248 pub fn register_user(
250 &mut self,
251 user_id: impl Into<String>,
252 demographics: Option<UserDemographics>,
253 ) -> Result<()> {
254 let user_id = user_id.into();
255
256 let embedding = self.initialize_user_embedding(&user_id, demographics.as_ref())?;
258
259 let profile = UserProfile {
260 user_id: user_id.clone(),
261 embedding,
262 preferences: HashMap::new(),
263 interaction_count: 0,
264 last_updated: SystemTime::now(),
265 demographics,
266 similar_users: Vec::new(),
267 favorite_categories: HashMap::new(),
268 negative_items: Vec::new(),
269 };
270
271 self.user_profiles.write().insert(user_id, profile);
272
273 Ok(())
274 }
275
276 pub fn personalized_search(
278 &self,
279 user_id: impl Into<String>,
280 query: impl Into<String>,
281 k: usize,
282 ) -> Result<Vec<PersonalizedResult>> {
283 let user_id = user_id.into();
284 let query = query.into();
285
286 let user_profiles = self.user_profiles.read();
288 let user_profile = user_profiles
289 .get(&user_id)
290 .ok_or_else(|| anyhow!("User not found: {}", user_id))?;
291
292 let use_personalization = user_profile.interaction_count >= self.config.min_interactions;
294
295 let base_results = self.content_based_search(&query, k * 3)?;
297
298 let personalized_results = if use_personalization {
300 self.apply_personalization(&user_id, base_results, k)?
301 } else {
302 self.apply_cold_start_strategy(&user_id, base_results, k)?
303 };
304
305 Ok(personalized_results)
306 }
307
308 fn content_based_search(&self, query: &str, k: usize) -> Result<Vec<PersonalizedResult>> {
310 let _query_embedding = self.create_query_embedding(query)?;
312
313 let store = self.vector_store.read();
315 let results = store.similarity_search(query, k)?;
316
317 Ok(results
319 .into_iter()
320 .map(|(id, score)| PersonalizedResult {
321 id,
322 score,
323 personalization_score: 0.0,
324 content_score: score,
325 collaborative_score: 0.0,
326 exploration_bonus: 0.0,
327 metadata: HashMap::new(),
328 explanation: None,
329 })
330 .collect())
331 }
332
333 fn apply_personalization(
335 &self,
336 user_id: &str,
337 mut results: Vec<PersonalizedResult>,
338 k: usize,
339 ) -> Result<Vec<PersonalizedResult>> {
340 let user_profiles = self.user_profiles.read();
341 let user_profile = user_profiles
342 .get(user_id)
343 .ok_or_else(|| anyhow!("User not found"))?;
344
345 for result in &mut results {
347 let collab_score = self.compute_collaborative_score(user_profile, &result.id)?;
349
350 let personal_score = self.compute_personalization_score(user_profile, &result.id)?;
352
353 let exploration_bonus = if self.config.enable_bandits {
355 self.compute_exploration_bonus(user_profile, &result.id)?
356 } else {
357 0.0
358 };
359
360 result.collaborative_score = collab_score;
362 result.personalization_score = personal_score;
363 result.exploration_bonus = exploration_bonus;
364
365 result.score = self.config.content_weight * result.content_score
366 + self.config.collaborative_weight * collab_score
367 + (1.0 - self.config.content_weight - self.config.collaborative_weight)
368 * personal_score
369 + exploration_bonus;
370
371 result.explanation = Some(self.generate_explanation(result));
373 }
374
375 results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
377
378 let diversified = self.apply_diversity(&results, k)?;
380
381 Ok(diversified)
382 }
383
384 fn compute_collaborative_score(
386 &self,
387 user_profile: &UserProfile,
388 item_id: &str,
389 ) -> Result<f32> {
390 let item_profiles = self.item_profiles.read();
391
392 if let Some(item_profile) = item_profiles.get(item_id) {
393 let mut collab_score = 0.0;
395 let mut total_weight = 0.0;
396
397 for (similar_user_id, similarity) in &user_profile.similar_users {
398 let interactions = self.interaction_history.read();
400 let user_interacted = interactions.iter().any(|i| {
401 &i.user_id == similar_user_id && i.item_id == item_id && i.score > 0.0
402 });
403
404 if user_interacted {
405 collab_score += similarity;
406 total_weight += similarity;
407 }
408 }
409
410 if total_weight > 0.0 {
411 collab_score /= total_weight;
412 }
413
414 collab_score += item_profile.popularity_score * 0.1;
416
417 Ok(collab_score.min(1.0))
418 } else {
419 Ok(0.0)
420 }
421 }
422
423 fn compute_personalization_score(
425 &self,
426 user_profile: &UserProfile,
427 item_id: &str,
428 ) -> Result<f32> {
429 let item_profiles = self.item_profiles.read();
430
431 if let Some(item_profile) = item_profiles.get(item_id) {
432 let similarity =
434 self.cosine_similarity(&user_profile.embedding, &item_profile.embedding);
435
436 if user_profile.negative_items.contains(&item_id.to_string()) {
438 return Ok(similarity * 0.5); }
440
441 let category_boost = item_profile
443 .categories
444 .iter()
445 .filter_map(|cat| user_profile.favorite_categories.get(cat))
446 .sum::<f32>()
447 / item_profile.categories.len().max(1) as f32;
448
449 Ok((similarity + category_boost * 0.3).min(1.0))
450 } else {
451 Ok(0.0)
452 }
453 }
454
455 fn compute_exploration_bonus(&self, user_profile: &UserProfile, item_id: &str) -> Result<f32> {
457 let item_profiles = self.item_profiles.read();
458
459 if let Some(item_profile) = item_profiles.get(item_id) {
460 let n = user_profile.interaction_count as f32;
462 let n_i = item_profile.interaction_count as f32;
463
464 if n_i == 0.0 {
465 return Ok(self.config.exploration_rate);
467 }
468
469 let exploration_bonus = self.config.exploration_rate * ((2.0 * n.ln() / n_i).sqrt());
470
471 Ok(exploration_bonus.min(0.5))
472 } else {
473 Ok(0.0)
474 }
475 }
476
477 fn apply_cold_start_strategy(
479 &self,
480 _user_id: &str,
481 mut results: Vec<PersonalizedResult>,
482 k: usize,
483 ) -> Result<Vec<PersonalizedResult>> {
484 match self.config.cold_start_strategy {
485 ColdStartStrategy::PopularityBased => {
486 let item_profiles = self.item_profiles.read();
488
489 for result in &mut results {
490 if let Some(item_profile) = item_profiles.get(&result.id) {
491 result.score += item_profile.popularity_score * 0.3;
492 }
493 }
494
495 results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
496 }
497 ColdStartStrategy::RandomExploration => {
498 use scirs2_core::random::rng;
500 let mut rng_instance = rng();
501
502 for result in &mut results {
503 let random_val = (rng_instance.next_u64() as f32 / u64::MAX as f32) * 0.2;
505 result.score += random_val;
506 }
507
508 results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
509 }
510 ColdStartStrategy::DemographicBased => {
511 results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
513 }
514 ColdStartStrategy::Hybrid => {
515 use scirs2_core::random::rng;
517 let item_profiles = self.item_profiles.read();
518 let mut rng_instance = rng();
519
520 for result in &mut results {
521 if let Some(item_profile) = item_profiles.get(&result.id) {
522 let random_val = (rng_instance.next_u64() as f32 / u64::MAX as f32) * 0.1;
523 result.score += item_profile.popularity_score * 0.2 + random_val;
524 }
525 }
526
527 results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
528 }
529 }
530
531 Ok(results.into_iter().take(k).collect())
532 }
533
534 pub fn record_feedback(&mut self, feedback: UserFeedback) -> Result<()> {
536 let interaction = UserInteraction {
538 user_id: feedback.user_id.clone(),
539 item_id: feedback.item_id.clone(),
540 interaction_type: Self::feedback_to_interaction_type(&feedback.feedback_type),
541 score: feedback.score,
542 timestamp: feedback.timestamp,
543 context: feedback.metadata.clone(),
544 };
545
546 self.interaction_history.write().push(interaction.clone());
548
549 if self.config.enable_realtime_updates {
551 self.update_user_profile(&feedback.user_id, &interaction)?;
552 }
553
554 self.update_item_profile(&feedback.item_id, &interaction)?;
556
557 Ok(())
558 }
559
560 fn update_user_profile(&mut self, user_id: &str, interaction: &UserInteraction) -> Result<()> {
562 let mut user_profiles = self.user_profiles.write();
563
564 if let Some(profile) = user_profiles.get_mut(user_id) {
565 profile.interaction_count += 1;
567 profile.last_updated = SystemTime::now();
568
569 let item_profiles = self.item_profiles.read();
571 if let Some(item_profile) = item_profiles.get(&interaction.item_id) {
572 let learning_rate = self.config.learning_rate;
574
575 for (i, emb_val) in profile.embedding.iter_mut().enumerate() {
576 if i < item_profile.embedding.len() {
577 let target = item_profile.embedding[i];
578 let gradient = (target - *emb_val) * interaction.score;
579 *emb_val += learning_rate * gradient;
580 }
581 }
582
583 let norm: f32 = profile.embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
585 if norm > 0.0 {
586 profile.embedding.iter_mut().for_each(|x| *x /= norm);
587 }
588
589 for category in &item_profile.categories {
591 let current = profile
592 .favorite_categories
593 .get(category)
594 .copied()
595 .unwrap_or(0.0);
596 let updated = current * 0.9 + interaction.score * 0.1;
597 profile
598 .favorite_categories
599 .insert(category.clone(), updated);
600 }
601
602 if interaction.score < 0.0 {
604 profile.negative_items.push(interaction.item_id.clone());
605 }
606 }
607 }
608
609 Ok(())
610 }
611
612 fn update_item_profile(&mut self, item_id: &str, interaction: &UserInteraction) -> Result<()> {
614 let mut item_profiles = self.item_profiles.write();
615
616 if let Some(profile) = item_profiles.get_mut(item_id) {
617 profile.interaction_count += 1;
618 profile.last_accessed = SystemTime::now();
619
620 let old_avg = profile.average_rating;
622 let count = profile.interaction_count as f32;
623 profile.average_rating = (old_avg * (count - 1.0) + interaction.score) / count;
624
625 profile.popularity_score = profile.popularity_score * 0.95 + interaction.score * 0.05;
627 }
628
629 Ok(())
630 }
631
632 pub fn update_user_similarities(&mut self) -> Result<()> {
634 let user_profiles = self.user_profiles.read();
635 let user_ids: Vec<String> = user_profiles.keys().cloned().collect();
636
637 for user_id in &user_ids {
638 if let Some(user_profile) = user_profiles.get(user_id) {
639 let mut similar_users = Vec::new();
640
641 for other_id in &user_ids {
643 if other_id != user_id {
644 if let Some(other_profile) = user_profiles.get(other_id) {
645 let similarity = self.cosine_similarity(
646 &user_profile.embedding,
647 &other_profile.embedding,
648 );
649
650 if similarity >= self.config.user_similarity_threshold {
651 similar_users.push((other_id.clone(), similarity));
652 }
653 }
654 }
655 }
656
657 similar_users.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
659 similar_users.truncate(10);
660
661 drop(user_profiles);
663 let mut user_profiles = self.user_profiles.write();
664 if let Some(profile) = user_profiles.get_mut(user_id) {
665 profile.similar_users = similar_users;
666 }
667
668 return Ok(()); }
670 }
671
672 Ok(())
673 }
674
675 fn apply_diversity(
677 &self,
678 results: &[PersonalizedResult],
679 k: usize,
680 ) -> Result<Vec<PersonalizedResult>> {
681 let mut diversified = Vec::new();
683 let mut remaining: Vec<PersonalizedResult> = results.to_vec();
684
685 if !remaining.is_empty() {
686 diversified.push(remaining.remove(0));
688 }
689
690 let lambda = 0.7; while diversified.len() < k && !remaining.is_empty() {
693 let mut best_idx = 0;
694 let mut best_score = f32::NEG_INFINITY;
695
696 for (i, candidate) in remaining.iter().enumerate() {
697 let mut min_similarity = 1.0f32;
699
700 for selected in &diversified {
701 let similarity = if selected.metadata.get("category")
702 == candidate.metadata.get("category")
703 {
704 0.8
705 } else {
706 0.2
707 };
708
709 min_similarity = min_similarity.min(similarity);
710 }
711
712 let mmr_score = lambda * candidate.score + (1.0 - lambda) * (1.0 - min_similarity);
714
715 if mmr_score > best_score {
716 best_score = mmr_score;
717 best_idx = i;
718 }
719 }
720
721 diversified.push(remaining.remove(best_idx));
722 }
723
724 Ok(diversified)
725 }
726
727 fn generate_explanation(&self, result: &PersonalizedResult) -> String {
729 let mut reasons = Vec::new();
730
731 if result.personalization_score > 0.5 {
732 reasons.push("matches your interests");
733 }
734
735 if result.collaborative_score > 0.5 {
736 reasons.push("liked by similar users");
737 }
738
739 if result.exploration_bonus > 0.1 {
740 reasons.push("new discovery");
741 }
742
743 if reasons.is_empty() {
744 reasons.push("relevant to your query");
745 }
746
747 format!("Recommended because: {}", reasons.join(", "))
748 }
749
750 fn initialize_user_embedding(
752 &self,
753 _user_id: &str,
754 demographics: Option<&UserDemographics>,
755 ) -> Result<Vec<f32>> {
756 use scirs2_core::random::rng;
757 let mut embedding = vec![0.0f32; self.config.user_embedding_dim];
758
759 if let Some(demo) = demographics {
760 for (_i, interest) in demo.interests.iter().enumerate().take(embedding.len() / 2) {
762 let hash = Self::hash_string(interest);
763 let idx = (hash % self.config.user_embedding_dim as u64) as usize;
764 embedding[idx] = 0.5;
765 }
766 } else {
767 let mut rng_instance = rng();
769
770 for val in &mut embedding {
771 let random_val = (rng_instance.next_u64() as f32 / u64::MAX as f32) * 0.2 - 0.1;
773 *val = random_val;
774 }
775 }
776
777 let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
779 if norm > 0.0 {
780 embedding.iter_mut().for_each(|x| *x /= norm);
781 }
782
783 Ok(embedding)
784 }
785
786 fn create_query_embedding(&self, query: &str) -> Result<Vector> {
788 let tokens: Vec<String> = query
790 .to_lowercase()
791 .split_whitespace()
792 .map(String::from)
793 .collect();
794
795 let mut embedding = vec![0.0f32; 128]; for token in tokens {
798 let hash = Self::hash_string(&token);
799 let idx = (hash % embedding.len() as u64) as usize;
800 embedding[idx] += 1.0;
801 }
802
803 let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
805 if norm > 0.0 {
806 embedding.iter_mut().for_each(|x| *x /= norm);
807 }
808
809 Ok(Vector::new(embedding))
810 }
811
812 fn cosine_similarity(&self, a: &[f32], b: &[f32]) -> f32 {
814 if a.len() != b.len() {
815 return 0.0;
816 }
817
818 let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
819 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
820 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
821
822 if norm_a == 0.0 || norm_b == 0.0 {
823 return 0.0;
824 }
825
826 dot_product / (norm_a * norm_b)
827 }
828
829 fn feedback_to_interaction_type(feedback_type: &FeedbackType) -> InteractionType {
831 match feedback_type {
832 FeedbackType::Explicit(rating) => InteractionType::Rating(*rating),
833 FeedbackType::Click => InteractionType::Click,
834 FeedbackType::View => InteractionType::View,
835 FeedbackType::Skip => InteractionType::Custom("skip".to_string()),
836 FeedbackType::Purchase => InteractionType::Purchase,
837 FeedbackType::Share => InteractionType::Share,
838 FeedbackType::LongDwell => InteractionType::DwellTime(Duration::from_secs(60)),
839 FeedbackType::QuickBounce => InteractionType::DwellTime(Duration::from_secs(5)),
840 FeedbackType::Custom(name) => InteractionType::Custom(name.clone()),
841 }
842 }
843
844 fn hash_string(s: &str) -> u64 {
846 use std::collections::hash_map::DefaultHasher;
847 use std::hash::{Hash, Hasher};
848
849 let mut hasher = DefaultHasher::new();
850 s.hash(&mut hasher);
851 hasher.finish()
852 }
853
854 pub fn get_user_profile(&self, user_id: &str) -> Option<UserProfile> {
856 self.user_profiles.read().get(user_id).cloned()
857 }
858
859 pub fn get_statistics(&self) -> PersonalizationStatistics {
861 let user_profiles = self.user_profiles.read();
862 let item_profiles = self.item_profiles.read();
863 let interactions = self.interaction_history.read();
864
865 PersonalizationStatistics {
866 total_users: user_profiles.len(),
867 total_items: item_profiles.len(),
868 total_interactions: interactions.len(),
869 average_interactions_per_user: if user_profiles.is_empty() {
870 0.0
871 } else {
872 interactions.len() as f32 / user_profiles.len() as f32
873 },
874 }
875 }
876}
877
878#[derive(Debug, Clone, Serialize, Deserialize)]
880pub struct PersonalizationStatistics {
881 pub total_users: usize,
882 pub total_items: usize,
883 pub total_interactions: usize,
884 pub average_interactions_per_user: f32,
885}
886
887#[cfg(test)]
888mod tests {
889 use super::*;
890
891 #[test]
892 fn test_register_user() -> Result<()> {
893 let mut engine = PersonalizedSearchEngine::new_default()?;
894
895 engine.register_user("user1", None)?;
896
897 let profile = engine.get_user_profile("user1");
898 assert!(profile.is_some());
899
900 Ok(())
901 }
902
903 #[test]
904 fn test_feedback_recording() -> Result<()> {
905 let mut engine = PersonalizedSearchEngine::new_default()?;
906
907 engine.register_user("user1", None)?;
908
909 let feedback = UserFeedback {
910 user_id: "user1".to_string(),
911 item_id: "item1".to_string(),
912 feedback_type: FeedbackType::Click,
913 score: 1.0,
914 timestamp: SystemTime::now(),
915 metadata: HashMap::new(),
916 };
917
918 engine.record_feedback(feedback)?;
919
920 let stats = engine.get_statistics();
921 assert_eq!(stats.total_interactions, 1);
922
923 Ok(())
924 }
925
926 #[test]
927 fn test_cold_start_strategy() -> Result<()> {
928 let engine = PersonalizedSearchEngine::new_default()?;
929
930 let query_embedding = engine.create_query_embedding("test query")?;
931 assert_eq!(query_embedding.dimensions, 128);
932
933 Ok(())
934 }
935
936 #[test]
937 fn test_cosine_similarity() -> Result<()> {
938 let engine = PersonalizedSearchEngine::new_default()?;
939
940 let a = vec![1.0, 0.0, 0.0];
941 let b = vec![1.0, 0.0, 0.0];
942
943 let similarity = engine.cosine_similarity(&a, &b);
944 assert!((similarity - 1.0).abs() < 0.001);
945
946 Ok(())
947 }
948}