use crate::types::{AgentId, MemoryId};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::fmt;
#[derive(Debug, Clone, PartialEq)]
pub enum EmbeddingError {
EmptyVector,
InvalidByteLength {
length: usize,
},
DimensionMismatch {
expected: usize,
actual: usize,
},
GenerationFailed {
provider: String,
message: String,
},
}
impl fmt::Display for EmbeddingError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::EmptyVector => write!(f, "embedding vector cannot be empty"),
Self::InvalidByteLength { length } => {
write!(
f,
"invalid byte length {} for embedding; must be multiple of 4",
length
)
}
Self::DimensionMismatch { expected, actual } => {
write!(
f,
"embedding dimension mismatch: expected {}, got {}",
expected, actual
)
}
Self::GenerationFailed { provider, message } => {
write!(f, "embedding provider '{}' failed: {}", provider, message)
}
}
}
}
impl std::error::Error for EmbeddingError {}
#[derive(Debug, Clone, PartialEq)]
pub struct Embedding {
values: Vec<f32>,
dimension: usize,
}
impl Embedding {
pub fn new(values: Vec<f32>) -> Result<Self, EmbeddingError> {
if values.is_empty() {
return Err(EmbeddingError::EmptyVector);
}
let dimension = values.len();
Ok(Self { values, dimension })
}
#[must_use]
pub fn dimension(&self) -> usize {
self.dimension
}
#[must_use]
pub fn values(&self) -> &[f32] {
&self.values
}
#[must_use]
pub fn to_bytes(&self) -> Vec<u8> {
self.values.iter().flat_map(|f| f.to_le_bytes()).collect()
}
pub fn from_bytes(bytes: &[u8]) -> Result<Self, EmbeddingError> {
if bytes.is_empty() {
return Err(EmbeddingError::EmptyVector);
}
if !bytes.len().is_multiple_of(4) {
return Err(EmbeddingError::InvalidByteLength {
length: bytes.len(),
});
}
let values: Vec<f32> = bytes
.chunks_exact(4)
.map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
.collect();
Self::new(values)
}
pub fn cosine_similarity(&self, other: &Self) -> Result<f32, EmbeddingError> {
if self.dimension != other.dimension {
return Err(EmbeddingError::DimensionMismatch {
expected: self.dimension,
actual: other.dimension,
});
}
let dot_product: f32 = self
.values
.iter()
.zip(other.values.iter())
.map(|(a, b)| a * b)
.sum();
let magnitude_a: f32 = self.values.iter().map(|x| x * x).sum::<f32>().sqrt();
let magnitude_b: f32 = other.values.iter().map(|x| x * x).sum::<f32>().sqrt();
if magnitude_a == 0.0 || magnitude_b == 0.0 {
return Ok(0.0);
}
Ok(dot_product / (magnitude_a * magnitude_b))
}
#[must_use]
pub fn normalize(&self) -> Self {
let magnitude: f32 = self.values.iter().map(|x| x * x).sum::<f32>().sqrt();
if magnitude == 0.0 {
return self.clone();
}
let values: Vec<f32> = self.values.iter().map(|x| x / magnitude).collect();
Self::new(values).expect("normalize should not produce empty vector")
}
}
#[async_trait]
pub trait EmbeddingProvider: Send + Sync {
async fn embed(&self, text: &str) -> Result<Embedding, EmbeddingError>;
fn dimension(&self) -> usize;
fn name(&self) -> &str;
}
#[derive(Debug, Clone)]
pub struct StubEmbeddingProvider {
dimension: usize,
}
impl StubEmbeddingProvider {
#[must_use]
pub fn new(dimension: usize) -> Self {
Self { dimension }
}
}
impl Default for StubEmbeddingProvider {
fn default() -> Self {
Self::new(384) }
}
#[async_trait]
impl EmbeddingProvider for StubEmbeddingProvider {
async fn embed(&self, text: &str) -> Result<Embedding, EmbeddingError> {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
text.hash(&mut hasher);
let seed = hasher.finish();
let values: Vec<f32> = (0..self.dimension)
.map(|i| {
let combined = seed
.wrapping_add(i as u64)
.wrapping_mul(0x5851_f42d_4c95_7f2d);
((combined as f64 / u64::MAX as f64) * std::f64::consts::PI * 2.0).sin() as f32
})
.collect();
Embedding::new(values)
}
fn dimension(&self) -> usize {
self.dimension
}
fn name(&self) -> &str {
"stub"
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct Memory {
pub id: MemoryId,
pub agent_id: AgentId,
pub content: String,
#[serde(skip)]
pub embedding: Option<Embedding>,
pub created_at: String,
}
impl Memory {
#[must_use]
pub fn new(agent_id: AgentId, content: impl Into<String>) -> Self {
Self {
id: MemoryId::new(),
agent_id,
content: content.into(),
embedding: None,
created_at: current_timestamp(),
}
}
#[must_use]
pub fn with_embedding(
agent_id: AgentId,
content: impl Into<String>,
embedding: Embedding,
) -> Self {
Self {
id: MemoryId::new(),
agent_id,
content: content.into(),
embedding: Some(embedding),
created_at: current_timestamp(),
}
}
}
#[derive(Debug, Clone)]
pub struct ScoredMemory {
pub memory: Memory,
pub score: f32,
}
fn current_timestamp() -> String {
use std::time::{SystemTime, UNIX_EPOCH};
let duration = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default();
let secs = duration.as_secs();
format!("{}", secs)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn embedding_new_valid() {
let embedding = Embedding::new(vec![0.1, 0.2, 0.3]).unwrap();
assert_eq!(embedding.dimension(), 3);
assert_eq!(embedding.values(), &[0.1, 0.2, 0.3]);
}
#[test]
fn embedding_new_empty_fails() {
let result = Embedding::new(vec![]);
assert!(matches!(result, Err(EmbeddingError::EmptyVector)));
}
#[test]
fn embedding_to_bytes_roundtrip() {
let original = Embedding::new(vec![1.0, -0.5, 0.0, 0.25]).unwrap();
let bytes = original.to_bytes();
let restored = Embedding::from_bytes(&bytes).unwrap();
assert_eq!(original, restored);
}
#[test]
fn embedding_from_bytes_empty_fails() {
let result = Embedding::from_bytes(&[]);
assert!(matches!(result, Err(EmbeddingError::EmptyVector)));
}
#[test]
fn embedding_from_bytes_invalid_length() {
let result = Embedding::from_bytes(&[1, 2, 3]); assert!(matches!(
result,
Err(EmbeddingError::InvalidByteLength { length: 3 })
));
}
#[test]
fn embedding_cosine_similarity_same_vector() {
let a = Embedding::new(vec![1.0, 0.0, 0.0]).unwrap();
let b = Embedding::new(vec![1.0, 0.0, 0.0]).unwrap();
let similarity = a.cosine_similarity(&b).unwrap();
assert!((similarity - 1.0).abs() < 0.0001);
}
#[test]
fn embedding_cosine_similarity_orthogonal() {
let a = Embedding::new(vec![1.0, 0.0]).unwrap();
let b = Embedding::new(vec![0.0, 1.0]).unwrap();
let similarity = a.cosine_similarity(&b).unwrap();
assert!(similarity.abs() < 0.0001);
}
#[test]
fn embedding_cosine_similarity_opposite() {
let a = Embedding::new(vec![1.0, 0.0]).unwrap();
let b = Embedding::new(vec![-1.0, 0.0]).unwrap();
let similarity = a.cosine_similarity(&b).unwrap();
assert!((similarity + 1.0).abs() < 0.0001);
}
#[test]
fn embedding_cosine_similarity_dimension_mismatch() {
let a = Embedding::new(vec![1.0, 0.0]).unwrap();
let b = Embedding::new(vec![1.0, 0.0, 0.0]).unwrap();
let result = a.cosine_similarity(&b);
assert!(matches!(
result,
Err(EmbeddingError::DimensionMismatch {
expected: 2,
actual: 3
})
));
}
#[test]
fn embedding_cosine_similarity_zero_vector() {
let a = Embedding::new(vec![0.0, 0.0]).unwrap();
let b = Embedding::new(vec![1.0, 0.0]).unwrap();
let similarity = a.cosine_similarity(&b).unwrap();
assert_eq!(similarity, 0.0);
}
#[test]
fn embedding_normalize() {
let a = Embedding::new(vec![3.0, 4.0]).unwrap();
let normalized = a.normalize();
let magnitude: f32 = normalized
.values()
.iter()
.map(|x| x * x)
.sum::<f32>()
.sqrt();
assert!((magnitude - 1.0).abs() < 0.0001);
assert!((normalized.values()[0] - 0.6).abs() < 0.0001);
assert!((normalized.values()[1] - 0.8).abs() < 0.0001);
}
#[test]
fn embedding_normalize_zero_vector() {
let a = Embedding::new(vec![0.0, 0.0]).unwrap();
let normalized = a.normalize();
assert_eq!(normalized.values(), &[0.0, 0.0]);
}
#[test]
fn embedding_error_display_empty_vector() {
let err = EmbeddingError::EmptyVector;
assert_eq!(err.to_string(), "embedding vector cannot be empty");
}
#[test]
fn embedding_error_display_invalid_byte_length() {
let err = EmbeddingError::InvalidByteLength { length: 7 };
assert!(err.to_string().contains("7"));
assert!(err.to_string().contains("multiple of 4"));
}
#[test]
fn embedding_error_display_dimension_mismatch() {
let err = EmbeddingError::DimensionMismatch {
expected: 384,
actual: 768,
};
assert!(err.to_string().contains("384"));
assert!(err.to_string().contains("768"));
}
#[test]
fn embedding_error_display_generation_failed() {
let err = EmbeddingError::GenerationFailed {
provider: "openai".to_string(),
message: "rate limited".to_string(),
};
assert!(err.to_string().contains("openai"));
assert!(err.to_string().contains("rate limited"));
}
#[tokio::test]
async fn stub_provider_embed_same_text_same_embedding() {
let provider = StubEmbeddingProvider::default();
let e1 = provider.embed("hello world").await.unwrap();
let e2 = provider.embed("hello world").await.unwrap();
assert_eq!(e1, e2);
}
#[tokio::test]
async fn stub_provider_embed_different_text_different_embedding() {
let provider = StubEmbeddingProvider::default();
let e1 = provider.embed("hello").await.unwrap();
let e2 = provider.embed("goodbye").await.unwrap();
assert_ne!(e1, e2);
}
#[tokio::test]
async fn stub_provider_dimension() {
let provider = StubEmbeddingProvider::new(512);
assert_eq!(provider.dimension(), 512);
let embedding = provider.embed("test").await.unwrap();
assert_eq!(embedding.dimension(), 512);
}
#[tokio::test]
async fn stub_provider_name() {
let provider = StubEmbeddingProvider::default();
assert_eq!(provider.name(), "stub");
}
#[tokio::test]
async fn stub_provider_default_dimension() {
let provider = StubEmbeddingProvider::default();
assert_eq!(provider.dimension(), 384); }
#[test]
fn memory_new_creates_valid_memory() {
let agent_id = AgentId::new();
let memory = Memory::new(agent_id.clone(), "test content");
assert_eq!(memory.agent_id, agent_id);
assert_eq!(memory.content, "test content");
assert!(memory.embedding.is_none());
assert!(!memory.created_at.is_empty());
}
#[test]
fn memory_with_embedding() {
let agent_id = AgentId::new();
let embedding = Embedding::new(vec![0.1, 0.2, 0.3]).unwrap();
let memory = Memory::with_embedding(agent_id.clone(), "test", embedding.clone());
assert_eq!(memory.agent_id, agent_id);
assert_eq!(memory.content, "test");
assert_eq!(memory.embedding, Some(embedding));
}
#[test]
fn memory_unique_ids() {
let agent_id = AgentId::new();
let m1 = Memory::new(agent_id.clone(), "one");
let m2 = Memory::new(agent_id, "two");
assert_ne!(m1.id, m2.id);
}
#[test]
fn memory_serialization_excludes_embedding() {
let agent_id = AgentId::new();
let embedding = Embedding::new(vec![0.1, 0.2, 0.3]).unwrap();
let memory = Memory::with_embedding(agent_id, "test", embedding);
let json = serde_json::to_string(&memory).unwrap();
assert!(!json.contains("embedding"));
}
#[test]
fn scored_memory_creation() {
let agent_id = AgentId::new();
let memory = Memory::new(agent_id, "test");
let scored = ScoredMemory {
memory: memory.clone(),
score: 0.85,
};
assert_eq!(scored.memory.content, "test");
assert!((scored.score - 0.85).abs() < 0.0001);
}
}