anchor_chain/vector/
document.rs

1//! Structures for managing documents in vector databases.
2//!
3//! This module provides the `Document` and `DocCollection` structs for handling and managing
4//! documents in vector databases.
5
6use base64::prelude::BASE64_URL_SAFE_NO_PAD;
7use base64::Engine;
8use serde::{Deserialize, Deserializer, Serialize, Serializer};
9use std::fmt;
10use std::hash::{DefaultHasher, Hash, Hasher};
11
12const DEFAULT_EMBEDDING_NAME: &str = "embedding";
13
14/// Document structure for serializing and deserializing when working with vector databases.
15///
16/// The `id` field is a unique identifier for the document. If not provided, it will be generated
17/// using a hash of the `text` field. The `text` field is the main content of the document. The
18/// `embedding` field is an optional field that can be used to store a vector embedding of the
19/// document. The `embedding_name` field is the name of the field that the embedding is stored in.
20/// The `metadata` field is an optional field that can be used to store additional metadata about
21/// the document.
22#[derive(Clone)]
23pub struct Document {
24    pub id: String,
25    pub text: String,
26    pub embedding: Option<Vec<f32>>,
27    pub embedding_name: Option<String>,
28    pub metadata: Option<serde_json::Value>,
29}
30
31impl Document {
32    /// Generate a unique identifier for a document based on its text.
33    fn hash_text(text: &str) -> String {
34        let mut hasher = DefaultHasher::new();
35        text.hash(&mut hasher);
36        let hash = hasher.finish();
37        BASE64_URL_SAFE_NO_PAD.encode(hash.to_be_bytes())
38    }
39
40    /// Create a new document with the given text.
41    pub fn new(text: String) -> Self {
42        Self {
43            id: Self::hash_text(&text),
44            text,
45            embedding: None,
46            embedding_name: None,
47            metadata: None,
48        }
49    }
50
51    /// Create a new document with the given id and text.
52    #[allow(dead_code)]
53    pub fn new_with_id(id: String, text: String) -> Self {
54        Self {
55            id,
56            text,
57            embedding: None,
58            embedding_name: None,
59            metadata: None,
60        }
61    }
62
63    /// Create a new document with the given text and embedding.
64    #[allow(dead_code)]
65    pub fn new_with_embedding(text: String, embedding: Vec<f32>, embedding_name: String) -> Self {
66        Self {
67            id: Self::hash_text(&text),
68            text,
69            embedding: Some(embedding),
70            embedding_name: Some(embedding_name),
71            metadata: None,
72        }
73    }
74}
75
76impl From<String> for Document {
77    fn from(text: String) -> Self {
78        Self::new(text)
79    }
80}
81
82impl fmt::Debug for Document {
83    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
84        let embedding_preview = match &self.embedding {
85            Some(vec) if !vec.is_empty() => {
86                let preview = vec.iter().take(2).collect::<Vec<_>>();
87                // Output [preview[0], preview[1], ... (n more)]
88                let mut preview_str = format!("[{}", preview[0]);
89                if vec.len() > 1 {
90                    preview_str.push_str(&format!(", {}", preview[1]));
91                }
92                if vec.len() > 2 {
93                    preview_str.push_str(&format!(", ...({} more)", vec.len() - 2));
94                }
95                preview_str.push(']');
96                preview_str
97            }
98            Some(_) => "[]".to_string(),
99            None => "None".to_string(),
100        };
101
102        write!(
103            f,
104            "Document {{ id: {:?}, text: {:?}, embedding: {}, embedding_name: {:?}, metadata: {:?} }}",
105            self.id, self.text, embedding_preview, self.embedding_name, self.metadata
106        )
107    }
108}
109
110/// A struct representing a collection of documents.
111#[allow(dead_code)]
112pub struct DocCollection {
113    documents: Vec<Document>,
114}
115
116impl<T: Into<Document>> FromIterator<T> for DocCollection {
117    fn from_iter<I: IntoIterator<Item = T>>(iter: I) -> Self {
118        Self {
119            documents: iter.into_iter().map(Into::into).collect(),
120        }
121    }
122}
123
124impl From<Vec<String>> for DocCollection {
125    fn from(texts: Vec<String>) -> Self {
126        Self {
127            documents: texts.into_iter().map(Document::from).collect(),
128        }
129    }
130}
131
132impl From<DocCollection> for Vec<Document> {
133    fn from(docs: DocCollection) -> Self {
134        docs.documents
135    }
136}
137
138impl Serialize for Document {
139    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
140    where
141        S: Serializer,
142    {
143        let mut doc = serde_json::json!({
144            "text": self.text,
145        });
146
147        doc["id"] = serde_json::json!(self.id);
148
149        if let Some(embedding) = &self.embedding {
150            let embedding_field_name = self
151                .embedding_name
152                .as_deref()
153                .unwrap_or(DEFAULT_EMBEDDING_NAME);
154            doc[embedding_field_name] = serde_json::json!(embedding);
155
156            let mut metadata = self.metadata.clone().unwrap_or_default();
157            metadata["embedding_field_name"] = serde_json::json!(embedding_field_name);
158            doc["metadata"] = metadata;
159        } else if let Some(metadata) = &self.metadata {
160            doc["metadata"] = metadata.clone();
161        }
162
163        doc.serialize(serializer)
164    }
165}
166
167impl<'de> Deserialize<'de> for Document {
168    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
169    where
170        D: Deserializer<'de>,
171    {
172        let doc = serde_json::Value::deserialize(deserializer)?;
173
174        let id = doc
175            .get("id")
176            .and_then(|v| v.as_str())
177            .map(|s| s.to_string());
178
179        let text = doc
180            .get("text")
181            .and_then(|v| v.as_str())
182            .map(|s| s.to_string())
183            .ok_or_else(|| serde::de::Error::missing_field("text"))?;
184
185        let id = id.unwrap_or_else(|| Document::hash_text(&text));
186        let metadata = doc.get("metadata").cloned();
187
188        let embedding_name = metadata
189            .as_ref()
190            .and_then(|m| m.get("embedding_field_name"))
191            .and_then(|v| v.as_str())
192            .map(|s| s.to_string());
193
194        let embedding = if let Some(name) = &embedding_name {
195            doc.get(name)
196                .and_then(|v| serde_json::from_value(v.clone()).ok())
197        } else {
198            None
199        };
200
201        Ok(Document {
202            id,
203            text,
204            embedding,
205            embedding_name,
206            metadata,
207        })
208    }
209}
210
211#[cfg(test)]
212mod tests {
213    #[test]
214    fn test_document_serialization() {
215        let doc = super::Document {
216            id: "1".to_string(),
217            text: "hello".to_string(),
218            embedding: Some(vec![1.0, 2.0, 3.0]),
219            embedding_name: Some("embedding".to_string()),
220            metadata: Some(serde_json::json!({"key": "value"})),
221        };
222
223        let serialized = serde_json::to_string(&doc).unwrap();
224        let deserialized: super::Document = serde_json::from_str(&serialized).unwrap();
225
226        assert_eq!(doc.id, deserialized.id);
227        assert_eq!(doc.text, deserialized.text);
228        assert_eq!(doc.embedding, deserialized.embedding);
229        assert_eq!(doc.embedding_name, deserialized.embedding_name);
230        assert_eq!(
231            doc.metadata.map(|mut m| {
232                m["embedding_field_name"] = serde_json::json!("embedding");
233                m
234            }),
235            deserialized.metadata
236        );
237    }
238
239    #[test]
240    fn test_document_deserialization() {
241        let serialized = r#"{
242            "id": "1",
243            "text": "hello",
244            "embedding": [1.0, 2.0, 3.0],
245            "metadata": {"key": "value", "embedding_field_name": "embedding"}
246        }"#;
247
248        let deserialized: super::Document = serde_json::from_str(serialized).unwrap();
249
250        assert_eq!(deserialized.id, "1".to_string());
251        assert_eq!(deserialized.text, "hello");
252        assert_eq!(deserialized.embedding, Some(vec![1.0, 2.0, 3.0]));
253        assert_eq!(deserialized.embedding_name, Some("embedding".to_string()));
254        assert_eq!(
255            deserialized.metadata,
256            Some(serde_json::json!({"embedding_field_name": "embedding", "key": "value"}))
257        );
258    }
259}