ai_chain/document_stores/
in_memory_document_store.rs1use std::collections::HashMap;
2
3use crate::document_stores::document_store::*;
4use crate::schema::Document;
5
6use async_trait::async_trait;
7use serde::{de::DeserializeOwned, Serialize};
8use thiserror::Error;
9
10#[derive(Debug, Clone)]
11pub struct InMemoryDocument<M>
12where
13 M: serde::Serialize + serde::de::DeserializeOwned,
14{
15 page_content: String,
16 metadata: Option<M>,
17}
18
19impl<M> From<&InMemoryDocument<M>> for Document<M>
20where
21 M: serde::Serialize + serde::de::DeserializeOwned,
22{
23 fn from(val: &InMemoryDocument<M>) -> Self {
24 let metadata = if let Some(m) = &val.metadata {
25 let str = serde_json::to_string(&m).unwrap();
26 let cloned = serde_json::from_str::<M>(&str).unwrap();
27 Some(cloned)
28 } else {
29 None
30 };
31
32 Document {
33 page_content: val.page_content.clone(),
34 metadata,
35 }
36 }
37}
38
39impl<M> From<&Document<M>> for InMemoryDocument<M>
40where
41 M: serde::Serialize + serde::de::DeserializeOwned,
42{
43 fn from(val: &Document<M>) -> Self {
44 let metadata = if let Some(m) = &val.metadata {
45 let str = serde_json::to_string(&m).unwrap();
46 let cloned = serde_json::from_str::<M>(&str).unwrap();
47 Some(cloned)
48 } else {
49 None
50 };
51
52 InMemoryDocument {
53 page_content: val.page_content.clone(),
54 metadata,
55 }
56 }
57}
58
59#[derive(Debug, Error)]
60pub enum InMemoryDocumentStoreError {
61 #[error("Serde Error: {0}")]
62 Serde(#[from] serde_json::Error),
63 #[error("Key \"{0}\" already exists!")]
64 KeyConflict(String),
65}
66
67impl DocumentStoreError for InMemoryDocumentStoreError {}
68
69pub struct InMemoryDocumentStore<M>
70where
71 M: Serialize + DeserializeOwned + Send + Sync,
72{
73 map: HashMap<usize, InMemoryDocument<M>>,
74}
75
76impl<M> InMemoryDocumentStore<M>
77where
78 M: Serialize + DeserializeOwned + Send + Sync,
79{
80 pub fn new() -> Self {
81 InMemoryDocumentStore {
82 map: HashMap::new(),
83 }
84 }
85}
86
87impl<M> Default for InMemoryDocumentStore<M>
88where
89 M: Serialize + DeserializeOwned + Send + Sync,
90{
91 fn default() -> Self {
92 Self::new()
93 }
94}
95
96#[async_trait]
97impl<M> DocumentStore<usize, M> for InMemoryDocumentStore<M>
98where
99 M: Serialize + DeserializeOwned + Send + Sync,
100{
101 type Error = InMemoryDocumentStoreError;
102
103 async fn get(&self, id: &usize) -> Result<Option<Document<M>>, Self::Error> {
104 Ok(self.map.get(id).map(|m| m.into()))
105 }
106
107 async fn next_id(&self) -> Result<usize, Self::Error> {
108 Ok(self.map.len())
109 }
110
111 async fn insert(&mut self, documents: &HashMap<usize, Document<M>>) -> Result<(), Self::Error> {
112 for (key, value) in documents.iter() {
113 if self.map.contains_key(key) {
114 return Err(InMemoryDocumentStoreError::KeyConflict(key.to_string()));
115 } else {
116 self.map.insert(key.clone(), value.into());
117 }
118 }
119
120 Ok(())
121 }
122}