Skip to main content

text_retrieval/
storage.rs

1#![doc = include_str!("../README.md")]
2
3use std::collections::BTreeMap;
4use std::fs::{self, File};
5use std::io::{BufRead, BufReader, BufWriter, Write};
6use std::path::Path;
7
8use serde::{Deserialize, Serialize};
9use text_embeddings::{EmbeddingModelInfo, TextEmbedderBackend};
10use thiserror::Error;
11use vector_analysis_index::SerializableVectorRecord;
12
13use crate::{DocumentChunk, RetrievalIndex};
14
15#[derive(Debug, Error)]
16/// Variants describing storage error.
17pub enum StorageError {
18    #[error("I/O error: {0}")]
19    /// The I/O variant.
20    Io(#[from] std::io::Error),
21    #[error("JSON error: {0}")]
22    /// The JSON variant.
23    Json(String),
24    #[error("invalid manifest: {0}")]
25    /// The invalid manifest variant.
26    InvalidManifest(String),
27    #[error("invalid retrieval state: {0}")]
28    /// The invalid state variant.
29    InvalidState(String),
30}
31
32/// Type alias for result.
33pub type Result<T> = std::result::Result<T, StorageError>;
34
35#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
36/// Data type for retrieval file.
37pub struct RetrievalFile {
38    /// Filesystem path for this value.
39    pub path: String,
40    /// The records value.
41    pub records: u64,
42}
43
44#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
45/// Data type for retrieval manifest.
46pub struct RetrievalManifest {
47    /// The schema version value.
48    pub schema_version: u32,
49    /// The chunk count value.
50    pub chunk_count: u64,
51    /// The vector count value.
52    pub vector_count: u64,
53    /// The dimensions value.
54    pub dimensions: Option<usize>,
55    /// The embedder value.
56    pub embedder: EmbeddingModelInfo,
57    /// The chunks file value.
58    pub chunks_file: RetrievalFile,
59    /// The vectors file value.
60    pub vectors_file: RetrievalFile,
61    /// The corpus file value.
62    pub corpus_file: String,
63}
64
65#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
66/// Data type for persisted corpus metadata.
67pub struct PersistedCorpusMetadata {
68    /// The corpus options value.
69    pub corpus_options: text_lexical::CorpusOptions,
70    /// The bm25 options value.
71    pub bm25_options: text_lexical::Bm25Options,
72}
73
74#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
75/// Data type for persisted chunk record.
76pub struct PersistedChunkRecord {
77    /// The chunk value.
78    pub chunk: DocumentChunk,
79    /// The raw text value.
80    pub raw_text: Option<String>,
81}
82
83#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
84/// Data type for persisted search index.
85pub struct PersistedSearchIndex {
86    /// The manifest value.
87    pub manifest: RetrievalManifest,
88    /// The corpus value.
89    pub corpus: PersistedCorpusMetadata,
90    /// The chunks value.
91    pub chunks: Vec<PersistedChunkRecord>,
92    /// The vectors value.
93    pub vectors: Vec<SerializableVectorRecord>,
94}
95
96impl PersistedSearchIndex {
97    /// Builds this value from index.
98    pub fn from_index<B: TextEmbedderBackend>(index: &RetrievalIndex<B>) -> Self {
99        let chunks = index
100            .chunks_iter()
101            .map(|chunk| PersistedChunkRecord {
102                chunk: chunk.clone(),
103                raw_text: index.raw_text(&chunk.chunk_id).map(ToString::to_string),
104            })
105            .collect::<Vec<_>>();
106        let vectors = index.vector_records();
107        let embedder = index.embedder_info();
108
109        Self {
110            manifest: RetrievalManifest {
111                schema_version: 1,
112                chunk_count: chunks.len() as u64,
113                vector_count: vectors.len() as u64,
114                dimensions: vectors.first().map(|record| record.vector.len()),
115                embedder,
116                chunks_file: RetrievalFile {
117                    path: "chunks.jsonl".to_string(),
118                    records: chunks.len() as u64,
119                },
120                vectors_file: RetrievalFile {
121                    path: "vectors.jsonl".to_string(),
122                    records: vectors.len() as u64,
123                },
124                corpus_file: "corpus.json".to_string(),
125            },
126            corpus: PersistedCorpusMetadata {
127                corpus_options: index.corpus_options().clone(),
128                bm25_options: index.bm25_options().clone(),
129            },
130            chunks,
131            vectors,
132        }
133    }
134
135    /// Returns save to path.
136    pub fn save_to_path(&self, path: &Path) -> Result<()> {
137        fs::create_dir_all(path)?;
138        write_json(path.join("manifest.json"), &self.manifest)?;
139        write_json(path.join(&self.manifest.corpus_file), &self.corpus)?;
140        write_jsonl(path.join(&self.manifest.chunks_file.path), &self.chunks)?;
141        write_jsonl(path.join(&self.manifest.vectors_file.path), &self.vectors)?;
142        Ok(())
143    }
144
145    /// Returns load from path.
146    pub fn load_from_path(path: &Path) -> Result<Self> {
147        let manifest = read_json::<RetrievalManifest>(path.join("manifest.json"))?;
148        let corpus = read_json::<PersistedCorpusMetadata>(path.join(&manifest.corpus_file))?;
149        let chunks = read_jsonl::<PersistedChunkRecord>(path.join(&manifest.chunks_file.path))?;
150        let vectors =
151            read_jsonl::<SerializableVectorRecord>(path.join(&manifest.vectors_file.path))?;
152
153        if chunks.len() as u64 != manifest.chunk_count {
154            return Err(StorageError::InvalidManifest(format!(
155                "manifest expected {} chunks, loaded {}",
156                manifest.chunk_count,
157                chunks.len()
158            )));
159        }
160        if vectors.len() as u64 != manifest.vector_count {
161            return Err(StorageError::InvalidManifest(format!(
162                "manifest expected {} vectors, loaded {}",
163                manifest.vector_count,
164                vectors.len()
165            )));
166        }
167        if let Some(dimensions) = manifest.dimensions {
168            if vectors
169                .iter()
170                .any(|record| record.vector.len() != dimensions)
171            {
172                return Err(StorageError::InvalidManifest(format!(
173                    "persisted vectors did not all match manifest dimension {dimensions}"
174                )));
175            }
176        }
177
178        Ok(Self {
179            manifest,
180            corpus,
181            chunks,
182            vectors,
183        })
184    }
185
186    /// Consumes this value into an index.
187    pub fn into_index<B: TextEmbedderBackend>(self, embedder: B) -> Result<RetrievalIndex<B>> {
188        validate_embedder_compatibility(&self.manifest.embedder, &embedder.model_info())?;
189        let raw_text_by_chunk_id = self
190            .chunks
191            .iter()
192            .filter_map(|record| {
193                record
194                    .raw_text
195                    .as_ref()
196                    .map(|raw_text| (record.chunk.chunk_id.clone(), raw_text.clone()))
197            })
198            .collect::<BTreeMap<_, _>>();
199        let chunks = self
200            .chunks
201            .into_iter()
202            .map(|record| record.chunk)
203            .collect::<Vec<_>>();
204
205        RetrievalIndex::from_parts(
206            embedder,
207            self.corpus.corpus_options,
208            self.corpus.bm25_options,
209            chunks,
210            raw_text_by_chunk_id,
211            self.vectors,
212        )
213        .map_err(|err| StorageError::InvalidState(err.to_string()))
214    }
215
216    /// Returns load with embedder.
217    pub fn load_with_embedder<B: TextEmbedderBackend>(
218        path: &Path,
219        embedder: B,
220    ) -> Result<RetrievalIndex<B>> {
221        Self::load_from_path(path)?.into_index(embedder)
222    }
223}
224
225fn write_json(path: impl AsRef<Path>, value: &impl Serialize) -> Result<()> {
226    let file = File::create(path)?;
227    serde_json::to_writer_pretty(BufWriter::new(file), value).map_err(json_error)
228}
229
230fn read_json<T: for<'de> Deserialize<'de>>(path: impl AsRef<Path>) -> Result<T> {
231    let file = File::open(path)?;
232    serde_json::from_reader(BufReader::new(file)).map_err(json_error)
233}
234
235fn write_jsonl<T: Serialize>(path: impl AsRef<Path>, values: &[T]) -> Result<()> {
236    let file = File::create(path)?;
237    let mut writer = BufWriter::new(file);
238    for value in values {
239        let line = serde_json::to_string(value).map_err(json_error)?;
240        writer.write_all(line.as_bytes())?;
241        writer.write_all(b"\n")?;
242    }
243    writer.flush()?;
244    Ok(())
245}
246
247fn read_jsonl<T: for<'de> Deserialize<'de>>(path: impl AsRef<Path>) -> Result<Vec<T>> {
248    let file = File::open(path)?;
249    let reader = BufReader::new(file);
250    let mut values = Vec::new();
251    for (line_index, line) in reader.lines().enumerate() {
252        let line = line?;
253        if line.trim().is_empty() {
254            continue;
255        }
256        let value = serde_json::from_str::<T>(&line)
257            .map_err(|err| StorageError::Json(format!("line {}: {err}", line_index + 1)))?;
258        values.push(value);
259    }
260    Ok(values)
261}
262
263fn validate_embedder_compatibility(
264    persisted: &EmbeddingModelInfo,
265    current: &EmbeddingModelInfo,
266) -> Result<()> {
267    if !persisted.model_name.is_empty()
268        && !current.model_name.is_empty()
269        && persisted.model_name != current.model_name
270    {
271        return Err(StorageError::InvalidState(format!(
272            "persisted embedder `{}` did not match provided embedder `{}`",
273            persisted.model_name, current.model_name
274        )));
275    }
276    if persisted.dimensions > 0
277        && current.dimensions > 0
278        && persisted.dimensions != current.dimensions
279    {
280        return Err(StorageError::InvalidState(format!(
281            "persisted embedder dimensions {} did not match provided embedder dimensions {}",
282            persisted.dimensions, current.dimensions
283        )));
284    }
285    Ok(())
286}
287
288fn json_error(error: serde_json::Error) -> StorageError {
289    StorageError::Json(error.to_string())
290}
291
292#[cfg(test)]
293mod tests {
294    use std::collections::BTreeMap;
295
296    use crate::{HybridConfig, IngestionOptions, SearchDocument, SearchQuery};
297    use tempfile::tempdir;
298    use text_embeddings::{
299        DenseVector, HashedTextEmbedder, TextEmbeddingBackend, TextEmbeddingBackendKind,
300        TextEmbeddingConfig, TextEmbeddingMetadata,
301    };
302    use text_lexical::CorpusOptions;
303
304    use super::*;
305
306    fn embedder() -> HashedTextEmbedder {
307        HashedTextEmbedder::new(
308            TextEmbeddingConfig {
309                dimensions: 32,
310                use_idf: true,
311            },
312            CorpusOptions::default(),
313        )
314        .unwrap()
315    }
316
317    #[derive(Debug, Clone)]
318    struct NamedEmbedder {
319        name: String,
320        dimensions: usize,
321    }
322
323    impl TextEmbeddingBackend for NamedEmbedder {
324        fn embed_text(&self, _text: &str) -> video_analysis_core::Result<DenseVector> {
325            DenseVector::new(vec![1.0; self.dimensions])
326        }
327
328        fn metadata(&self) -> TextEmbeddingMetadata {
329            TextEmbeddingMetadata {
330                backend: TextEmbeddingBackendKind::Custom,
331                model_name: Some(self.name.clone()),
332                dimensions: Some(self.dimensions),
333                ..TextEmbeddingMetadata::default()
334            }
335        }
336    }
337
338    #[test]
339    fn persisted_index_round_trips_with_manifest_validation() {
340        let mut index = RetrievalIndex::new(embedder());
341        index
342            .ingest_documents(
343                &[
344                    SearchDocument {
345                        id: "doc-1".to_string(),
346                        title: Some("Rust Search".to_string()),
347                        body: "Rust cargo crates enable semantic search over documentation."
348                            .to_string(),
349                        metadata: BTreeMap::from([("lang".to_string(), "en".to_string())]),
350                        source: None,
351                        provenance: Vec::new(),
352                        annotations: Vec::new(),
353                    },
354                    SearchDocument {
355                        id: "doc-2".to_string(),
356                        title: None,
357                        body: "Music playlists and recommendation notes.".to_string(),
358                        metadata: BTreeMap::from([("lang".to_string(), "en".to_string())]),
359                        source: None,
360                        provenance: Vec::new(),
361                        annotations: Vec::new(),
362                    },
363                ],
364                &IngestionOptions::default(),
365            )
366            .unwrap();
367
368        let dir = tempdir().unwrap();
369        let persisted = PersistedSearchIndex::from_index(&index);
370        persisted.save_to_path(dir.path()).unwrap();
371
372        let loaded = PersistedSearchIndex::load_with_embedder(dir.path(), embedder()).unwrap();
373        let query = SearchQuery {
374            text: "rust search docs".to_string(),
375            top_k: 2,
376            filter: None,
377            hybrid: HybridConfig::default(),
378        };
379
380        assert_eq!(
381            loaded.search(&query).unwrap(),
382            index.search(&query).unwrap()
383        );
384    }
385
386    #[test]
387    fn malformed_manifest_is_rejected() {
388        let dir = tempdir().unwrap();
389        fs::write(
390            dir.path().join("manifest.json"),
391            r#"{"schema_version":1,"chunk_count":2,"vector_count":1,"dimensions":2,"embedder":{"model_name":"hashed-text-embedder","backend":"hashed","dimensions":32},"chunks_file":{"path":"chunks.jsonl","records":0},"vectors_file":{"path":"vectors.jsonl","records":0},"corpus_file":"corpus.json"}"#,
392        )
393        .unwrap();
394        fs::write(dir.path().join("corpus.json"), r#"{"corpus_options":{"processing":{"language":null,"lowercase":true,"normalize_unicode":true,"keep_apostrophes":true,"include_punctuation":false},"min_term_len":1,"stop_words":[],"max_terms_per_document":null},"bm25_options":{"k1":1.2,"b":0.75,"min_term_len":1,"stop_words":[]}}"#).unwrap();
395        fs::write(dir.path().join("chunks.jsonl"), "").unwrap();
396        fs::write(dir.path().join("vectors.jsonl"), "").unwrap();
397
398        let err = PersistedSearchIndex::load_from_path(dir.path()).unwrap_err();
399        assert!(
400            matches!(err, StorageError::InvalidManifest(message) if message.contains("chunks"))
401        );
402    }
403
404    #[test]
405    fn loading_rejects_incompatible_embedder_name_or_dimensions() {
406        let mut index = RetrievalIndex::new(NamedEmbedder {
407            name: "persisted".to_string(),
408            dimensions: 4,
409        });
410        index
411            .ingest_documents(
412                &[SearchDocument {
413                    id: "doc-1".to_string(),
414                    title: None,
415                    body: "rust cargo crates".to_string(),
416                    metadata: BTreeMap::new(),
417                    source: None,
418                    provenance: Vec::new(),
419                    annotations: Vec::new(),
420                }],
421                &IngestionOptions::default(),
422            )
423            .unwrap();
424
425        let dir = tempdir().unwrap();
426        PersistedSearchIndex::from_index(&index)
427            .save_to_path(dir.path())
428            .unwrap();
429
430        let wrong_name = PersistedSearchIndex::load_with_embedder(
431            dir.path(),
432            NamedEmbedder {
433                name: "other".to_string(),
434                dimensions: 4,
435            },
436        )
437        .unwrap_err();
438        assert!(
439            matches!(wrong_name, StorageError::InvalidState(message) if message.contains("embedder"))
440        );
441
442        let wrong_dimensions = PersistedSearchIndex::load_with_embedder(
443            dir.path(),
444            NamedEmbedder {
445                name: "persisted".to_string(),
446                dimensions: 8,
447            },
448        )
449        .unwrap_err();
450        assert!(
451            matches!(wrong_dimensions, StorageError::InvalidState(message) if message.contains("dimensions"))
452        );
453    }
454}