Skip to main content

albert_runtime/
session.rs

1use std::collections::BTreeMap;
2use std::fmt::{Display, Formatter};
3use std::fs;
4use std::path::Path;
5
6use crate::json::{JsonError, JsonValue};
7use crate::usage::TokenUsage;
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub enum MessageRole {
11    System,
12    User,
13    Assistant,
14    Tool,
15}
16
17#[derive(Debug, Clone, PartialEq, Eq)]
18pub enum ContentBlock {
19    Text {
20        text: String,
21    },
22    ToolUse {
23        id: String,
24        name: String,
25        input: String,
26    },
27    ToolResult {
28        tool_use_id: String,
29        tool_name: String,
30        output: String,
31        is_error: bool,
32    },
33    /// Attached image — persisted as a placeholder text in session files.
34    Image {
35        media_type: String,
36        data: String,
37    },
38}
39
40#[derive(Debug, Clone, PartialEq, Eq)]
41pub struct ConversationMessage {
42    pub role: MessageRole,
43    pub blocks: Vec<ContentBlock>,
44    pub usage: Option<TokenUsage>,
45}
46
47#[derive(Debug, Clone, PartialEq, Eq)]
48pub struct Session {
49    pub version: u32,
50    pub messages: Vec<ConversationMessage>,
51}
52
53#[derive(Debug)]
54pub enum SessionError {
55    Io(std::io::Error),
56    Json(JsonError),
57    Format(String),
58}
59
60impl Display for SessionError {
61    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
62        match self {
63            Self::Io(error) => write!(f, "{error}"),
64            Self::Json(error) => write!(f, "{error}"),
65            Self::Format(error) => write!(f, "{error}"),
66        }
67    }
68}
69
70impl std::error::Error for SessionError {}
71
72impl From<std::io::Error> for SessionError {
73    fn from(value: std::io::Error) -> Self {
74        Self::Io(value)
75    }
76}
77
78impl From<JsonError> for SessionError {
79    fn from(value: JsonError) -> Self {
80        Self::Json(value)
81    }
82}
83
84impl Session {
85    #[must_use]
86    pub fn new() -> Self {
87        Self {
88            version: 1,
89            messages: Vec::new(),
90        }
91    }
92
93    pub fn save_to_path(&self, path: impl AsRef<Path>) -> Result<(), SessionError> {
94        fs::write(path, self.to_json().render())?;
95        Ok(())
96    }
97
98    pub fn load_from_path(path: impl AsRef<Path>) -> Result<Self, SessionError> {
99        let contents = fs::read_to_string(path)?;
100        Self::from_json(&JsonValue::parse(&contents)?)
101    }
102
103    #[must_use]
104    pub fn to_json(&self) -> JsonValue {
105        let mut object = BTreeMap::new();
106        object.insert(
107            "version".to_string(),
108            JsonValue::Number(i64::from(self.version)),
109        );
110        object.insert(
111            "messages".to_string(),
112            JsonValue::Array(
113                self.messages
114                    .iter()
115                    .map(ConversationMessage::to_json)
116                    .collect(),
117            ),
118        );
119        JsonValue::Object(object)
120    }
121
122    pub fn from_json(value: &JsonValue) -> Result<Self, SessionError> {
123        let object = value
124            .as_object()
125            .ok_or_else(|| SessionError::Format("session must be an object".to_string()))?;
126        let version = object
127            .get("version")
128            .and_then(JsonValue::as_i64)
129            .ok_or_else(|| SessionError::Format("missing version".to_string()))?;
130        let version = u32::try_from(version)
131            .map_err(|_| SessionError::Format("version out of range".to_string()))?;
132        let messages = object
133            .get("messages")
134            .and_then(JsonValue::as_array)
135            .ok_or_else(|| SessionError::Format("missing messages".to_string()))?
136            .iter()
137            .map(ConversationMessage::from_json)
138            .collect::<Result<Vec<_>, _>>()?;
139        Ok(Self { version, messages })
140    }
141}
142
143impl Default for Session {
144    fn default() -> Self {
145        Self::new()
146    }
147}
148
149impl ConversationMessage {
150    #[must_use]
151    pub fn user_text(text: impl Into<String>) -> Self {
152        Self {
153            role: MessageRole::User,
154            blocks: vec![ContentBlock::Text { text: text.into() }],
155            usage: None,
156        }
157    }
158
159    #[must_use]
160    pub fn assistant(blocks: Vec<ContentBlock>) -> Self {
161        Self {
162            role: MessageRole::Assistant,
163            blocks,
164            usage: None,
165        }
166    }
167
168    #[must_use]
169    pub fn assistant_with_usage(blocks: Vec<ContentBlock>, usage: Option<TokenUsage>) -> Self {
170        Self {
171            role: MessageRole::Assistant,
172            blocks,
173            usage,
174        }
175    }
176
177    #[must_use]
178    pub fn tool_result(
179        tool_use_id: impl Into<String>,
180        tool_name: impl Into<String>,
181        output: impl Into<String>,
182        is_error: bool,
183    ) -> Self {
184        Self {
185            role: MessageRole::Tool,
186            blocks: vec![ContentBlock::ToolResult {
187                tool_use_id: tool_use_id.into(),
188                tool_name: tool_name.into(),
189                output: output.into(),
190                is_error,
191            }],
192            usage: None,
193        }
194    }
195
196    #[must_use]
197    pub fn to_json(&self) -> JsonValue {
198        let mut object = BTreeMap::new();
199        object.insert(
200            "role".to_string(),
201            JsonValue::String(
202                match self.role {
203                    MessageRole::System => "system",
204                    MessageRole::User => "user",
205                    MessageRole::Assistant => "assistant",
206                    MessageRole::Tool => "tool",
207                }
208                .to_string(),
209            ),
210        );
211        object.insert(
212            "blocks".to_string(),
213            JsonValue::Array(self.blocks.iter().map(ContentBlock::to_json).collect()),
214        );
215        if let Some(usage) = self.usage {
216            object.insert("usage".to_string(), usage_to_json(usage));
217        }
218        JsonValue::Object(object)
219    }
220
221    fn from_json(value: &JsonValue) -> Result<Self, SessionError> {
222        let object = value
223            .as_object()
224            .ok_or_else(|| SessionError::Format("message must be an object".to_string()))?;
225        let role = match object
226            .get("role")
227            .and_then(JsonValue::as_str)
228            .ok_or_else(|| SessionError::Format("missing role".to_string()))?
229        {
230            "system" => MessageRole::System,
231            "user" => MessageRole::User,
232            "assistant" => MessageRole::Assistant,
233            "tool" => MessageRole::Tool,
234            other => {
235                return Err(SessionError::Format(format!(
236                    "unsupported message role: {other}"
237                )))
238            }
239        };
240        let blocks = object
241            .get("blocks")
242            .and_then(JsonValue::as_array)
243            .ok_or_else(|| SessionError::Format("missing blocks".to_string()))?
244            .iter()
245            .map(ContentBlock::from_json)
246            .collect::<Result<Vec<_>, _>>()?;
247        let usage = object.get("usage").map(usage_from_json).transpose()?;
248        Ok(Self {
249            role,
250            blocks,
251            usage,
252        })
253    }
254}
255
256impl ContentBlock {
257    #[must_use]
258    pub fn to_json(&self) -> JsonValue {
259        let mut object = BTreeMap::new();
260        match self {
261            Self::Text { text } => {
262                object.insert("type".to_string(), JsonValue::String("text".to_string()));
263                object.insert("text".to_string(), JsonValue::String(text.clone()));
264            }
265            Self::ToolUse { id, name, input } => {
266                object.insert(
267                    "type".to_string(),
268                    JsonValue::String("tool_use".to_string()),
269                );
270                object.insert("id".to_string(), JsonValue::String(id.clone()));
271                object.insert("name".to_string(), JsonValue::String(name.clone()));
272                object.insert("input".to_string(), JsonValue::String(input.clone()));
273            }
274            Self::ToolResult {
275                tool_use_id,
276                tool_name,
277                output,
278                is_error,
279            } => {
280                object.insert(
281                    "type".to_string(),
282                    JsonValue::String("tool_result".to_string()),
283                );
284                object.insert(
285                    "tool_use_id".to_string(),
286                    JsonValue::String(tool_use_id.clone()),
287                );
288                object.insert(
289                    "tool_name".to_string(),
290                    JsonValue::String(tool_name.clone()),
291                );
292                object.insert("output".to_string(), JsonValue::String(output.clone()));
293                object.insert("is_error".to_string(), JsonValue::Bool(*is_error));
294            }
295            // Store only the mime type — base64 data is not persisted.
296            Self::Image { media_type, .. } => {
297                object.insert("type".to_string(), JsonValue::String("text".to_string()));
298                object.insert(
299                    "text".to_string(),
300                    JsonValue::String(format!("[image: {media_type}]")),
301                );
302            }
303        }
304        JsonValue::Object(object)
305    }
306
307    fn from_json(value: &JsonValue) -> Result<Self, SessionError> {
308        let object = value
309            .as_object()
310            .ok_or_else(|| SessionError::Format("block must be an object".to_string()))?;
311        match object
312            .get("type")
313            .and_then(JsonValue::as_str)
314            .ok_or_else(|| SessionError::Format("missing block type".to_string()))?
315        {
316            "text" => Ok(Self::Text {
317                text: required_string(object, "text")?,
318            }),
319            "tool_use" => Ok(Self::ToolUse {
320                id: required_string(object, "id")?,
321                name: required_string(object, "name")?,
322                input: required_string(object, "input")?,
323            }),
324            "tool_result" => Ok(Self::ToolResult {
325                tool_use_id: required_string(object, "tool_use_id")?,
326                tool_name: required_string(object, "tool_name")?,
327                output: required_string(object, "output")?,
328                is_error: object
329                    .get("is_error")
330                    .and_then(JsonValue::as_bool)
331                    .ok_or_else(|| SessionError::Format("missing is_error".to_string()))?,
332            }),
333            other => Err(SessionError::Format(format!(
334                "unsupported block type: {other}"
335            ))),
336        }
337    }
338}
339
340fn usage_to_json(usage: TokenUsage) -> JsonValue {
341    let mut object = BTreeMap::new();
342    object.insert(
343        "input_tokens".to_string(),
344        JsonValue::Number(i64::from(usage.input_tokens)),
345    );
346    object.insert(
347        "output_tokens".to_string(),
348        JsonValue::Number(i64::from(usage.output_tokens)),
349    );
350    object.insert(
351        "cache_creation_input_tokens".to_string(),
352        JsonValue::Number(i64::from(usage.cache_creation_input_tokens)),
353    );
354    object.insert(
355        "cache_read_input_tokens".to_string(),
356        JsonValue::Number(i64::from(usage.cache_read_input_tokens)),
357    );
358    JsonValue::Object(object)
359}
360
361fn usage_from_json(value: &JsonValue) -> Result<TokenUsage, SessionError> {
362    let object = value
363        .as_object()
364        .ok_or_else(|| SessionError::Format("usage must be an object".to_string()))?;
365    Ok(TokenUsage {
366        input_tokens: required_u32(object, "input_tokens")?,
367        output_tokens: required_u32(object, "output_tokens")?,
368        cache_creation_input_tokens: required_u32(object, "cache_creation_input_tokens")?,
369        cache_read_input_tokens: required_u32(object, "cache_read_input_tokens")?,
370    })
371}
372
373fn required_string(
374    object: &BTreeMap<String, JsonValue>,
375    key: &str,
376) -> Result<String, SessionError> {
377    object
378        .get(key)
379        .and_then(JsonValue::as_str)
380        .map(ToOwned::to_owned)
381        .ok_or_else(|| SessionError::Format(format!("missing {key}")))
382}
383
384fn required_u32(object: &BTreeMap<String, JsonValue>, key: &str) -> Result<u32, SessionError> {
385    let value = object
386        .get(key)
387        .and_then(JsonValue::as_i64)
388        .ok_or_else(|| SessionError::Format(format!("missing {key}")))?;
389    u32::try_from(value).map_err(|_| SessionError::Format(format!("{key} out of range")))
390}
391
392#[cfg(test)]
393mod tests {
394    use super::{ContentBlock, ConversationMessage, MessageRole, Session};
395    use crate::usage::TokenUsage;
396    use std::fs;
397    use std::time::{SystemTime, UNIX_EPOCH};
398
399    #[test]
400    fn persists_and_restores_session_json() {
401        let mut session = Session::new();
402        session
403            .messages
404            .push(ConversationMessage::user_text("hello"));
405        session
406            .messages
407            .push(ConversationMessage::assistant_with_usage(
408                vec![
409                    ContentBlock::Text {
410                        text: "thinking".to_string(),
411                    },
412                    ContentBlock::ToolUse {
413                        id: "tool-1".to_string(),
414                        name: "bash".to_string(),
415                        input: "echo hi".to_string(),
416                    },
417                ],
418                Some(TokenUsage {
419                    input_tokens: 10,
420                    output_tokens: 4,
421                    cache_creation_input_tokens: 1,
422                    cache_read_input_tokens: 2,
423                }),
424            ));
425        session.messages.push(ConversationMessage::tool_result(
426            "tool-1", "bash", "hi", false,
427        ));
428
429        let nanos = SystemTime::now()
430            .duration_since(UNIX_EPOCH)
431            .expect("system time should be after epoch")
432            .as_nanos();
433        let path = std::env::temp_dir().join(format!("runtime-session-{nanos}.json"));
434        session.save_to_path(&path).expect("session should save");
435        let restored = Session::load_from_path(&path).expect("session should load");
436        fs::remove_file(&path).expect("temp file should be removable");
437
438        assert_eq!(restored, session);
439        assert_eq!(restored.messages[2].role, MessageRole::Tool);
440        assert_eq!(
441            restored.messages[1].usage.expect("usage").total_tokens(),
442            17
443        );
444    }
445}