1use base64::prelude::BASE64_URL_SAFE_NO_PAD;
7use base64::Engine;
8use serde::{Deserialize, Deserializer, Serialize, Serializer};
9use std::fmt;
10use std::hash::{DefaultHasher, Hash, Hasher};
11
12const DEFAULT_EMBEDDING_NAME: &str = "embedding";
13
14#[derive(Clone)]
23pub struct Document {
24 pub id: String,
25 pub text: String,
26 pub embedding: Option<Vec<f32>>,
27 pub embedding_name: Option<String>,
28 pub metadata: Option<serde_json::Value>,
29}
30
31impl Document {
32 fn hash_text(text: &str) -> String {
34 let mut hasher = DefaultHasher::new();
35 text.hash(&mut hasher);
36 let hash = hasher.finish();
37 BASE64_URL_SAFE_NO_PAD.encode(hash.to_be_bytes())
38 }
39
40 pub fn new(text: String) -> Self {
42 Self {
43 id: Self::hash_text(&text),
44 text,
45 embedding: None,
46 embedding_name: None,
47 metadata: None,
48 }
49 }
50
51 #[allow(dead_code)]
53 pub fn new_with_id(id: String, text: String) -> Self {
54 Self {
55 id,
56 text,
57 embedding: None,
58 embedding_name: None,
59 metadata: None,
60 }
61 }
62
63 #[allow(dead_code)]
65 pub fn new_with_embedding(text: String, embedding: Vec<f32>, embedding_name: String) -> Self {
66 Self {
67 id: Self::hash_text(&text),
68 text,
69 embedding: Some(embedding),
70 embedding_name: Some(embedding_name),
71 metadata: None,
72 }
73 }
74}
75
76impl From<String> for Document {
77 fn from(text: String) -> Self {
78 Self::new(text)
79 }
80}
81
82impl fmt::Debug for Document {
83 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
84 let embedding_preview = match &self.embedding {
85 Some(vec) if !vec.is_empty() => {
86 let preview = vec.iter().take(2).collect::<Vec<_>>();
87 let mut preview_str = format!("[{}", preview[0]);
89 if vec.len() > 1 {
90 preview_str.push_str(&format!(", {}", preview[1]));
91 }
92 if vec.len() > 2 {
93 preview_str.push_str(&format!(", ...({} more)", vec.len() - 2));
94 }
95 preview_str.push(']');
96 preview_str
97 }
98 Some(_) => "[]".to_string(),
99 None => "None".to_string(),
100 };
101
102 write!(
103 f,
104 "Document {{ id: {:?}, text: {:?}, embedding: {}, embedding_name: {:?}, metadata: {:?} }}",
105 self.id, self.text, embedding_preview, self.embedding_name, self.metadata
106 )
107 }
108}
109
110#[allow(dead_code)]
112pub struct DocCollection {
113 documents: Vec<Document>,
114}
115
116impl<T: Into<Document>> FromIterator<T> for DocCollection {
117 fn from_iter<I: IntoIterator<Item = T>>(iter: I) -> Self {
118 Self {
119 documents: iter.into_iter().map(Into::into).collect(),
120 }
121 }
122}
123
124impl From<Vec<String>> for DocCollection {
125 fn from(texts: Vec<String>) -> Self {
126 Self {
127 documents: texts.into_iter().map(Document::from).collect(),
128 }
129 }
130}
131
132impl From<DocCollection> for Vec<Document> {
133 fn from(docs: DocCollection) -> Self {
134 docs.documents
135 }
136}
137
138impl Serialize for Document {
139 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
140 where
141 S: Serializer,
142 {
143 let mut doc = serde_json::json!({
144 "text": self.text,
145 });
146
147 doc["id"] = serde_json::json!(self.id);
148
149 if let Some(embedding) = &self.embedding {
150 let embedding_field_name = self
151 .embedding_name
152 .as_deref()
153 .unwrap_or(DEFAULT_EMBEDDING_NAME);
154 doc[embedding_field_name] = serde_json::json!(embedding);
155
156 let mut metadata = self.metadata.clone().unwrap_or_default();
157 metadata["embedding_field_name"] = serde_json::json!(embedding_field_name);
158 doc["metadata"] = metadata;
159 } else if let Some(metadata) = &self.metadata {
160 doc["metadata"] = metadata.clone();
161 }
162
163 doc.serialize(serializer)
164 }
165}
166
167impl<'de> Deserialize<'de> for Document {
168 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
169 where
170 D: Deserializer<'de>,
171 {
172 let doc = serde_json::Value::deserialize(deserializer)?;
173
174 let id = doc
175 .get("id")
176 .and_then(|v| v.as_str())
177 .map(|s| s.to_string());
178
179 let text = doc
180 .get("text")
181 .and_then(|v| v.as_str())
182 .map(|s| s.to_string())
183 .ok_or_else(|| serde::de::Error::missing_field("text"))?;
184
185 let id = id.unwrap_or_else(|| Document::hash_text(&text));
186 let metadata = doc.get("metadata").cloned();
187
188 let embedding_name = metadata
189 .as_ref()
190 .and_then(|m| m.get("embedding_field_name"))
191 .and_then(|v| v.as_str())
192 .map(|s| s.to_string());
193
194 let embedding = if let Some(name) = &embedding_name {
195 doc.get(name)
196 .and_then(|v| serde_json::from_value(v.clone()).ok())
197 } else {
198 None
199 };
200
201 Ok(Document {
202 id,
203 text,
204 embedding,
205 embedding_name,
206 metadata,
207 })
208 }
209}
210
211#[cfg(test)]
212mod tests {
213 #[test]
214 fn test_document_serialization() {
215 let doc = super::Document {
216 id: "1".to_string(),
217 text: "hello".to_string(),
218 embedding: Some(vec![1.0, 2.0, 3.0]),
219 embedding_name: Some("embedding".to_string()),
220 metadata: Some(serde_json::json!({"key": "value"})),
221 };
222
223 let serialized = serde_json::to_string(&doc).unwrap();
224 let deserialized: super::Document = serde_json::from_str(&serialized).unwrap();
225
226 assert_eq!(doc.id, deserialized.id);
227 assert_eq!(doc.text, deserialized.text);
228 assert_eq!(doc.embedding, deserialized.embedding);
229 assert_eq!(doc.embedding_name, deserialized.embedding_name);
230 assert_eq!(
231 doc.metadata.map(|mut m| {
232 m["embedding_field_name"] = serde_json::json!("embedding");
233 m
234 }),
235 deserialized.metadata
236 );
237 }
238
239 #[test]
240 fn test_document_deserialization() {
241 let serialized = r#"{
242 "id": "1",
243 "text": "hello",
244 "embedding": [1.0, 2.0, 3.0],
245 "metadata": {"key": "value", "embedding_field_name": "embedding"}
246 }"#;
247
248 let deserialized: super::Document = serde_json::from_str(serialized).unwrap();
249
250 assert_eq!(deserialized.id, "1".to_string());
251 assert_eq!(deserialized.text, "hello");
252 assert_eq!(deserialized.embedding, Some(vec![1.0, 2.0, 3.0]));
253 assert_eq!(deserialized.embedding_name, Some("embedding".to_string()));
254 assert_eq!(
255 deserialized.metadata,
256 Some(serde_json::json!({"embedding_field_name": "embedding", "key": "value"}))
257 );
258 }
259}