1#![doc = include_str!("../README.md")]
2
3use std::collections::BTreeMap;
4use std::fs::{self, File};
5use std::io::{BufRead, BufReader, BufWriter, Write};
6use std::path::Path;
7
8use serde::{Deserialize, Serialize};
9use text_embeddings::{EmbeddingModelInfo, TextEmbedderBackend};
10use thiserror::Error;
11use vector_analysis_index::SerializableVectorRecord;
12
13use crate::{DocumentChunk, RetrievalIndex};
14
15#[derive(Debug, Error)]
16pub enum StorageError {
18 #[error("I/O error: {0}")]
19 Io(#[from] std::io::Error),
21 #[error("JSON error: {0}")]
22 Json(String),
24 #[error("invalid manifest: {0}")]
25 InvalidManifest(String),
27 #[error("invalid retrieval state: {0}")]
28 InvalidState(String),
30}
31
32pub type Result<T> = std::result::Result<T, StorageError>;
34
35#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
36pub struct RetrievalFile {
38 pub path: String,
40 pub records: u64,
42}
43
44#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
45pub struct RetrievalManifest {
47 pub schema_version: u32,
49 pub chunk_count: u64,
51 pub vector_count: u64,
53 pub dimensions: Option<usize>,
55 pub embedder: EmbeddingModelInfo,
57 pub chunks_file: RetrievalFile,
59 pub vectors_file: RetrievalFile,
61 pub corpus_file: String,
63}
64
65#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
66pub struct PersistedCorpusMetadata {
68 pub corpus_options: text_lexical::CorpusOptions,
70 pub bm25_options: text_lexical::Bm25Options,
72}
73
74#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
75pub struct PersistedChunkRecord {
77 pub chunk: DocumentChunk,
79 pub raw_text: Option<String>,
81}
82
83#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
84pub struct PersistedSearchIndex {
86 pub manifest: RetrievalManifest,
88 pub corpus: PersistedCorpusMetadata,
90 pub chunks: Vec<PersistedChunkRecord>,
92 pub vectors: Vec<SerializableVectorRecord>,
94}
95
96impl PersistedSearchIndex {
97 pub fn from_index<B: TextEmbedderBackend>(index: &RetrievalIndex<B>) -> Self {
99 let chunks = index
100 .chunks_iter()
101 .map(|chunk| PersistedChunkRecord {
102 chunk: chunk.clone(),
103 raw_text: index.raw_text(&chunk.chunk_id).map(ToString::to_string),
104 })
105 .collect::<Vec<_>>();
106 let vectors = index.vector_records();
107 let embedder = index.embedder_info();
108
109 Self {
110 manifest: RetrievalManifest {
111 schema_version: 1,
112 chunk_count: chunks.len() as u64,
113 vector_count: vectors.len() as u64,
114 dimensions: vectors.first().map(|record| record.vector.len()),
115 embedder,
116 chunks_file: RetrievalFile {
117 path: "chunks.jsonl".to_string(),
118 records: chunks.len() as u64,
119 },
120 vectors_file: RetrievalFile {
121 path: "vectors.jsonl".to_string(),
122 records: vectors.len() as u64,
123 },
124 corpus_file: "corpus.json".to_string(),
125 },
126 corpus: PersistedCorpusMetadata {
127 corpus_options: index.corpus_options().clone(),
128 bm25_options: index.bm25_options().clone(),
129 },
130 chunks,
131 vectors,
132 }
133 }
134
135 pub fn save_to_path(&self, path: &Path) -> Result<()> {
137 fs::create_dir_all(path)?;
138 write_json(path.join("manifest.json"), &self.manifest)?;
139 write_json(path.join(&self.manifest.corpus_file), &self.corpus)?;
140 write_jsonl(path.join(&self.manifest.chunks_file.path), &self.chunks)?;
141 write_jsonl(path.join(&self.manifest.vectors_file.path), &self.vectors)?;
142 Ok(())
143 }
144
145 pub fn load_from_path(path: &Path) -> Result<Self> {
147 let manifest = read_json::<RetrievalManifest>(path.join("manifest.json"))?;
148 let corpus = read_json::<PersistedCorpusMetadata>(path.join(&manifest.corpus_file))?;
149 let chunks = read_jsonl::<PersistedChunkRecord>(path.join(&manifest.chunks_file.path))?;
150 let vectors =
151 read_jsonl::<SerializableVectorRecord>(path.join(&manifest.vectors_file.path))?;
152
153 if chunks.len() as u64 != manifest.chunk_count {
154 return Err(StorageError::InvalidManifest(format!(
155 "manifest expected {} chunks, loaded {}",
156 manifest.chunk_count,
157 chunks.len()
158 )));
159 }
160 if vectors.len() as u64 != manifest.vector_count {
161 return Err(StorageError::InvalidManifest(format!(
162 "manifest expected {} vectors, loaded {}",
163 manifest.vector_count,
164 vectors.len()
165 )));
166 }
167 if let Some(dimensions) = manifest.dimensions {
168 if vectors
169 .iter()
170 .any(|record| record.vector.len() != dimensions)
171 {
172 return Err(StorageError::InvalidManifest(format!(
173 "persisted vectors did not all match manifest dimension {dimensions}"
174 )));
175 }
176 }
177
178 Ok(Self {
179 manifest,
180 corpus,
181 chunks,
182 vectors,
183 })
184 }
185
186 pub fn into_index<B: TextEmbedderBackend>(self, embedder: B) -> Result<RetrievalIndex<B>> {
188 validate_embedder_compatibility(&self.manifest.embedder, &embedder.model_info())?;
189 let raw_text_by_chunk_id = self
190 .chunks
191 .iter()
192 .filter_map(|record| {
193 record
194 .raw_text
195 .as_ref()
196 .map(|raw_text| (record.chunk.chunk_id.clone(), raw_text.clone()))
197 })
198 .collect::<BTreeMap<_, _>>();
199 let chunks = self
200 .chunks
201 .into_iter()
202 .map(|record| record.chunk)
203 .collect::<Vec<_>>();
204
205 RetrievalIndex::from_parts(
206 embedder,
207 self.corpus.corpus_options,
208 self.corpus.bm25_options,
209 chunks,
210 raw_text_by_chunk_id,
211 self.vectors,
212 )
213 .map_err(|err| StorageError::InvalidState(err.to_string()))
214 }
215
216 pub fn load_with_embedder<B: TextEmbedderBackend>(
218 path: &Path,
219 embedder: B,
220 ) -> Result<RetrievalIndex<B>> {
221 Self::load_from_path(path)?.into_index(embedder)
222 }
223}
224
225fn write_json(path: impl AsRef<Path>, value: &impl Serialize) -> Result<()> {
226 let file = File::create(path)?;
227 serde_json::to_writer_pretty(BufWriter::new(file), value).map_err(json_error)
228}
229
230fn read_json<T: for<'de> Deserialize<'de>>(path: impl AsRef<Path>) -> Result<T> {
231 let file = File::open(path)?;
232 serde_json::from_reader(BufReader::new(file)).map_err(json_error)
233}
234
235fn write_jsonl<T: Serialize>(path: impl AsRef<Path>, values: &[T]) -> Result<()> {
236 let file = File::create(path)?;
237 let mut writer = BufWriter::new(file);
238 for value in values {
239 let line = serde_json::to_string(value).map_err(json_error)?;
240 writer.write_all(line.as_bytes())?;
241 writer.write_all(b"\n")?;
242 }
243 writer.flush()?;
244 Ok(())
245}
246
247fn read_jsonl<T: for<'de> Deserialize<'de>>(path: impl AsRef<Path>) -> Result<Vec<T>> {
248 let file = File::open(path)?;
249 let reader = BufReader::new(file);
250 let mut values = Vec::new();
251 for (line_index, line) in reader.lines().enumerate() {
252 let line = line?;
253 if line.trim().is_empty() {
254 continue;
255 }
256 let value = serde_json::from_str::<T>(&line)
257 .map_err(|err| StorageError::Json(format!("line {}: {err}", line_index + 1)))?;
258 values.push(value);
259 }
260 Ok(values)
261}
262
263fn validate_embedder_compatibility(
264 persisted: &EmbeddingModelInfo,
265 current: &EmbeddingModelInfo,
266) -> Result<()> {
267 if !persisted.model_name.is_empty()
268 && !current.model_name.is_empty()
269 && persisted.model_name != current.model_name
270 {
271 return Err(StorageError::InvalidState(format!(
272 "persisted embedder `{}` did not match provided embedder `{}`",
273 persisted.model_name, current.model_name
274 )));
275 }
276 if persisted.dimensions > 0
277 && current.dimensions > 0
278 && persisted.dimensions != current.dimensions
279 {
280 return Err(StorageError::InvalidState(format!(
281 "persisted embedder dimensions {} did not match provided embedder dimensions {}",
282 persisted.dimensions, current.dimensions
283 )));
284 }
285 Ok(())
286}
287
288fn json_error(error: serde_json::Error) -> StorageError {
289 StorageError::Json(error.to_string())
290}
291
292#[cfg(test)]
293mod tests {
294 use std::collections::BTreeMap;
295
296 use crate::{HybridConfig, IngestionOptions, SearchDocument, SearchQuery};
297 use tempfile::tempdir;
298 use text_embeddings::{
299 DenseVector, HashedTextEmbedder, TextEmbeddingBackend, TextEmbeddingBackendKind,
300 TextEmbeddingConfig, TextEmbeddingMetadata,
301 };
302 use text_lexical::CorpusOptions;
303
304 use super::*;
305
306 fn embedder() -> HashedTextEmbedder {
307 HashedTextEmbedder::new(
308 TextEmbeddingConfig {
309 dimensions: 32,
310 use_idf: true,
311 },
312 CorpusOptions::default(),
313 )
314 .unwrap()
315 }
316
317 #[derive(Debug, Clone)]
318 struct NamedEmbedder {
319 name: String,
320 dimensions: usize,
321 }
322
323 impl TextEmbeddingBackend for NamedEmbedder {
324 fn embed_text(&self, _text: &str) -> video_analysis_core::Result<DenseVector> {
325 DenseVector::new(vec![1.0; self.dimensions])
326 }
327
328 fn metadata(&self) -> TextEmbeddingMetadata {
329 TextEmbeddingMetadata {
330 backend: TextEmbeddingBackendKind::Custom,
331 model_name: Some(self.name.clone()),
332 dimensions: Some(self.dimensions),
333 ..TextEmbeddingMetadata::default()
334 }
335 }
336 }
337
338 #[test]
339 fn persisted_index_round_trips_with_manifest_validation() {
340 let mut index = RetrievalIndex::new(embedder());
341 index
342 .ingest_documents(
343 &[
344 SearchDocument {
345 id: "doc-1".to_string(),
346 title: Some("Rust Search".to_string()),
347 body: "Rust cargo crates enable semantic search over documentation."
348 .to_string(),
349 metadata: BTreeMap::from([("lang".to_string(), "en".to_string())]),
350 source: None,
351 provenance: Vec::new(),
352 annotations: Vec::new(),
353 },
354 SearchDocument {
355 id: "doc-2".to_string(),
356 title: None,
357 body: "Music playlists and recommendation notes.".to_string(),
358 metadata: BTreeMap::from([("lang".to_string(), "en".to_string())]),
359 source: None,
360 provenance: Vec::new(),
361 annotations: Vec::new(),
362 },
363 ],
364 &IngestionOptions::default(),
365 )
366 .unwrap();
367
368 let dir = tempdir().unwrap();
369 let persisted = PersistedSearchIndex::from_index(&index);
370 persisted.save_to_path(dir.path()).unwrap();
371
372 let loaded = PersistedSearchIndex::load_with_embedder(dir.path(), embedder()).unwrap();
373 let query = SearchQuery {
374 text: "rust search docs".to_string(),
375 top_k: 2,
376 filter: None,
377 hybrid: HybridConfig::default(),
378 };
379
380 assert_eq!(
381 loaded.search(&query).unwrap(),
382 index.search(&query).unwrap()
383 );
384 }
385
386 #[test]
387 fn malformed_manifest_is_rejected() {
388 let dir = tempdir().unwrap();
389 fs::write(
390 dir.path().join("manifest.json"),
391 r#"{"schema_version":1,"chunk_count":2,"vector_count":1,"dimensions":2,"embedder":{"model_name":"hashed-text-embedder","backend":"hashed","dimensions":32},"chunks_file":{"path":"chunks.jsonl","records":0},"vectors_file":{"path":"vectors.jsonl","records":0},"corpus_file":"corpus.json"}"#,
392 )
393 .unwrap();
394 fs::write(dir.path().join("corpus.json"), r#"{"corpus_options":{"processing":{"language":null,"lowercase":true,"normalize_unicode":true,"keep_apostrophes":true,"include_punctuation":false},"min_term_len":1,"stop_words":[],"max_terms_per_document":null},"bm25_options":{"k1":1.2,"b":0.75,"min_term_len":1,"stop_words":[]}}"#).unwrap();
395 fs::write(dir.path().join("chunks.jsonl"), "").unwrap();
396 fs::write(dir.path().join("vectors.jsonl"), "").unwrap();
397
398 let err = PersistedSearchIndex::load_from_path(dir.path()).unwrap_err();
399 assert!(
400 matches!(err, StorageError::InvalidManifest(message) if message.contains("chunks"))
401 );
402 }
403
404 #[test]
405 fn loading_rejects_incompatible_embedder_name_or_dimensions() {
406 let mut index = RetrievalIndex::new(NamedEmbedder {
407 name: "persisted".to_string(),
408 dimensions: 4,
409 });
410 index
411 .ingest_documents(
412 &[SearchDocument {
413 id: "doc-1".to_string(),
414 title: None,
415 body: "rust cargo crates".to_string(),
416 metadata: BTreeMap::new(),
417 source: None,
418 provenance: Vec::new(),
419 annotations: Vec::new(),
420 }],
421 &IngestionOptions::default(),
422 )
423 .unwrap();
424
425 let dir = tempdir().unwrap();
426 PersistedSearchIndex::from_index(&index)
427 .save_to_path(dir.path())
428 .unwrap();
429
430 let wrong_name = PersistedSearchIndex::load_with_embedder(
431 dir.path(),
432 NamedEmbedder {
433 name: "other".to_string(),
434 dimensions: 4,
435 },
436 )
437 .unwrap_err();
438 assert!(
439 matches!(wrong_name, StorageError::InvalidState(message) if message.contains("embedder"))
440 );
441
442 let wrong_dimensions = PersistedSearchIndex::load_with_embedder(
443 dir.path(),
444 NamedEmbedder {
445 name: "persisted".to_string(),
446 dimensions: 8,
447 },
448 )
449 .unwrap_err();
450 assert!(
451 matches!(wrong_dimensions, StorageError::InvalidState(message) if message.contains("dimensions"))
452 );
453 }
454}