use anyhow::Result;
use serde::{Deserialize, Serialize};
use tracing;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingConfig {
pub similarity_threshold: f32,
pub batch_size: usize,
pub cache_embeddings: bool,
}
impl Default for EmbeddingConfig {
fn default() -> Self {
Self {
similarity_threshold: 0.7,
batch_size: 32,
cache_embeddings: true,
}
}
}
#[must_use]
pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() || a.is_empty() {
return 0.0;
}
let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let magnitude_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let magnitude_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if magnitude_a == 0.0 || magnitude_b == 0.0 {
return 0.0;
}
let similarity = dot_product / (magnitude_a * magnitude_b);
(similarity + 1.0) / 2.0
}
pub fn text_to_embedding(text: &str) -> Vec<f32> {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
tracing::warn!(
"PRODUCTION WARNING: Using hash-based pseudo-embeddings - semantic search will not work correctly! \
Text: '{}'. Use real embedding models for production.",
text.chars().take(20).collect::<String>()
);
let mut hasher = DefaultHasher::new();
text.hash(&mut hasher);
let hash = hasher.finish();
let dimension = 384; let mut embedding = Vec::with_capacity(dimension);
let mut seed = hash;
for _ in 0..dimension {
seed = seed.wrapping_mul(1_103_515_245).wrapping_add(12345);
let value = ((seed >> 16) as f32) / 32768.0 - 1.0; embedding.push(value);
}
let magnitude = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
if magnitude > 0.0 {
for x in &mut embedding {
*x /= magnitude;
}
}
embedding
}
#[cfg(test)]
#[must_use]
pub fn text_to_embedding_test(text: &str) -> Vec<f32> {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
text.hash(&mut hasher);
let hash = hasher.finish();
let dimension = 384; let mut embedding = Vec::with_capacity(dimension);
let mut seed = hash;
for _ in 0..dimension {
seed = seed.wrapping_mul(1_103_515_245).wrapping_add(12345);
let value = ((seed >> 16) as f32) / 32768.0 - 1.0; embedding.push(value);
}
let magnitude = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
if magnitude > 0.0 {
for x in &mut embedding {
*x /= magnitude;
}
}
embedding
}
pub fn find_similar_texts(
query: &str,
candidates: &[String],
limit: usize,
threshold: f32,
) -> Vec<(usize, f32, String)> {
tracing::warn!(
"Using mock embeddings for semantic search - results are not semantically meaningful!"
);
let query_embedding = text_to_embedding(query);
let mut similarities: Vec<(usize, f32, String)> = candidates
.iter()
.enumerate()
.map(|(i, text)| {
let embedding = text_to_embedding(text);
let similarity = cosine_similarity(&query_embedding, &embedding);
(i, similarity, text.clone())
})
.filter(|(_, similarity, _)| *similarity >= threshold)
.collect();
similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
similarities.into_iter().take(limit).collect()
}
pub fn demonstrate_semantic_search() -> Result<()> {
tracing::warn!("🧠 Semantic Search Demonstration (Mock Embeddings)");
tracing::warn!("WARNING: This demonstration uses hash-based pseudo-embeddings");
tracing::warn!("that are NOT semantically meaningful. Similarity scores are");
tracing::warn!("essentially random and do not reflect actual semantic similarity.");
tracing::warn!("For production semantic search, use real embedding models.");
tracing::info!("Enable with: cargo run --features local-embeddings");
let episodes = vec![
"Implement user authentication with JWT tokens".to_string(),
"Build REST API endpoints for user management".to_string(),
"Create data validation middleware for API requests".to_string(),
"Add rate limiting to prevent API abuse".to_string(),
"Implement OAuth2 authentication flow".to_string(),
"Design database schema for user profiles".to_string(),
"Write unit tests for authentication module".to_string(),
"Deploy API to production with Docker".to_string(),
"Monitor API performance and error rates".to_string(),
"Document API endpoints with OpenAPI spec".to_string(),
];
let queries = vec![
"How to secure API with authentication?",
"Need to create user management endpoints",
"Add validation to API requests",
"Prevent API abuse and rate limiting",
];
for query in queries {
tracing::debug!("Query: \"{}\"", query);
let results = find_similar_texts(query, &episodes, 3, 0.5);
tracing::debug!("Top {} similar episodes:", results.len());
for (i, (idx, similarity, text)) in results.iter().enumerate() {
tracing::debug!(
" {}. [{}] {} (similarity: {:.3})",
i + 1,
idx,
text,
similarity
);
}
}
tracing::debug!("Direct Similarity Examples:");
let pairs = vec![
("user authentication", "login system"),
("REST API", "web service endpoints"),
("data validation", "input verification"),
("rate limiting", "API throttling"),
];
for (text1, text2) in pairs {
let emb1 = text_to_embedding(text1);
let emb2 = text_to_embedding(text2);
let similarity = cosine_similarity(&emb1, &emb2);
tracing::debug!(" \"{}\" <-> \"{}\" = {:.3}", text1, text2, similarity);
}
tracing::info!("For real semantic search, use memory-core::embeddings modules");
tracing::info!("with proper ONNX models and sentence transformers.");
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cosine_similarity() {
let vec1 = vec![1.0, 2.0, 3.0];
let vec2 = vec![1.0, 2.0, 3.0];
let similarity = cosine_similarity(&vec1, &vec2);
assert!((similarity - 1.0).abs() < 0.001);
let vec3 = vec![1.0, 0.0];
let vec4 = vec![0.0, 1.0];
let similarity = cosine_similarity(&vec3, &vec4);
assert!((similarity - 0.5).abs() < 0.001);
}
#[test]
fn test_text_to_embedding() {
let embedding1 = text_to_embedding("hello world");
let embedding2 = text_to_embedding("hello world");
let embedding3 = text_to_embedding("different text");
assert_eq!(embedding1, embedding2);
assert_ne!(embedding1, embedding3);
let magnitude1: f32 = embedding1.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((magnitude1 - 1.0).abs() < 0.001);
}
#[test]
fn test_find_similar_texts() {
let candidates = vec![
"implement user authentication".to_string(),
"create REST API endpoints".to_string(),
"add input validation".to_string(),
"deploy with Docker".to_string(),
];
let results = find_similar_texts("user login system", &candidates, 2, 0.0);
assert!(results.len() <= 2);
if results.len() > 1 {
assert!(results[0].1 >= results[1].1);
}
}
#[test]
fn test_embedding_config() {
let config = EmbeddingConfig::default();
assert_eq!(config.similarity_threshold, 0.7);
assert_eq!(config.batch_size, 32);
assert!(config.cache_embeddings);
}
}