Skip to main content

limit_llm/
persistence.rs

1use crate::error::LlmError;
2use crate::types::Message;
3use bincode::{deserialize, serialize};
4use serde::{Deserialize, Serialize};
5use std::fs;
6use std::path::Path;
7
8const CURRENT_VERSION: u32 = 2;
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
11struct PersistedState {
12    version: u32,
13    messages: Vec<PersistedMessage>,
14}
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
17struct PersistedMessage {
18    role: PersistedRole,
19    content: Option<String>,
20    tool_calls: Option<Vec<crate::types::ToolCall>>,
21    tool_call_id: Option<String>,
22    cache_control: Option<crate::types::CacheControl>,
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
26enum PersistedRole {
27    User,
28    Assistant,
29    System,
30    Tool,
31}
32
33impl From<PersistedRole> for crate::types::Role {
34    fn from(role: PersistedRole) -> Self {
35        match role {
36            PersistedRole::User => crate::types::Role::User,
37            PersistedRole::Assistant => crate::types::Role::Assistant,
38            PersistedRole::System => crate::types::Role::System,
39            PersistedRole::Tool => crate::types::Role::Tool,
40        }
41    }
42}
43
44impl From<crate::types::Role> for PersistedRole {
45    fn from(role: crate::types::Role) -> Self {
46        match role {
47            crate::types::Role::User => PersistedRole::User,
48            crate::types::Role::Assistant => PersistedRole::Assistant,
49            crate::types::Role::System => PersistedRole::System,
50            crate::types::Role::Tool => PersistedRole::Tool,
51        }
52    }
53}
54
55impl From<PersistedMessage> for Message {
56    fn from(msg: PersistedMessage) -> Self {
57        Message {
58            role: msg.role.into(),
59            content: msg.content,
60            tool_calls: msg.tool_calls,
61            tool_call_id: msg.tool_call_id,
62            cache_control: msg.cache_control,
63        }
64    }
65}
66
67impl From<Message> for PersistedMessage {
68    fn from(msg: Message) -> Self {
69        PersistedMessage {
70            role: msg.role.into(),
71            content: msg.content,
72            tool_calls: msg.tool_calls,
73            tool_call_id: msg.tool_call_id,
74            cache_control: msg.cache_control,
75        }
76    }
77}
78
79pub struct StatePersistence {
80    file_path: std::path::PathBuf,
81}
82
83impl StatePersistence {
84    pub fn new<P: AsRef<Path>>(file_path: P) -> Self {
85        Self {
86            file_path: file_path.as_ref().to_path_buf(),
87        }
88    }
89
90    pub fn save(&self, messages: &[Message]) -> Result<(), LlmError> {
91        let persisted_messages: Vec<PersistedMessage> =
92            messages.iter().cloned().map(|m| m.into()).collect();
93
94        let state = PersistedState {
95            version: CURRENT_VERSION,
96            messages: persisted_messages,
97        };
98
99        let serialized = serialize(&state)
100            .map_err(|e| LlmError::PersistenceError(format!("Failed to serialize state: {}", e)))?;
101
102        fs::write(&self.file_path, serialized)
103            .map_err(|e| LlmError::PersistenceError(format!("Failed to write file: {}", e)))?;
104
105        Ok(())
106    }
107
108    pub fn load(&self) -> Result<Vec<Message>, LlmError> {
109        let data = match fs::read(&self.file_path) {
110            Ok(data) => data,
111            Err(e) if e.kind() == std::io::ErrorKind::NotFound => return Ok(vec![]),
112            Err(e) => {
113                return Err(LlmError::PersistenceError(format!(
114                    "Failed to read file: {}",
115                    e
116                )))
117            }
118        };
119
120        let state: PersistedState = deserialize(&data).map_err(|e| {
121            LlmError::PersistenceError(format!("Failed to deserialize state: {}", e))
122        })?;
123
124        // Handle version migration if needed
125        if state.version > CURRENT_VERSION {
126            return Err(LlmError::PersistenceError(format!(
127                "Version mismatch: expected {}, found {}",
128                CURRENT_VERSION, state.version
129            )));
130        }
131
132        let messages: Vec<Message> = state.messages.into_iter().map(|m| m.into()).collect();
133
134        Ok(messages)
135    }
136}
137
138#[cfg(test)]
139mod tests {
140    use super::*;
141    use crate::types::Role;
142    use tempfile::tempdir;
143
144    #[test]
145    fn test_save_load_roundtrip() {
146        let dir = tempdir().unwrap();
147        let file_path = dir.path().join("test_state.bin");
148
149        let persistence = StatePersistence::new(&file_path);
150
151        let messages = vec![
152            Message {
153                role: Role::User,
154                content: Some("Hello".to_string()),
155                tool_calls: None,
156                tool_call_id: None,
157                cache_control: None,
158            },
159            Message {
160                role: Role::Assistant,
161                content: Some("Hi there!".to_string()),
162                tool_calls: None,
163                tool_call_id: None,
164                cache_control: None,
165            },
166        ];
167
168        persistence.save(&messages).unwrap();
169
170        let loaded = persistence.load().unwrap();
171
172        assert_eq!(loaded.len(), messages.len());
173        assert_eq!(loaded[0].content, messages[0].content);
174        assert_eq!(loaded[1].content, messages[1].content);
175    }
176
177    #[test]
178    fn test_save_load_with_tool_result() {
179        let dir = tempdir().unwrap();
180        let file_path = dir.path().join("test_state.bin");
181
182        let persistence = StatePersistence::new(&file_path);
183
184        let messages = vec![Message {
185            role: Role::Tool,
186            content: Some("tool output".to_string()),
187            tool_calls: None,
188            tool_call_id: Some("call_123".to_string()),
189            cache_control: None,
190        }];
191
192        persistence.save(&messages).unwrap();
193
194        let loaded = persistence.load().unwrap();
195
196        assert_eq!(loaded[0].role, Role::Tool);
197        assert_eq!(loaded[0].tool_call_id, Some("call_123".to_string()));
198    }
199
200    #[test]
201    fn test_load_empty_file_returns_empty() {
202        let dir = tempdir().unwrap();
203        let file_path = dir.path().join("nonexistent.bin");
204
205        let persistence = StatePersistence::new(&file_path);
206
207        let loaded = persistence.load().unwrap();
208
209        assert!(loaded.is_empty());
210    }
211
212    #[test]
213    fn test_load_corrupted_file_returns_error() {
214        let dir = tempdir().unwrap();
215        let file_path = dir.path().join("corrupted.bin");
216
217        std::fs::write(&file_path, b"invalid binary data").unwrap();
218
219        let persistence = StatePersistence::new(&file_path);
220
221        let result = persistence.load();
222
223        assert!(result.is_err());
224    }
225
226    #[test]
227    fn test_role_conversion() {
228        assert_eq!(PersistedRole::User, Role::User.into());
229        assert_eq!(PersistedRole::Assistant, Role::Assistant.into());
230        assert_eq!(PersistedRole::System, Role::System.into());
231        assert_eq!(PersistedRole::Tool, Role::Tool.into());
232
233        assert_eq!(Role::User, PersistedRole::User.into());
234        assert_eq!(Role::Assistant, PersistedRole::Assistant.into());
235        assert_eq!(Role::System, PersistedRole::System.into());
236        assert_eq!(Role::Tool, PersistedRole::Tool.into());
237    }
238}