agent_memory/
embeddings.rs

1//! Embedding vector utilities shared across memory components.
2
3use std::sync::Arc;
4
5use serde::{Deserialize, Deserializer, Serialize, Serializer};
6
7use crate::{MemoryError, MemoryResult};
8
9/// Wrapper type around an immutable floating-point embedding.
10#[derive(Clone, PartialEq)]
11pub struct EmbeddingVector {
12    values: Arc<[f32]>,
13}
14
15impl EmbeddingVector {
16    /// Creates a new embedding from owned values.
17    ///
18    /// # Errors
19    ///
20    /// Returns [`MemoryError::InvalidRecord`] when the supplied vector is empty
21    /// or contains non-finite values.
22    pub fn new(values: Vec<f32>) -> MemoryResult<Self> {
23        if values.is_empty() {
24            return Err(MemoryError::InvalidRecord(
25                "embedding vector must not be empty",
26            ));
27        }
28        if !values.iter().all(|value| value.is_finite()) {
29            return Err(MemoryError::InvalidRecord(
30                "embedding vector contains non-finite values",
31            ));
32        }
33        Ok(Self {
34            values: Arc::<[f32]>::from(values.into_boxed_slice()),
35        })
36    }
37
38    /// Creates an embedding by copying the provided slice.
39    ///
40    /// # Errors
41    ///
42    /// Returns [`MemoryError::InvalidRecord`] if the slice is empty or contains
43    /// non-finite values.
44    pub fn from_slice(values: &[f32]) -> MemoryResult<Self> {
45        Self::new(values.to_vec())
46    }
47
48    /// Returns an immutable view of the embedding data.
49    #[must_use]
50    pub fn as_slice(&self) -> &[f32] {
51        &self.values
52    }
53
54    /// Returns the dimensionality of the embedding.
55    #[must_use]
56    pub fn len(&self) -> usize {
57        self.values.len()
58    }
59
60    /// Returns whether the embedding is empty. This should never be the case
61    /// because [`EmbeddingVector::new`] rejects empty inputs, but the helper is
62    /// provided for completeness.
63    #[must_use]
64    pub fn is_empty(&self) -> bool {
65        self.values.is_empty()
66    }
67
68    pub(crate) fn dot(&self, other: &Self) -> f32 {
69        self.values
70            .iter()
71            .zip(other.values.iter())
72            .map(|(a, b)| a * b)
73            .sum()
74    }
75
76    pub(crate) fn magnitude(&self) -> f32 {
77        self.values
78            .iter()
79            .map(|value| value * value)
80            .sum::<f32>()
81            .sqrt()
82    }
83}
84
85impl std::fmt::Debug for EmbeddingVector {
86    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
87        f.debug_struct("EmbeddingVector")
88            .field("dimensions", &self.len())
89            .finish()
90    }
91}
92
93impl Serialize for EmbeddingVector {
94    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
95    where
96        S: Serializer,
97    {
98        self.values.as_ref().serialize(serializer)
99    }
100}
101
102impl<'de> Deserialize<'de> for EmbeddingVector {
103    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
104    where
105        D: Deserializer<'de>,
106    {
107        let values = Vec::<f32>::deserialize(deserializer)?;
108        Self::new(values).map_err(serde::de::Error::custom)
109    }
110}
111
112#[cfg(test)]
113mod tests {
114    use super::*;
115
116    #[test]
117    fn rejects_empty_vectors() {
118        let err = EmbeddingVector::new(vec![]).expect_err("empty vector should error");
119        assert!(matches!(err, MemoryError::InvalidRecord(_)));
120    }
121
122    #[test]
123    fn rejects_non_finite_values() {
124        let err = EmbeddingVector::new(vec![1.0, f32::NAN]).expect_err("nan not allowed");
125        assert!(matches!(err, MemoryError::InvalidRecord(_)));
126    }
127
128    #[test]
129    fn serialization_roundtrip() {
130        let embedding = EmbeddingVector::new(vec![0.1, 0.2, 0.3]).unwrap();
131        let json = serde_json::to_string(&embedding).unwrap();
132        let decoded: EmbeddingVector = serde_json::from_str(&json).unwrap();
133        assert_eq!(decoded.as_slice(), embedding.as_slice());
134    }
135}