ai_chain/document_stores/
in_memory_document_store.rs

1use std::collections::HashMap;
2
3use crate::document_stores::document_store::*;
4use crate::schema::Document;
5
6use async_trait::async_trait;
7use serde::{de::DeserializeOwned, Serialize};
8use thiserror::Error;
9
10#[derive(Debug, Clone)]
11pub struct InMemoryDocument<M>
12where
13    M: serde::Serialize + serde::de::DeserializeOwned,
14{
15    page_content: String,
16    metadata: Option<M>,
17}
18
19impl<M> From<&InMemoryDocument<M>> for Document<M>
20where
21    M: serde::Serialize + serde::de::DeserializeOwned,
22{
23    fn from(val: &InMemoryDocument<M>) -> Self {
24        let metadata = if let Some(m) = &val.metadata {
25            let str = serde_json::to_string(&m).unwrap();
26            let cloned = serde_json::from_str::<M>(&str).unwrap();
27            Some(cloned)
28        } else {
29            None
30        };
31
32        Document {
33            page_content: val.page_content.clone(),
34            metadata,
35        }
36    }
37}
38
39impl<M> From<&Document<M>> for InMemoryDocument<M>
40where
41    M: serde::Serialize + serde::de::DeserializeOwned,
42{
43    fn from(val: &Document<M>) -> Self {
44        let metadata = if let Some(m) = &val.metadata {
45            let str = serde_json::to_string(&m).unwrap();
46            let cloned = serde_json::from_str::<M>(&str).unwrap();
47            Some(cloned)
48        } else {
49            None
50        };
51
52        InMemoryDocument {
53            page_content: val.page_content.clone(),
54            metadata,
55        }
56    }
57}
58
59#[derive(Debug, Error)]
60pub enum InMemoryDocumentStoreError {
61    #[error("Serde Error: {0}")]
62    Serde(#[from] serde_json::Error),
63    #[error("Key \"{0}\" already exists!")]
64    KeyConflict(String),
65}
66
67impl DocumentStoreError for InMemoryDocumentStoreError {}
68
69pub struct InMemoryDocumentStore<M>
70where
71    M: Serialize + DeserializeOwned + Send + Sync,
72{
73    map: HashMap<usize, InMemoryDocument<M>>,
74}
75
76impl<M> InMemoryDocumentStore<M>
77where
78    M: Serialize + DeserializeOwned + Send + Sync,
79{
80    pub fn new() -> Self {
81        InMemoryDocumentStore {
82            map: HashMap::new(),
83        }
84    }
85}
86
87impl<M> Default for InMemoryDocumentStore<M>
88where
89    M: Serialize + DeserializeOwned + Send + Sync,
90{
91    fn default() -> Self {
92        Self::new()
93    }
94}
95
96#[async_trait]
97impl<M> DocumentStore<usize, M> for InMemoryDocumentStore<M>
98where
99    M: Serialize + DeserializeOwned + Send + Sync,
100{
101    type Error = InMemoryDocumentStoreError;
102
103    async fn get(&self, id: &usize) -> Result<Option<Document<M>>, Self::Error> {
104        Ok(self.map.get(id).map(|m| m.into()))
105    }
106
107    async fn next_id(&self) -> Result<usize, Self::Error> {
108        Ok(self.map.len())
109    }
110
111    async fn insert(&mut self, documents: &HashMap<usize, Document<M>>) -> Result<(), Self::Error> {
112        for (key, value) in documents.iter() {
113            if self.map.contains_key(key) {
114                return Err(InMemoryDocumentStoreError::KeyConflict(key.to_string()));
115            } else {
116                self.map.insert(key.clone(), value.into());
117            }
118        }
119
120        Ok(())
121    }
122}