agent_memory/
embeddings.rs1use std::sync::Arc;
4
5use serde::{Deserialize, Deserializer, Serialize, Serializer};
6
7use crate::{MemoryError, MemoryResult};
8
9#[derive(Clone, PartialEq)]
11pub struct EmbeddingVector {
12 values: Arc<[f32]>,
13}
14
15impl EmbeddingVector {
16 pub fn new(values: Vec<f32>) -> MemoryResult<Self> {
23 if values.is_empty() {
24 return Err(MemoryError::InvalidRecord(
25 "embedding vector must not be empty",
26 ));
27 }
28 if !values.iter().all(|value| value.is_finite()) {
29 return Err(MemoryError::InvalidRecord(
30 "embedding vector contains non-finite values",
31 ));
32 }
33 Ok(Self {
34 values: Arc::<[f32]>::from(values.into_boxed_slice()),
35 })
36 }
37
38 pub fn from_slice(values: &[f32]) -> MemoryResult<Self> {
45 Self::new(values.to_vec())
46 }
47
48 #[must_use]
50 pub fn as_slice(&self) -> &[f32] {
51 &self.values
52 }
53
54 #[must_use]
56 pub fn len(&self) -> usize {
57 self.values.len()
58 }
59
60 #[must_use]
64 pub fn is_empty(&self) -> bool {
65 self.values.is_empty()
66 }
67
68 pub(crate) fn dot(&self, other: &Self) -> f32 {
69 self.values
70 .iter()
71 .zip(other.values.iter())
72 .map(|(a, b)| a * b)
73 .sum()
74 }
75
76 pub(crate) fn magnitude(&self) -> f32 {
77 self.values
78 .iter()
79 .map(|value| value * value)
80 .sum::<f32>()
81 .sqrt()
82 }
83}
84
85impl std::fmt::Debug for EmbeddingVector {
86 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
87 f.debug_struct("EmbeddingVector")
88 .field("dimensions", &self.len())
89 .finish()
90 }
91}
92
93impl Serialize for EmbeddingVector {
94 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
95 where
96 S: Serializer,
97 {
98 self.values.as_ref().serialize(serializer)
99 }
100}
101
102impl<'de> Deserialize<'de> for EmbeddingVector {
103 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
104 where
105 D: Deserializer<'de>,
106 {
107 let values = Vec::<f32>::deserialize(deserializer)?;
108 Self::new(values).map_err(serde::de::Error::custom)
109 }
110}
111
112#[cfg(test)]
113mod tests {
114 use super::*;
115
116 #[test]
117 fn rejects_empty_vectors() {
118 let err = EmbeddingVector::new(vec![]).expect_err("empty vector should error");
119 assert!(matches!(err, MemoryError::InvalidRecord(_)));
120 }
121
122 #[test]
123 fn rejects_non_finite_values() {
124 let err = EmbeddingVector::new(vec![1.0, f32::NAN]).expect_err("nan not allowed");
125 assert!(matches!(err, MemoryError::InvalidRecord(_)));
126 }
127
128 #[test]
129 fn serialization_roundtrip() {
130 let embedding = EmbeddingVector::new(vec![0.1, 0.2, 0.3]).unwrap();
131 let json = serde_json::to_string(&embedding).unwrap();
132 let decoded: EmbeddingVector = serde_json::from_str(&json).unwrap();
133 assert_eq!(decoded.as_slice(), embedding.as_slice());
134 }
135}