1use crate::error::LlmError;
2use crate::types::Message;
3use bincode::{deserialize, serialize};
4use serde::{Deserialize, Serialize};
5use std::fs;
6use std::path::Path;
7
8const CURRENT_VERSION: u32 = 2;
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
11struct PersistedState {
12 version: u32,
13 messages: Vec<PersistedMessage>,
14}
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
17struct PersistedMessage {
18 role: PersistedRole,
19 content: Option<String>,
20 tool_calls: Option<Vec<crate::types::ToolCall>>,
21 tool_call_id: Option<String>,
22 cache_control: Option<crate::types::CacheControl>,
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
26enum PersistedRole {
27 User,
28 Assistant,
29 System,
30 Tool,
31}
32
33impl From<PersistedRole> for crate::types::Role {
34 fn from(role: PersistedRole) -> Self {
35 match role {
36 PersistedRole::User => crate::types::Role::User,
37 PersistedRole::Assistant => crate::types::Role::Assistant,
38 PersistedRole::System => crate::types::Role::System,
39 PersistedRole::Tool => crate::types::Role::Tool,
40 }
41 }
42}
43
44impl From<crate::types::Role> for PersistedRole {
45 fn from(role: crate::types::Role) -> Self {
46 match role {
47 crate::types::Role::User => PersistedRole::User,
48 crate::types::Role::Assistant => PersistedRole::Assistant,
49 crate::types::Role::System => PersistedRole::System,
50 crate::types::Role::Tool => PersistedRole::Tool,
51 }
52 }
53}
54
55impl From<PersistedMessage> for Message {
56 fn from(msg: PersistedMessage) -> Self {
57 Message {
58 role: msg.role.into(),
59 content: msg.content,
60 tool_calls: msg.tool_calls,
61 tool_call_id: msg.tool_call_id,
62 cache_control: msg.cache_control,
63 }
64 }
65}
66
67impl From<Message> for PersistedMessage {
68 fn from(msg: Message) -> Self {
69 PersistedMessage {
70 role: msg.role.into(),
71 content: msg.content,
72 tool_calls: msg.tool_calls,
73 tool_call_id: msg.tool_call_id,
74 cache_control: msg.cache_control,
75 }
76 }
77}
78
79pub struct StatePersistence {
80 file_path: std::path::PathBuf,
81}
82
83impl StatePersistence {
84 pub fn new<P: AsRef<Path>>(file_path: P) -> Self {
85 Self {
86 file_path: file_path.as_ref().to_path_buf(),
87 }
88 }
89
90 pub fn save(&self, messages: &[Message]) -> Result<(), LlmError> {
91 let persisted_messages: Vec<PersistedMessage> =
92 messages.iter().cloned().map(|m| m.into()).collect();
93
94 let state = PersistedState {
95 version: CURRENT_VERSION,
96 messages: persisted_messages,
97 };
98
99 let serialized = serialize(&state)
100 .map_err(|e| LlmError::PersistenceError(format!("Failed to serialize state: {}", e)))?;
101
102 fs::write(&self.file_path, serialized)
103 .map_err(|e| LlmError::PersistenceError(format!("Failed to write file: {}", e)))?;
104
105 Ok(())
106 }
107
108 pub fn load(&self) -> Result<Vec<Message>, LlmError> {
109 let data = match fs::read(&self.file_path) {
110 Ok(data) => data,
111 Err(e) if e.kind() == std::io::ErrorKind::NotFound => return Ok(vec![]),
112 Err(e) => {
113 return Err(LlmError::PersistenceError(format!(
114 "Failed to read file: {}",
115 e
116 )))
117 }
118 };
119
120 let state: PersistedState = deserialize(&data).map_err(|e| {
121 LlmError::PersistenceError(format!("Failed to deserialize state: {}", e))
122 })?;
123
124 if state.version > CURRENT_VERSION {
126 return Err(LlmError::PersistenceError(format!(
127 "Version mismatch: expected {}, found {}",
128 CURRENT_VERSION, state.version
129 )));
130 }
131
132 let messages: Vec<Message> = state.messages.into_iter().map(|m| m.into()).collect();
133
134 Ok(messages)
135 }
136}
137
138#[cfg(test)]
139mod tests {
140 use super::*;
141 use crate::types::Role;
142 use tempfile::tempdir;
143
144 #[test]
145 fn test_save_load_roundtrip() {
146 let dir = tempdir().unwrap();
147 let file_path = dir.path().join("test_state.bin");
148
149 let persistence = StatePersistence::new(&file_path);
150
151 let messages = vec![
152 Message {
153 role: Role::User,
154 content: Some("Hello".to_string()),
155 tool_calls: None,
156 tool_call_id: None,
157 cache_control: None,
158 },
159 Message {
160 role: Role::Assistant,
161 content: Some("Hi there!".to_string()),
162 tool_calls: None,
163 tool_call_id: None,
164 cache_control: None,
165 },
166 ];
167
168 persistence.save(&messages).unwrap();
169
170 let loaded = persistence.load().unwrap();
171
172 assert_eq!(loaded.len(), messages.len());
173 assert_eq!(loaded[0].content, messages[0].content);
174 assert_eq!(loaded[1].content, messages[1].content);
175 }
176
177 #[test]
178 fn test_save_load_with_tool_result() {
179 let dir = tempdir().unwrap();
180 let file_path = dir.path().join("test_state.bin");
181
182 let persistence = StatePersistence::new(&file_path);
183
184 let messages = vec![Message {
185 role: Role::Tool,
186 content: Some("tool output".to_string()),
187 tool_calls: None,
188 tool_call_id: Some("call_123".to_string()),
189 cache_control: None,
190 }];
191
192 persistence.save(&messages).unwrap();
193
194 let loaded = persistence.load().unwrap();
195
196 assert_eq!(loaded[0].role, Role::Tool);
197 assert_eq!(loaded[0].tool_call_id, Some("call_123".to_string()));
198 }
199
200 #[test]
201 fn test_load_empty_file_returns_empty() {
202 let dir = tempdir().unwrap();
203 let file_path = dir.path().join("nonexistent.bin");
204
205 let persistence = StatePersistence::new(&file_path);
206
207 let loaded = persistence.load().unwrap();
208
209 assert!(loaded.is_empty());
210 }
211
212 #[test]
213 fn test_load_corrupted_file_returns_error() {
214 let dir = tempdir().unwrap();
215 let file_path = dir.path().join("corrupted.bin");
216
217 std::fs::write(&file_path, b"invalid binary data").unwrap();
218
219 let persistence = StatePersistence::new(&file_path);
220
221 let result = persistence.load();
222
223 assert!(result.is_err());
224 }
225
226 #[test]
227 fn test_role_conversion() {
228 assert_eq!(PersistedRole::User, Role::User.into());
229 assert_eq!(PersistedRole::Assistant, Role::Assistant.into());
230 assert_eq!(PersistedRole::System, Role::System.into());
231 assert_eq!(PersistedRole::Tool, Role::Tool.into());
232
233 assert_eq!(Role::User, PersistedRole::User.into());
234 assert_eq!(Role::Assistant, PersistedRole::Assistant.into());
235 assert_eq!(Role::System, PersistedRole::System.into());
236 assert_eq!(Role::Tool, PersistedRole::Tool.into());
237 }
238}