Skip to main content

limit_llm/
persistence.rs

1use crate::error::LlmError;
2use crate::types::Message;
3use bincode::{deserialize, serialize};
4use serde::{Deserialize, Serialize};
5use std::fs;
6use std::path::Path;
7
8const CURRENT_VERSION: u32 = 2;
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
11struct PersistedState {
12    version: u32,
13    messages: Vec<PersistedMessage>,
14}
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
17struct PersistedMessage {
18    role: PersistedRole,
19    content: Option<String>,
20    tool_calls: Option<Vec<crate::types::ToolCall>>,
21    tool_call_id: Option<String>,
22}
23
24#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
25enum PersistedRole {
26    User,
27    Assistant,
28    System,
29    Tool,
30}
31
32impl From<PersistedRole> for crate::types::Role {
33    fn from(role: PersistedRole) -> Self {
34        match role {
35            PersistedRole::User => crate::types::Role::User,
36            PersistedRole::Assistant => crate::types::Role::Assistant,
37            PersistedRole::System => crate::types::Role::System,
38            PersistedRole::Tool => crate::types::Role::Tool,
39        }
40    }
41}
42
43impl From<crate::types::Role> for PersistedRole {
44    fn from(role: crate::types::Role) -> Self {
45        match role {
46            crate::types::Role::User => PersistedRole::User,
47            crate::types::Role::Assistant => PersistedRole::Assistant,
48            crate::types::Role::System => PersistedRole::System,
49            crate::types::Role::Tool => PersistedRole::Tool,
50        }
51    }
52}
53
54impl From<PersistedMessage> for Message {
55    fn from(msg: PersistedMessage) -> Self {
56        Message {
57            role: msg.role.into(),
58            content: msg.content,
59            tool_calls: msg.tool_calls,
60            tool_call_id: msg.tool_call_id,
61        }
62    }
63}
64
65impl From<Message> for PersistedMessage {
66    fn from(msg: Message) -> Self {
67        PersistedMessage {
68            role: msg.role.into(),
69            content: msg.content,
70            tool_calls: msg.tool_calls,
71            tool_call_id: msg.tool_call_id,
72        }
73    }
74}
75
76pub struct StatePersistence {
77    file_path: std::path::PathBuf,
78}
79
80impl StatePersistence {
81    pub fn new<P: AsRef<Path>>(file_path: P) -> Self {
82        Self {
83            file_path: file_path.as_ref().to_path_buf(),
84        }
85    }
86
87    pub fn save(&self, messages: &[Message]) -> Result<(), LlmError> {
88        let persisted_messages: Vec<PersistedMessage> =
89            messages.iter().cloned().map(|m| m.into()).collect();
90
91        let state = PersistedState {
92            version: CURRENT_VERSION,
93            messages: persisted_messages,
94        };
95
96        let serialized = serialize(&state)
97            .map_err(|e| LlmError::PersistenceError(format!("Failed to serialize state: {}", e)))?;
98
99        fs::write(&self.file_path, serialized)
100            .map_err(|e| LlmError::PersistenceError(format!("Failed to write file: {}", e)))?;
101
102        Ok(())
103    }
104
105    pub fn load(&self) -> Result<Vec<Message>, LlmError> {
106        let data = match fs::read(&self.file_path) {
107            Ok(data) => data,
108            Err(e) if e.kind() == std::io::ErrorKind::NotFound => return Ok(vec![]),
109            Err(e) => {
110                return Err(LlmError::PersistenceError(format!(
111                    "Failed to read file: {}",
112                    e
113                )))
114            }
115        };
116
117        let state: PersistedState = deserialize(&data).map_err(|e| {
118            LlmError::PersistenceError(format!("Failed to deserialize state: {}", e))
119        })?;
120
121        // Handle version migration if needed
122        if state.version > CURRENT_VERSION {
123            return Err(LlmError::PersistenceError(format!(
124                "Version mismatch: expected {}, found {}",
125                CURRENT_VERSION, state.version
126            )));
127        }
128
129        let messages: Vec<Message> = state.messages.into_iter().map(|m| m.into()).collect();
130
131        Ok(messages)
132    }
133}
134
135#[cfg(test)]
136mod tests {
137    use super::*;
138    use crate::types::Role;
139    use tempfile::tempdir;
140
141    #[test]
142    fn test_save_load_roundtrip() {
143        let dir = tempdir().unwrap();
144        let file_path = dir.path().join("test_state.bin");
145
146        let persistence = StatePersistence::new(&file_path);
147
148        let messages = vec![
149            Message {
150                role: Role::User,
151                content: Some("Hello".to_string()),
152                tool_calls: None,
153                tool_call_id: None,
154            },
155            Message {
156                role: Role::Assistant,
157                content: Some("Hi there!".to_string()),
158                tool_calls: None,
159                tool_call_id: None,
160            },
161        ];
162
163        persistence.save(&messages).unwrap();
164
165        let loaded = persistence.load().unwrap();
166
167        assert_eq!(loaded.len(), messages.len());
168        assert_eq!(loaded[0].content, messages[0].content);
169        assert_eq!(loaded[1].content, messages[1].content);
170    }
171
172    #[test]
173    fn test_save_load_with_tool_result() {
174        let dir = tempdir().unwrap();
175        let file_path = dir.path().join("test_state.bin");
176
177        let persistence = StatePersistence::new(&file_path);
178
179        let messages = vec![Message {
180            role: Role::Tool,
181            content: Some("tool output".to_string()),
182            tool_calls: None,
183            tool_call_id: Some("call_123".to_string()),
184        }];
185
186        persistence.save(&messages).unwrap();
187
188        let loaded = persistence.load().unwrap();
189
190        assert_eq!(loaded[0].role, Role::Tool);
191        assert_eq!(loaded[0].tool_call_id, Some("call_123".to_string()));
192    }
193
194    #[test]
195    fn test_load_empty_file_returns_empty() {
196        let dir = tempdir().unwrap();
197        let file_path = dir.path().join("nonexistent.bin");
198
199        let persistence = StatePersistence::new(&file_path);
200
201        let loaded = persistence.load().unwrap();
202
203        assert!(loaded.is_empty());
204    }
205
206    #[test]
207    fn test_load_corrupted_file_returns_error() {
208        let dir = tempdir().unwrap();
209        let file_path = dir.path().join("corrupted.bin");
210
211        std::fs::write(&file_path, b"invalid binary data").unwrap();
212
213        let persistence = StatePersistence::new(&file_path);
214
215        let result = persistence.load();
216
217        assert!(result.is_err());
218    }
219
220    #[test]
221    fn test_role_conversion() {
222        assert_eq!(PersistedRole::User, Role::User.into());
223        assert_eq!(PersistedRole::Assistant, Role::Assistant.into());
224        assert_eq!(PersistedRole::System, Role::System.into());
225        assert_eq!(PersistedRole::Tool, Role::Tool.into());
226
227        assert_eq!(Role::User, PersistedRole::User.into());
228        assert_eq!(Role::Assistant, PersistedRole::Assistant.into());
229        assert_eq!(Role::System, PersistedRole::System.into());
230        assert_eq!(Role::Tool, PersistedRole::Tool.into());
231    }
232}