Skip to main content

limit_llm/
persistence.rs

1use crate::error::LlmError;
2use crate::types::{Message, MessageContent};
3use serde::{Deserialize, Serialize};
4use std::fs;
5use std::path::Path;
6
7const CURRENT_VERSION: u32 = 2;
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
10struct PersistedState {
11    version: u32,
12    messages: Vec<PersistedMessage>,
13}
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
16struct PersistedMessage {
17    role: PersistedRole,
18    content: Option<String>,
19    tool_calls: Option<Vec<crate::types::ToolCall>>,
20    tool_call_id: Option<String>,
21    cache_control: Option<crate::types::CacheControl>,
22}
23
24#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
25enum PersistedRole {
26    User,
27    Assistant,
28    System,
29    Tool,
30}
31
32impl From<PersistedRole> for crate::types::Role {
33    fn from(role: PersistedRole) -> Self {
34        match role {
35            PersistedRole::User => crate::types::Role::User,
36            PersistedRole::Assistant => crate::types::Role::Assistant,
37            PersistedRole::System => crate::types::Role::System,
38            PersistedRole::Tool => crate::types::Role::Tool,
39        }
40    }
41}
42
43impl From<crate::types::Role> for PersistedRole {
44    fn from(role: crate::types::Role) -> Self {
45        match role {
46            crate::types::Role::User => PersistedRole::User,
47            crate::types::Role::Assistant => PersistedRole::Assistant,
48            crate::types::Role::System => PersistedRole::System,
49            crate::types::Role::Tool => PersistedRole::Tool,
50        }
51    }
52}
53
54impl From<PersistedMessage> for Message {
55    fn from(msg: PersistedMessage) -> Self {
56        Message {
57            role: msg.role.into(),
58            content: msg.content.map(MessageContent::Text),
59            tool_calls: msg.tool_calls,
60            tool_call_id: msg.tool_call_id,
61            cache_control: msg.cache_control,
62        }
63    }
64}
65
66impl From<Message> for PersistedMessage {
67    fn from(msg: Message) -> Self {
68        PersistedMessage {
69            role: msg.role.into(),
70            content: msg.content.map(|c| c.to_text()),
71            tool_calls: msg.tool_calls,
72            tool_call_id: msg.tool_call_id,
73            cache_control: msg.cache_control,
74        }
75    }
76}
77
78pub struct StatePersistence {
79    file_path: std::path::PathBuf,
80}
81
82impl StatePersistence {
83    pub fn new<P: AsRef<Path>>(file_path: P) -> Self {
84        Self {
85            file_path: file_path.as_ref().to_path_buf(),
86        }
87    }
88
89    pub fn save(&self, messages: &[Message]) -> Result<(), LlmError> {
90        let persisted_messages: Vec<PersistedMessage> =
91            messages.iter().cloned().map(|m| m.into()).collect();
92
93        let state = PersistedState {
94            version: CURRENT_VERSION,
95            messages: persisted_messages,
96        };
97
98        let serialized = serde_json::to_string_pretty(&state)
99            .map_err(|e| LlmError::PersistenceError(format!("Failed to serialize state: {}", e)))?;
100
101        fs::write(&self.file_path, serialized)
102            .map_err(|e| LlmError::PersistenceError(format!("Failed to write file: {}", e)))?;
103
104        Ok(())
105    }
106
107    pub fn load(&self) -> Result<Vec<Message>, LlmError> {
108        let data = match fs::read_to_string(&self.file_path) {
109            Ok(data) => data,
110            Err(e) if e.kind() == std::io::ErrorKind::NotFound => return Ok(vec![]),
111            Err(e) => {
112                return Err(LlmError::PersistenceError(format!(
113                    "Failed to read file: {}",
114                    e
115                )))
116            }
117        };
118
119        let state: PersistedState = serde_json::from_str(&data).map_err(|e| {
120            LlmError::PersistenceError(format!("Failed to deserialize state: {}", e))
121        })?;
122
123        if state.version > CURRENT_VERSION {
124            return Err(LlmError::PersistenceError(format!(
125                "Version mismatch: expected {}, found {}",
126                CURRENT_VERSION, state.version
127            )));
128        }
129
130        let messages: Vec<Message> = state.messages.into_iter().map(|m| m.into()).collect();
131
132        Ok(messages)
133    }
134}
135
136#[cfg(test)]
137mod tests {
138    use super::*;
139    use crate::types::Role;
140    use tempfile::tempdir;
141
142    #[test]
143    fn test_save_load_roundtrip() {
144        let dir = tempdir().unwrap();
145        let file_path = dir.path().join("test_state.json");
146
147        let persistence = StatePersistence::new(&file_path);
148
149        let messages = vec![
150            Message {
151                role: Role::User,
152                content: Some(crate::MessageContent::text("Hello")),
153                tool_calls: None,
154                tool_call_id: None,
155                cache_control: None,
156            },
157            Message {
158                role: Role::Assistant,
159                content: Some(crate::MessageContent::text("Hi there!")),
160                tool_calls: None,
161                tool_call_id: None,
162                cache_control: None,
163            },
164        ];
165
166        persistence.save(&messages).unwrap();
167
168        let loaded = persistence.load().unwrap();
169
170        assert_eq!(loaded.len(), messages.len());
171        assert_eq!(loaded[0].content, messages[0].content);
172        assert_eq!(loaded[1].content, messages[1].content);
173    }
174
175    #[test]
176    fn test_save_load_with_tool_result() {
177        let dir = tempdir().unwrap();
178        let file_path = dir.path().join("test_state.json");
179
180        let persistence = StatePersistence::new(&file_path);
181
182        let messages = vec![Message {
183            role: Role::Tool,
184            content: Some(crate::MessageContent::text("tool output")),
185            tool_calls: None,
186            tool_call_id: Some("call_123".to_string()),
187            cache_control: None,
188        }];
189
190        persistence.save(&messages).unwrap();
191
192        let loaded = persistence.load().unwrap();
193
194        assert_eq!(loaded[0].role, Role::Tool);
195        assert_eq!(loaded[0].tool_call_id, Some("call_123".to_string()));
196    }
197
198    #[test]
199    fn test_load_empty_file_returns_empty() {
200        let dir = tempdir().unwrap();
201        let file_path = dir.path().join("nonexistent.json");
202
203        let persistence = StatePersistence::new(&file_path);
204
205        let loaded = persistence.load().unwrap();
206
207        assert!(loaded.is_empty());
208    }
209
210    #[test]
211    fn test_load_corrupted_file_returns_error() {
212        let dir = tempdir().unwrap();
213        let file_path = dir.path().join("corrupted.json");
214
215        std::fs::write(&file_path, b"invalid json data").unwrap();
216
217        let persistence = StatePersistence::new(&file_path);
218
219        let result = persistence.load();
220
221        assert!(result.is_err());
222    }
223
224    #[test]
225    fn test_role_conversion() {
226        assert_eq!(PersistedRole::User, Role::User.into());
227        assert_eq!(PersistedRole::Assistant, Role::Assistant.into());
228        assert_eq!(PersistedRole::System, Role::System.into());
229        assert_eq!(PersistedRole::Tool, Role::Tool.into());
230
231        assert_eq!(Role::User, PersistedRole::User.into());
232        assert_eq!(Role::Assistant, PersistedRole::Assistant.into());
233        assert_eq!(Role::System, PersistedRole::System.into());
234        assert_eq!(Role::Tool, PersistedRole::Tool.into());
235    }
236}