1use async_trait::async_trait;
2use sqlx::{sqlite::SqlitePoolOptions, Row, SqlitePool};
3use tokio::{fs, io::AsyncWriteExt};
4
5use crate::error::{AgnoError, Result};
6use crate::message::Message;
7
8#[async_trait]
10pub trait ConversationStore: Send + Sync {
11 async fn load(&self) -> Result<Vec<Message>>;
12 async fn append(&self, message: &Message) -> Result<()>;
13 async fn clear(&self) -> Result<()>;
14}
15
16pub struct FileConversationStore {
18 path: String,
19}
20
21impl FileConversationStore {
22 pub fn new(path: impl Into<String>) -> Self {
23 Self { path: path.into() }
24 }
25}
26
27#[async_trait]
28impl ConversationStore for FileConversationStore {
29 async fn load(&self) -> Result<Vec<Message>> {
30 let content = match fs::read_to_string(&self.path).await {
31 Ok(contents) => contents,
32 Err(err) if err.kind() == std::io::ErrorKind::NotFound => return Ok(Vec::new()),
33 Err(err) => {
34 return Err(AgnoError::Storage(format!(
35 "failed to read transcript `{}`: {err}",
36 self.path
37 )))
38 }
39 };
40
41 let mut messages = Vec::new();
42 for line in content.lines() {
43 let msg: Message = serde_json::from_str(line)?;
44 messages.push(msg);
45 }
46
47 Ok(messages)
48 }
49
50 async fn append(&self, message: &Message) -> Result<()> {
51 let mut serialized = serde_json::to_string(message)?;
52 serialized.push('\n');
53 fs::OpenOptions::new()
54 .create(true)
55 .append(true)
56 .open(&self.path)
57 .await
58 .map_err(|err| {
59 AgnoError::Storage(format!("failed to open `{}`: {err}", self.path.clone()))
60 })?
61 .write_all(serialized.as_bytes())
62 .await
63 .map_err(|err| AgnoError::Storage(format!("failed to persist message: {err}")))
64 }
65
66 async fn clear(&self) -> Result<()> {
67 fs::remove_file(&self.path)
68 .await
69 .or_else(|err| {
70 if err.kind() == std::io::ErrorKind::NotFound {
71 Ok(())
72 } else {
73 Err(err)
74 }
75 })
76 .map_err(|err| AgnoError::Storage(format!("failed clearing `{}`: {err}", self.path)))
77 }
78}
79
80pub struct SqlConversationStore {
83 pool: SqlitePool,
84}
85
86impl SqlConversationStore {
87 const INIT_STATEMENT: &'static str = r#"
88 CREATE TABLE IF NOT EXISTS messages (
89 id INTEGER PRIMARY KEY AUTOINCREMENT,
90 payload TEXT NOT NULL
91 )
92 "#;
93
94 pub async fn connect(connection_url: impl AsRef<str>) -> Result<Self> {
95 let pool = SqlitePoolOptions::new()
96 .max_connections(1)
97 .connect(connection_url.as_ref())
98 .await
99 .map_err(|err| {
100 AgnoError::Storage(format!(
101 "failed connecting to SQL backend `{}`: {err}",
102 connection_url.as_ref()
103 ))
104 })?;
105
106 sqlx::query(Self::INIT_STATEMENT)
107 .execute(&pool)
108 .await
109 .map_err(|err| AgnoError::Storage(format!("failed initializing schema: {err}")))?;
110
111 Ok(Self { pool })
112 }
113}
114
115#[async_trait]
116impl ConversationStore for SqlConversationStore {
117 async fn load(&self) -> Result<Vec<Message>> {
118 let rows = sqlx::query("SELECT payload FROM messages ORDER BY id ASC")
119 .fetch_all(&self.pool)
120 .await
121 .map_err(|err| AgnoError::Storage(format!("failed loading messages: {err}")))?;
122
123 rows.into_iter()
124 .map(|row| {
125 let payload: String = row.try_get("payload").map_err(|err| {
126 AgnoError::Storage(format!("failed decoding message payload: {err}"))
127 })?;
128 serde_json::from_str(&payload)
129 .map_err(|err| AgnoError::Storage(format!("invalid message payload: {err}")))
130 })
131 .collect()
132 }
133
134 async fn append(&self, message: &Message) -> Result<()> {
135 let payload = serde_json::to_string(message)?;
136 sqlx::query("INSERT INTO messages (payload) VALUES (?)")
137 .bind(payload)
138 .execute(&self.pool)
139 .await
140 .map(|_| ())
141 .map_err(|err| AgnoError::Storage(format!("failed writing message: {err}")))
142 }
143
144 async fn clear(&self) -> Result<()> {
145 sqlx::query("DELETE FROM messages")
146 .execute(&self.pool)
147 .await
148 .map(|_| ())
149 .map_err(|err| AgnoError::Storage(format!("failed clearing messages: {err}")))
150 }
151}
152
153#[cfg(test)]
154mod tests {
155 use super::*;
156 use crate::message::Role;
157 use tempfile::NamedTempFile;
158
159 #[tokio::test]
160 async fn file_store_round_trip() {
161 let file = NamedTempFile::new().unwrap();
162 let store = FileConversationStore::new(file.path().to_str().unwrap());
163
164 let msg = Message::user("hello");
165 store.append(&msg).await.unwrap();
166
167 let loaded = store.load().await.unwrap();
168 assert_eq!(loaded.len(), 1);
169 assert_eq!(loaded[0].role, Role::User);
170
171 store.clear().await.unwrap();
172 let cleared = store.load().await.unwrap();
173 assert!(cleared.is_empty());
174 }
175
176 #[tokio::test]
177 async fn sqlite_store_round_trip() {
178 let store = SqlConversationStore::connect("sqlite::memory:")
179 .await
180 .unwrap();
181
182 let msg = Message::assistant("hi from db");
183 store.append(&msg).await.unwrap();
184
185 let loaded = store.load().await.unwrap();
186 assert_eq!(loaded.len(), 1);
187 assert_eq!(loaded[0].content, "hi from db");
188
189 store.clear().await.unwrap();
190 let cleared = store.load().await.unwrap();
191 assert!(cleared.is_empty());
192 }
193}