sayr_engine/
storage.rs

1use async_trait::async_trait;
2use sqlx::{sqlite::SqlitePoolOptions, Row, SqlitePool};
3use tokio::{fs, io::AsyncWriteExt};
4
5use crate::error::{AgnoError, Result};
6use crate::message::Message;
7
8/// Generic persistence contract for conversation state.
9#[async_trait]
10pub trait ConversationStore: Send + Sync {
11    async fn load(&self) -> Result<Vec<Message>>;
12    async fn append(&self, message: &Message) -> Result<()>;
13    async fn clear(&self) -> Result<()>;
14}
15
16/// A simple JSONL-based store that writes messages to disk.
17pub struct FileConversationStore {
18    path: String,
19}
20
21impl FileConversationStore {
22    pub fn new(path: impl Into<String>) -> Self {
23        Self { path: path.into() }
24    }
25}
26
27#[async_trait]
28impl ConversationStore for FileConversationStore {
29    async fn load(&self) -> Result<Vec<Message>> {
30        let content = match fs::read_to_string(&self.path).await {
31            Ok(contents) => contents,
32            Err(err) if err.kind() == std::io::ErrorKind::NotFound => return Ok(Vec::new()),
33            Err(err) => {
34                return Err(AgnoError::Storage(format!(
35                    "failed to read transcript `{}`: {err}",
36                    self.path
37                )))
38            }
39        };
40
41        let mut messages = Vec::new();
42        for line in content.lines() {
43            let msg: Message = serde_json::from_str(line)?;
44            messages.push(msg);
45        }
46
47        Ok(messages)
48    }
49
50    async fn append(&self, message: &Message) -> Result<()> {
51        let mut serialized = serde_json::to_string(message)?;
52        serialized.push('\n');
53        fs::OpenOptions::new()
54            .create(true)
55            .append(true)
56            .open(&self.path)
57            .await
58            .map_err(|err| {
59                AgnoError::Storage(format!("failed to open `{}`: {err}", self.path.clone()))
60            })?
61            .write_all(serialized.as_bytes())
62            .await
63            .map_err(|err| AgnoError::Storage(format!("failed to persist message: {err}")))
64    }
65
66    async fn clear(&self) -> Result<()> {
67        fs::remove_file(&self.path)
68            .await
69            .or_else(|err| {
70                if err.kind() == std::io::ErrorKind::NotFound {
71                    Ok(())
72                } else {
73                    Err(err)
74                }
75            })
76            .map_err(|err| AgnoError::Storage(format!("failed clearing `{}`: {err}", self.path)))
77    }
78}
79
80/// Placeholder for SQL-based backends. The type compiles without requiring the
81/// database drivers and can be swapped out once the feature lands.
82pub struct SqlConversationStore {
83    pool: SqlitePool,
84}
85
86impl SqlConversationStore {
87    const INIT_STATEMENT: &'static str = r#"
88        CREATE TABLE IF NOT EXISTS messages (
89            id INTEGER PRIMARY KEY AUTOINCREMENT,
90            payload TEXT NOT NULL
91        )
92    "#;
93
94    pub async fn connect(connection_url: impl AsRef<str>) -> Result<Self> {
95        let pool = SqlitePoolOptions::new()
96            .max_connections(1)
97            .connect(connection_url.as_ref())
98            .await
99            .map_err(|err| {
100                AgnoError::Storage(format!(
101                    "failed connecting to SQL backend `{}`: {err}",
102                    connection_url.as_ref()
103                ))
104            })?;
105
106        sqlx::query(Self::INIT_STATEMENT)
107            .execute(&pool)
108            .await
109            .map_err(|err| AgnoError::Storage(format!("failed initializing schema: {err}")))?;
110
111        Ok(Self { pool })
112    }
113}
114
115#[async_trait]
116impl ConversationStore for SqlConversationStore {
117    async fn load(&self) -> Result<Vec<Message>> {
118        let rows = sqlx::query("SELECT payload FROM messages ORDER BY id ASC")
119            .fetch_all(&self.pool)
120            .await
121            .map_err(|err| AgnoError::Storage(format!("failed loading messages: {err}")))?;
122
123        rows.into_iter()
124            .map(|row| {
125                let payload: String = row.try_get("payload").map_err(|err| {
126                    AgnoError::Storage(format!("failed decoding message payload: {err}"))
127                })?;
128                serde_json::from_str(&payload)
129                    .map_err(|err| AgnoError::Storage(format!("invalid message payload: {err}")))
130            })
131            .collect()
132    }
133
134    async fn append(&self, message: &Message) -> Result<()> {
135        let payload = serde_json::to_string(message)?;
136        sqlx::query("INSERT INTO messages (payload) VALUES (?)")
137            .bind(payload)
138            .execute(&self.pool)
139            .await
140            .map(|_| ())
141            .map_err(|err| AgnoError::Storage(format!("failed writing message: {err}")))
142    }
143
144    async fn clear(&self) -> Result<()> {
145        sqlx::query("DELETE FROM messages")
146            .execute(&self.pool)
147            .await
148            .map(|_| ())
149            .map_err(|err| AgnoError::Storage(format!("failed clearing messages: {err}")))
150    }
151}
152
153#[cfg(test)]
154mod tests {
155    use super::*;
156    use crate::message::Role;
157    use tempfile::NamedTempFile;
158
159    #[tokio::test]
160    async fn file_store_round_trip() {
161        let file = NamedTempFile::new().unwrap();
162        let store = FileConversationStore::new(file.path().to_str().unwrap());
163
164        let msg = Message::user("hello");
165        store.append(&msg).await.unwrap();
166
167        let loaded = store.load().await.unwrap();
168        assert_eq!(loaded.len(), 1);
169        assert_eq!(loaded[0].role, Role::User);
170
171        store.clear().await.unwrap();
172        let cleared = store.load().await.unwrap();
173        assert!(cleared.is_empty());
174    }
175
176    #[tokio::test]
177    async fn sqlite_store_round_trip() {
178        let store = SqlConversationStore::connect("sqlite::memory:")
179            .await
180            .unwrap();
181
182        let msg = Message::assistant("hi from db");
183        store.append(&msg).await.unwrap();
184
185        let loaded = store.load().await.unwrap();
186        assert_eq!(loaded.len(), 1);
187        assert_eq!(loaded[0].content, "hi from db");
188
189        store.clear().await.unwrap();
190        let cleared = store.load().await.unwrap();
191        assert!(cleared.is_empty());
192    }
193}