use super::similarity::SimilaritySearchResult;
use crate::Result;
use crate::episode::Episode;
use crate::episode::PatternId;
use crate::pattern::Pattern;
use async_trait::async_trait;
use uuid::Uuid;
#[async_trait]
pub trait EmbeddingStorageBackend: Send + Sync {
async fn store_episode_embedding(&self, episode_id: Uuid, embedding: Vec<f32>) -> Result<()>;
async fn store_pattern_embedding(
&self,
pattern_id: PatternId,
embedding: Vec<f32>,
) -> Result<()>;
async fn get_episode_embedding(&self, episode_id: Uuid) -> Result<Option<Vec<f32>>>;
async fn get_pattern_embedding(&self, pattern_id: PatternId) -> Result<Option<Vec<f32>>>;
async fn find_similar_episodes(
&self,
query_embedding: Vec<f32>,
limit: usize,
threshold: f32,
) -> Result<Vec<SimilaritySearchResult<Episode>>>;
async fn find_similar_patterns(
&self,
query_embedding: Vec<f32>,
limit: usize,
threshold: f32,
) -> Result<Vec<SimilaritySearchResult<Pattern>>>;
}
pub struct InMemoryEmbeddingStorage {
episode_embeddings:
std::sync::Arc<tokio::sync::RwLock<std::collections::HashMap<Uuid, Vec<f32>>>>,
pattern_embeddings:
std::sync::Arc<tokio::sync::RwLock<std::collections::HashMap<PatternId, Vec<f32>>>>,
episodes: std::sync::Arc<tokio::sync::RwLock<std::collections::HashMap<Uuid, Episode>>>,
patterns: std::sync::Arc<tokio::sync::RwLock<std::collections::HashMap<PatternId, Pattern>>>,
}
impl InMemoryEmbeddingStorage {
#[must_use]
pub fn new() -> Self {
Self {
episode_embeddings: std::sync::Arc::new(tokio::sync::RwLock::new(
std::collections::HashMap::new(),
)),
pattern_embeddings: std::sync::Arc::new(tokio::sync::RwLock::new(
std::collections::HashMap::new(),
)),
episodes: std::sync::Arc::new(tokio::sync::RwLock::new(
std::collections::HashMap::new(),
)),
patterns: std::sync::Arc::new(tokio::sync::RwLock::new(
std::collections::HashMap::new(),
)),
}
}
pub async fn add_episode(&self, episode: Episode) {
let mut episodes = self.episodes.write().await;
episodes.insert(episode.episode_id, episode);
}
pub async fn add_pattern(&self, pattern: Pattern) {
let mut patterns = self.patterns.write().await;
let pattern_id = match &pattern {
Pattern::ToolSequence { id, .. } => *id,
Pattern::DecisionPoint { .. }
| Pattern::ErrorRecovery { .. }
| Pattern::ContextPattern { .. } => {
uuid::Uuid::new_v4() }
};
patterns.insert(pattern_id, pattern);
}
}
impl Default for InMemoryEmbeddingStorage {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl EmbeddingStorageBackend for InMemoryEmbeddingStorage {
async fn store_episode_embedding(&self, episode_id: Uuid, embedding: Vec<f32>) -> Result<()> {
let mut embeddings = self.episode_embeddings.write().await;
embeddings.insert(episode_id, embedding);
Ok(())
}
async fn store_pattern_embedding(
&self,
pattern_id: PatternId,
embedding: Vec<f32>,
) -> Result<()> {
let mut embeddings = self.pattern_embeddings.write().await;
embeddings.insert(pattern_id, embedding);
Ok(())
}
async fn get_episode_embedding(&self, episode_id: Uuid) -> Result<Option<Vec<f32>>> {
let embeddings = self.episode_embeddings.read().await;
Ok(embeddings.get(&episode_id).cloned())
}
async fn get_pattern_embedding(&self, pattern_id: PatternId) -> Result<Option<Vec<f32>>> {
let embeddings = self.pattern_embeddings.read().await;
Ok(embeddings.get(&pattern_id).cloned())
}
async fn find_similar_episodes(
&self,
query_embedding: Vec<f32>,
limit: usize,
threshold: f32,
) -> Result<Vec<SimilaritySearchResult<Episode>>> {
let embeddings = self.episode_embeddings.read().await;
let episodes = self.episodes.read().await;
let mut results = Vec::new();
for (episode_id, embedding) in embeddings.iter() {
if let Some(episode) = episodes.get(episode_id) {
let similarity = super::similarity::cosine_similarity(&query_embedding, embedding);
if similarity >= threshold {
results.push(SimilaritySearchResult {
item: episode.clone(),
similarity,
metadata: super::similarity::SimilarityMetadata {
embedding_model: "unknown".to_string(),
embedding_timestamp: None,
context: serde_json::json!({}),
},
});
}
}
}
results.sort_by(|a, b| {
b.similarity
.partial_cmp(&a.similarity)
.unwrap_or(std::cmp::Ordering::Equal)
});
results.truncate(limit);
Ok(results)
}
async fn find_similar_patterns(
&self,
query_embedding: Vec<f32>,
limit: usize,
threshold: f32,
) -> Result<Vec<SimilaritySearchResult<Pattern>>> {
let embeddings = self.pattern_embeddings.read().await;
let patterns = self.patterns.read().await;
let mut results = Vec::new();
for (pattern_id, embedding) in embeddings.iter() {
if let Some(pattern) = patterns.get(pattern_id) {
let similarity = super::similarity::cosine_similarity(&query_embedding, embedding);
if similarity >= threshold {
results.push(SimilaritySearchResult {
item: pattern.clone(),
similarity,
metadata: super::similarity::SimilarityMetadata {
embedding_model: "unknown".to_string(),
embedding_timestamp: None,
context: serde_json::json!({}),
},
});
}
}
}
results.sort_by(|a, b| {
b.similarity
.partial_cmp(&a.similarity)
.unwrap_or(std::cmp::Ordering::Equal)
});
results.truncate(limit);
Ok(results)
}
}
#[cfg(test)]
pub struct MockEmbeddingStorage;
#[cfg(test)]
#[async_trait]
impl EmbeddingStorageBackend for MockEmbeddingStorage {
async fn store_episode_embedding(&self, _episode_id: Uuid, _embedding: Vec<f32>) -> Result<()> {
Ok(())
}
async fn store_pattern_embedding(
&self,
_pattern_id: PatternId,
_embedding: Vec<f32>,
) -> Result<()> {
Ok(())
}
async fn get_episode_embedding(&self, _episode_id: Uuid) -> Result<Option<Vec<f32>>> {
Ok(None)
}
async fn get_pattern_embedding(&self, _pattern_id: PatternId) -> Result<Option<Vec<f32>>> {
Ok(None)
}
async fn find_similar_episodes(
&self,
_query_embedding: Vec<f32>,
_limit: usize,
_threshold: f32,
) -> Result<Vec<SimilaritySearchResult<Episode>>> {
Ok(Vec::new())
}
async fn find_similar_patterns(
&self,
_query_embedding: Vec<f32>,
_limit: usize,
_threshold: f32,
) -> Result<Vec<SimilaritySearchResult<Pattern>>> {
Ok(Vec::new())
}
}
pub struct EmbeddingStorage<T: crate::storage::StorageBackend + EmbeddingStorageBackend> {
storage: std::sync::Arc<T>,
fallback: InMemoryEmbeddingStorage,
}
impl<T: crate::storage::StorageBackend + EmbeddingStorageBackend> EmbeddingStorage<T> {
pub fn new(storage: std::sync::Arc<T>) -> Self {
Self {
storage,
fallback: InMemoryEmbeddingStorage::new(),
}
}
}
#[async_trait]
impl<T: crate::storage::StorageBackend + EmbeddingStorageBackend> EmbeddingStorageBackend
for EmbeddingStorage<T>
{
async fn store_episode_embedding(&self, episode_id: Uuid, embedding: Vec<f32>) -> Result<()> {
if let Err(e) = self
.storage
.store_episode_embedding(episode_id, embedding.clone())
.await
{
tracing::warn!("Failed to store episode embedding in main storage: {}", e);
self.fallback
.store_episode_embedding(episode_id, embedding)
.await?;
}
Ok(())
}
async fn store_pattern_embedding(
&self,
pattern_id: PatternId,
embedding: Vec<f32>,
) -> Result<()> {
if let Err(e) = self
.storage
.store_pattern_embedding(pattern_id, embedding.clone())
.await
{
tracing::warn!("Failed to store pattern embedding in main storage: {}", e);
self.fallback
.store_pattern_embedding(pattern_id, embedding)
.await?;
}
Ok(())
}
async fn get_episode_embedding(&self, episode_id: Uuid) -> Result<Option<Vec<f32>>> {
if let Ok(Some(embedding)) = self.storage.get_episode_embedding(episode_id).await {
return Ok(Some(embedding));
}
self.fallback.get_episode_embedding(episode_id).await
}
async fn get_pattern_embedding(&self, pattern_id: PatternId) -> Result<Option<Vec<f32>>> {
if let Ok(Some(embedding)) = self.storage.get_pattern_embedding(pattern_id).await {
return Ok(Some(embedding));
}
self.fallback.get_pattern_embedding(pattern_id).await
}
async fn find_similar_episodes(
&self,
query_embedding: Vec<f32>,
limit: usize,
threshold: f32,
) -> Result<Vec<SimilaritySearchResult<Episode>>> {
if let Ok(results) = self
.storage
.find_similar_episodes(query_embedding.clone(), limit, threshold)
.await
{
return Ok(results);
}
self.fallback
.find_similar_episodes(query_embedding, limit, threshold)
.await
}
async fn find_similar_patterns(
&self,
query_embedding: Vec<f32>,
limit: usize,
threshold: f32,
) -> Result<Vec<SimilaritySearchResult<Pattern>>> {
if let Ok(results) = self
.storage
.find_similar_patterns(query_embedding.clone(), limit, threshold)
.await
{
return Ok(results);
}
self.fallback
.find_similar_patterns(query_embedding, limit, threshold)
.await
}
}