1use crate::error::LlmError;
2use crate::types::Message;
3use bincode::{deserialize, serialize};
4use serde::{Deserialize, Serialize};
5use std::fs;
6use std::path::Path;
7
8const CURRENT_VERSION: u32 = 2;
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
11struct PersistedState {
12 version: u32,
13 messages: Vec<PersistedMessage>,
14}
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
17struct PersistedMessage {
18 role: PersistedRole,
19 content: Option<String>,
20 tool_calls: Option<Vec<crate::types::ToolCall>>,
21 tool_call_id: Option<String>,
22}
23
24#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
25enum PersistedRole {
26 User,
27 Assistant,
28 System,
29 Tool,
30}
31
32impl From<PersistedRole> for crate::types::Role {
33 fn from(role: PersistedRole) -> Self {
34 match role {
35 PersistedRole::User => crate::types::Role::User,
36 PersistedRole::Assistant => crate::types::Role::Assistant,
37 PersistedRole::System => crate::types::Role::System,
38 PersistedRole::Tool => crate::types::Role::Tool,
39 }
40 }
41}
42
43impl From<crate::types::Role> for PersistedRole {
44 fn from(role: crate::types::Role) -> Self {
45 match role {
46 crate::types::Role::User => PersistedRole::User,
47 crate::types::Role::Assistant => PersistedRole::Assistant,
48 crate::types::Role::System => PersistedRole::System,
49 crate::types::Role::Tool => PersistedRole::Tool,
50 }
51 }
52}
53
54impl From<PersistedMessage> for Message {
55 fn from(msg: PersistedMessage) -> Self {
56 Message {
57 role: msg.role.into(),
58 content: msg.content,
59 tool_calls: msg.tool_calls,
60 tool_call_id: msg.tool_call_id,
61 }
62 }
63}
64
65impl From<Message> for PersistedMessage {
66 fn from(msg: Message) -> Self {
67 PersistedMessage {
68 role: msg.role.into(),
69 content: msg.content,
70 tool_calls: msg.tool_calls,
71 tool_call_id: msg.tool_call_id,
72 }
73 }
74}
75
76pub struct StatePersistence {
77 file_path: std::path::PathBuf,
78}
79
80impl StatePersistence {
81 pub fn new<P: AsRef<Path>>(file_path: P) -> Self {
82 Self {
83 file_path: file_path.as_ref().to_path_buf(),
84 }
85 }
86
87 pub fn save(&self, messages: &[Message]) -> Result<(), LlmError> {
88 let persisted_messages: Vec<PersistedMessage> =
89 messages.iter().cloned().map(|m| m.into()).collect();
90
91 let state = PersistedState {
92 version: CURRENT_VERSION,
93 messages: persisted_messages,
94 };
95
96 let serialized = serialize(&state)
97 .map_err(|e| LlmError::PersistenceError(format!("Failed to serialize state: {}", e)))?;
98
99 fs::write(&self.file_path, serialized)
100 .map_err(|e| LlmError::PersistenceError(format!("Failed to write file: {}", e)))?;
101
102 Ok(())
103 }
104
105 pub fn load(&self) -> Result<Vec<Message>, LlmError> {
106 let data = match fs::read(&self.file_path) {
107 Ok(data) => data,
108 Err(e) if e.kind() == std::io::ErrorKind::NotFound => return Ok(vec![]),
109 Err(e) => {
110 return Err(LlmError::PersistenceError(format!(
111 "Failed to read file: {}",
112 e
113 )))
114 }
115 };
116
117 let state: PersistedState = deserialize(&data).map_err(|e| {
118 LlmError::PersistenceError(format!("Failed to deserialize state: {}", e))
119 })?;
120
121 if state.version > CURRENT_VERSION {
123 return Err(LlmError::PersistenceError(format!(
124 "Version mismatch: expected {}, found {}",
125 CURRENT_VERSION, state.version
126 )));
127 }
128
129 let messages: Vec<Message> = state.messages.into_iter().map(|m| m.into()).collect();
130
131 Ok(messages)
132 }
133}
134
135#[cfg(test)]
136mod tests {
137 use super::*;
138 use crate::types::Role;
139 use tempfile::tempdir;
140
141 #[test]
142 fn test_save_load_roundtrip() {
143 let dir = tempdir().unwrap();
144 let file_path = dir.path().join("test_state.bin");
145
146 let persistence = StatePersistence::new(&file_path);
147
148 let messages = vec![
149 Message {
150 role: Role::User,
151 content: Some("Hello".to_string()),
152 tool_calls: None,
153 tool_call_id: None,
154 },
155 Message {
156 role: Role::Assistant,
157 content: Some("Hi there!".to_string()),
158 tool_calls: None,
159 tool_call_id: None,
160 },
161 ];
162
163 persistence.save(&messages).unwrap();
164
165 let loaded = persistence.load().unwrap();
166
167 assert_eq!(loaded.len(), messages.len());
168 assert_eq!(loaded[0].content, messages[0].content);
169 assert_eq!(loaded[1].content, messages[1].content);
170 }
171
172 #[test]
173 fn test_save_load_with_tool_result() {
174 let dir = tempdir().unwrap();
175 let file_path = dir.path().join("test_state.bin");
176
177 let persistence = StatePersistence::new(&file_path);
178
179 let messages = vec![Message {
180 role: Role::Tool,
181 content: Some("tool output".to_string()),
182 tool_calls: None,
183 tool_call_id: Some("call_123".to_string()),
184 }];
185
186 persistence.save(&messages).unwrap();
187
188 let loaded = persistence.load().unwrap();
189
190 assert_eq!(loaded[0].role, Role::Tool);
191 assert_eq!(loaded[0].tool_call_id, Some("call_123".to_string()));
192 }
193
194 #[test]
195 fn test_load_empty_file_returns_empty() {
196 let dir = tempdir().unwrap();
197 let file_path = dir.path().join("nonexistent.bin");
198
199 let persistence = StatePersistence::new(&file_path);
200
201 let loaded = persistence.load().unwrap();
202
203 assert!(loaded.is_empty());
204 }
205
206 #[test]
207 fn test_load_corrupted_file_returns_error() {
208 let dir = tempdir().unwrap();
209 let file_path = dir.path().join("corrupted.bin");
210
211 std::fs::write(&file_path, b"invalid binary data").unwrap();
212
213 let persistence = StatePersistence::new(&file_path);
214
215 let result = persistence.load();
216
217 assert!(result.is_err());
218 }
219
220 #[test]
221 fn test_role_conversion() {
222 assert_eq!(PersistedRole::User, Role::User.into());
223 assert_eq!(PersistedRole::Assistant, Role::Assistant.into());
224 assert_eq!(PersistedRole::System, Role::System.into());
225 assert_eq!(PersistedRole::Tool, Role::Tool.into());
226
227 assert_eq!(Role::User, PersistedRole::User.into());
228 assert_eq!(Role::Assistant, PersistedRole::Assistant.into());
229 assert_eq!(Role::System, PersistedRole::System.into());
230 assert_eq!(Role::Tool, PersistedRole::Tool.into());
231 }
232}