Skip to main content

albert_runtime/
session.rs

1use std::collections::BTreeMap;
2use std::fmt::{Display, Formatter};
3use std::fs;
4use std::path::Path;
5
6use crate::json::{JsonError, JsonValue};
7use crate::usage::TokenUsage;
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub enum MessageRole {
11    System,
12    User,
13    Assistant,
14    Tool,
15}
16
17#[derive(Debug, Clone, PartialEq, Eq)]
18pub enum ContentBlock {
19    Text {
20        text: String,
21    },
22    ToolUse {
23        id: String,
24        name: String,
25        input: String,
26    },
27    ToolResult {
28        tool_use_id: String,
29        tool_name: String,
30        output: String,
31        is_error: bool,
32    },
33}
34
35#[derive(Debug, Clone, PartialEq, Eq)]
36pub struct ConversationMessage {
37    pub role: MessageRole,
38    pub blocks: Vec<ContentBlock>,
39    pub usage: Option<TokenUsage>,
40}
41
42#[derive(Debug, Clone, PartialEq, Eq)]
43pub struct Session {
44    pub version: u32,
45    pub messages: Vec<ConversationMessage>,
46}
47
48#[derive(Debug)]
49pub enum SessionError {
50    Io(std::io::Error),
51    Json(JsonError),
52    Format(String),
53}
54
55impl Display for SessionError {
56    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
57        match self {
58            Self::Io(error) => write!(f, "{error}"),
59            Self::Json(error) => write!(f, "{error}"),
60            Self::Format(error) => write!(f, "{error}"),
61        }
62    }
63}
64
65impl std::error::Error for SessionError {}
66
67impl From<std::io::Error> for SessionError {
68    fn from(value: std::io::Error) -> Self {
69        Self::Io(value)
70    }
71}
72
73impl From<JsonError> for SessionError {
74    fn from(value: JsonError) -> Self {
75        Self::Json(value)
76    }
77}
78
79impl Session {
80    #[must_use]
81    pub fn new() -> Self {
82        Self {
83            version: 1,
84            messages: Vec::new(),
85        }
86    }
87
88    pub fn save_to_path(&self, path: impl AsRef<Path>) -> Result<(), SessionError> {
89        fs::write(path, self.to_json().render())?;
90        Ok(())
91    }
92
93    pub fn load_from_path(path: impl AsRef<Path>) -> Result<Self, SessionError> {
94        let contents = fs::read_to_string(path)?;
95        Self::from_json(&JsonValue::parse(&contents)?)
96    }
97
98    #[must_use]
99    pub fn to_json(&self) -> JsonValue {
100        let mut object = BTreeMap::new();
101        object.insert(
102            "version".to_string(),
103            JsonValue::Number(i64::from(self.version)),
104        );
105        object.insert(
106            "messages".to_string(),
107            JsonValue::Array(
108                self.messages
109                    .iter()
110                    .map(ConversationMessage::to_json)
111                    .collect(),
112            ),
113        );
114        JsonValue::Object(object)
115    }
116
117    pub fn from_json(value: &JsonValue) -> Result<Self, SessionError> {
118        let object = value
119            .as_object()
120            .ok_or_else(|| SessionError::Format("session must be an object".to_string()))?;
121        let version = object
122            .get("version")
123            .and_then(JsonValue::as_i64)
124            .ok_or_else(|| SessionError::Format("missing version".to_string()))?;
125        let version = u32::try_from(version)
126            .map_err(|_| SessionError::Format("version out of range".to_string()))?;
127        let messages = object
128            .get("messages")
129            .and_then(JsonValue::as_array)
130            .ok_or_else(|| SessionError::Format("missing messages".to_string()))?
131            .iter()
132            .map(ConversationMessage::from_json)
133            .collect::<Result<Vec<_>, _>>()?;
134        Ok(Self { version, messages })
135    }
136}
137
138impl Default for Session {
139    fn default() -> Self {
140        Self::new()
141    }
142}
143
144impl ConversationMessage {
145    #[must_use]
146    pub fn user_text(text: impl Into<String>) -> Self {
147        Self {
148            role: MessageRole::User,
149            blocks: vec![ContentBlock::Text { text: text.into() }],
150            usage: None,
151        }
152    }
153
154    #[must_use]
155    pub fn assistant(blocks: Vec<ContentBlock>) -> Self {
156        Self {
157            role: MessageRole::Assistant,
158            blocks,
159            usage: None,
160        }
161    }
162
163    #[must_use]
164    pub fn assistant_with_usage(blocks: Vec<ContentBlock>, usage: Option<TokenUsage>) -> Self {
165        Self {
166            role: MessageRole::Assistant,
167            blocks,
168            usage,
169        }
170    }
171
172    #[must_use]
173    pub fn tool_result(
174        tool_use_id: impl Into<String>,
175        tool_name: impl Into<String>,
176        output: impl Into<String>,
177        is_error: bool,
178    ) -> Self {
179        Self {
180            role: MessageRole::Tool,
181            blocks: vec![ContentBlock::ToolResult {
182                tool_use_id: tool_use_id.into(),
183                tool_name: tool_name.into(),
184                output: output.into(),
185                is_error,
186            }],
187            usage: None,
188        }
189    }
190
191    #[must_use]
192    pub fn to_json(&self) -> JsonValue {
193        let mut object = BTreeMap::new();
194        object.insert(
195            "role".to_string(),
196            JsonValue::String(
197                match self.role {
198                    MessageRole::System => "system",
199                    MessageRole::User => "user",
200                    MessageRole::Assistant => "assistant",
201                    MessageRole::Tool => "tool",
202                }
203                .to_string(),
204            ),
205        );
206        object.insert(
207            "blocks".to_string(),
208            JsonValue::Array(self.blocks.iter().map(ContentBlock::to_json).collect()),
209        );
210        if let Some(usage) = self.usage {
211            object.insert("usage".to_string(), usage_to_json(usage));
212        }
213        JsonValue::Object(object)
214    }
215
216    fn from_json(value: &JsonValue) -> Result<Self, SessionError> {
217        let object = value
218            .as_object()
219            .ok_or_else(|| SessionError::Format("message must be an object".to_string()))?;
220        let role = match object
221            .get("role")
222            .and_then(JsonValue::as_str)
223            .ok_or_else(|| SessionError::Format("missing role".to_string()))?
224        {
225            "system" => MessageRole::System,
226            "user" => MessageRole::User,
227            "assistant" => MessageRole::Assistant,
228            "tool" => MessageRole::Tool,
229            other => {
230                return Err(SessionError::Format(format!(
231                    "unsupported message role: {other}"
232                )))
233            }
234        };
235        let blocks = object
236            .get("blocks")
237            .and_then(JsonValue::as_array)
238            .ok_or_else(|| SessionError::Format("missing blocks".to_string()))?
239            .iter()
240            .map(ContentBlock::from_json)
241            .collect::<Result<Vec<_>, _>>()?;
242        let usage = object.get("usage").map(usage_from_json).transpose()?;
243        Ok(Self {
244            role,
245            blocks,
246            usage,
247        })
248    }
249}
250
251impl ContentBlock {
252    #[must_use]
253    pub fn to_json(&self) -> JsonValue {
254        let mut object = BTreeMap::new();
255        match self {
256            Self::Text { text } => {
257                object.insert("type".to_string(), JsonValue::String("text".to_string()));
258                object.insert("text".to_string(), JsonValue::String(text.clone()));
259            }
260            Self::ToolUse { id, name, input } => {
261                object.insert(
262                    "type".to_string(),
263                    JsonValue::String("tool_use".to_string()),
264                );
265                object.insert("id".to_string(), JsonValue::String(id.clone()));
266                object.insert("name".to_string(), JsonValue::String(name.clone()));
267                object.insert("input".to_string(), JsonValue::String(input.clone()));
268            }
269            Self::ToolResult {
270                tool_use_id,
271                tool_name,
272                output,
273                is_error,
274            } => {
275                object.insert(
276                    "type".to_string(),
277                    JsonValue::String("tool_result".to_string()),
278                );
279                object.insert(
280                    "tool_use_id".to_string(),
281                    JsonValue::String(tool_use_id.clone()),
282                );
283                object.insert(
284                    "tool_name".to_string(),
285                    JsonValue::String(tool_name.clone()),
286                );
287                object.insert("output".to_string(), JsonValue::String(output.clone()));
288                object.insert("is_error".to_string(), JsonValue::Bool(*is_error));
289            }
290        }
291        JsonValue::Object(object)
292    }
293
294    fn from_json(value: &JsonValue) -> Result<Self, SessionError> {
295        let object = value
296            .as_object()
297            .ok_or_else(|| SessionError::Format("block must be an object".to_string()))?;
298        match object
299            .get("type")
300            .and_then(JsonValue::as_str)
301            .ok_or_else(|| SessionError::Format("missing block type".to_string()))?
302        {
303            "text" => Ok(Self::Text {
304                text: required_string(object, "text")?,
305            }),
306            "tool_use" => Ok(Self::ToolUse {
307                id: required_string(object, "id")?,
308                name: required_string(object, "name")?,
309                input: required_string(object, "input")?,
310            }),
311            "tool_result" => Ok(Self::ToolResult {
312                tool_use_id: required_string(object, "tool_use_id")?,
313                tool_name: required_string(object, "tool_name")?,
314                output: required_string(object, "output")?,
315                is_error: object
316                    .get("is_error")
317                    .and_then(JsonValue::as_bool)
318                    .ok_or_else(|| SessionError::Format("missing is_error".to_string()))?,
319            }),
320            other => Err(SessionError::Format(format!(
321                "unsupported block type: {other}"
322            ))),
323        }
324    }
325}
326
327fn usage_to_json(usage: TokenUsage) -> JsonValue {
328    let mut object = BTreeMap::new();
329    object.insert(
330        "input_tokens".to_string(),
331        JsonValue::Number(i64::from(usage.input_tokens)),
332    );
333    object.insert(
334        "output_tokens".to_string(),
335        JsonValue::Number(i64::from(usage.output_tokens)),
336    );
337    object.insert(
338        "cache_creation_input_tokens".to_string(),
339        JsonValue::Number(i64::from(usage.cache_creation_input_tokens)),
340    );
341    object.insert(
342        "cache_read_input_tokens".to_string(),
343        JsonValue::Number(i64::from(usage.cache_read_input_tokens)),
344    );
345    JsonValue::Object(object)
346}
347
348fn usage_from_json(value: &JsonValue) -> Result<TokenUsage, SessionError> {
349    let object = value
350        .as_object()
351        .ok_or_else(|| SessionError::Format("usage must be an object".to_string()))?;
352    Ok(TokenUsage {
353        input_tokens: required_u32(object, "input_tokens")?,
354        output_tokens: required_u32(object, "output_tokens")?,
355        cache_creation_input_tokens: required_u32(object, "cache_creation_input_tokens")?,
356        cache_read_input_tokens: required_u32(object, "cache_read_input_tokens")?,
357    })
358}
359
360fn required_string(
361    object: &BTreeMap<String, JsonValue>,
362    key: &str,
363) -> Result<String, SessionError> {
364    object
365        .get(key)
366        .and_then(JsonValue::as_str)
367        .map(ToOwned::to_owned)
368        .ok_or_else(|| SessionError::Format(format!("missing {key}")))
369}
370
371fn required_u32(object: &BTreeMap<String, JsonValue>, key: &str) -> Result<u32, SessionError> {
372    let value = object
373        .get(key)
374        .and_then(JsonValue::as_i64)
375        .ok_or_else(|| SessionError::Format(format!("missing {key}")))?;
376    u32::try_from(value).map_err(|_| SessionError::Format(format!("{key} out of range")))
377}
378
379#[cfg(test)]
380mod tests {
381    use super::{ContentBlock, ConversationMessage, MessageRole, Session};
382    use crate::usage::TokenUsage;
383    use std::fs;
384    use std::time::{SystemTime, UNIX_EPOCH};
385
386    #[test]
387    fn persists_and_restores_session_json() {
388        let mut session = Session::new();
389        session
390            .messages
391            .push(ConversationMessage::user_text("hello"));
392        session
393            .messages
394            .push(ConversationMessage::assistant_with_usage(
395                vec![
396                    ContentBlock::Text {
397                        text: "thinking".to_string(),
398                    },
399                    ContentBlock::ToolUse {
400                        id: "tool-1".to_string(),
401                        name: "bash".to_string(),
402                        input: "echo hi".to_string(),
403                    },
404                ],
405                Some(TokenUsage {
406                    input_tokens: 10,
407                    output_tokens: 4,
408                    cache_creation_input_tokens: 1,
409                    cache_read_input_tokens: 2,
410                }),
411            ));
412        session.messages.push(ConversationMessage::tool_result(
413            "tool-1", "bash", "hi", false,
414        ));
415
416        let nanos = SystemTime::now()
417            .duration_since(UNIX_EPOCH)
418            .expect("system time should be after epoch")
419            .as_nanos();
420        let path = std::env::temp_dir().join(format!("runtime-session-{nanos}.json"));
421        session.save_to_path(&path).expect("session should save");
422        let restored = Session::load_from_path(&path).expect("session should load");
423        fs::remove_file(&path).expect("temp file should be removable");
424
425        assert_eq!(restored, session);
426        assert_eq!(restored.messages[2].role, MessageRole::Tool);
427        assert_eq!(
428            restored.messages[1].usage.expect("usage").total_tokens(),
429            17
430        );
431    }
432}