Skip to main content

albert_runtime/
session.rs

1use std::collections::BTreeMap;
2use std::fmt::{Display, Formatter};
3use std::fs;
4use std::path::Path;
5
6use crate::json::{JsonError, JsonValue};
7use crate::usage::TokenUsage;
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub enum MessageRole {
11    System,
12    User,
13    Assistant,
14    Tool,
15}
16
17#[derive(Debug, Clone, PartialEq, Eq)]
18pub enum ContentBlock {
19    Text {
20        text: String,
21    },
22    ToolUse {
23        id: String,
24        name: String,
25        input: String,
26    },
27    ToolResult {
28        tool_use_id: String,
29        tool_name: String,
30        output: String,
31        is_error: bool,
32    },
33    /// Attached image — persisted as a placeholder text in session files.
34    Image {
35        media_type: String,
36        data: String,
37    },
38    /// Gemini reasoning/thinking block — carries thoughtSignature for tool-call continuations.
39    Thinking {
40        text: String,
41        thought_signature: Option<String>,
42    },
43}
44
45#[derive(Debug, Clone, PartialEq, Eq)]
46pub struct ConversationMessage {
47    pub role: MessageRole,
48    pub blocks: Vec<ContentBlock>,
49    pub usage: Option<TokenUsage>,
50}
51
52#[derive(Debug, Clone, PartialEq, Eq)]
53pub struct Session {
54    pub version: u32,
55    pub messages: Vec<ConversationMessage>,
56}
57
58#[derive(Debug)]
59pub enum SessionError {
60    Io(std::io::Error),
61    Json(JsonError),
62    Format(String),
63}
64
65impl Display for SessionError {
66    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
67        match self {
68            Self::Io(error) => write!(f, "{error}"),
69            Self::Json(error) => write!(f, "{error}"),
70            Self::Format(error) => write!(f, "{error}"),
71        }
72    }
73}
74
75impl std::error::Error for SessionError {}
76
77impl From<std::io::Error> for SessionError {
78    fn from(value: std::io::Error) -> Self {
79        Self::Io(value)
80    }
81}
82
83impl From<JsonError> for SessionError {
84    fn from(value: JsonError) -> Self {
85        Self::Json(value)
86    }
87}
88
89impl Session {
90    #[must_use]
91    pub fn new() -> Self {
92        Self {
93            version: 1,
94            messages: Vec::new(),
95        }
96    }
97
98    pub fn save_to_path(&self, path: impl AsRef<Path>) -> Result<(), SessionError> {
99        fs::write(path, self.to_json().render())?;
100        Ok(())
101    }
102
103    pub fn load_from_path(path: impl AsRef<Path>) -> Result<Self, SessionError> {
104        let contents = fs::read_to_string(path)?;
105        Self::from_json(&JsonValue::parse(&contents)?)
106    }
107
108    #[must_use]
109    pub fn to_json(&self) -> JsonValue {
110        let mut object = BTreeMap::new();
111        object.insert(
112            "version".to_string(),
113            JsonValue::Number(i64::from(self.version)),
114        );
115        object.insert(
116            "messages".to_string(),
117            JsonValue::Array(
118                self.messages
119                    .iter()
120                    .map(ConversationMessage::to_json)
121                    .collect(),
122            ),
123        );
124        JsonValue::Object(object)
125    }
126
127    pub fn from_json(value: &JsonValue) -> Result<Self, SessionError> {
128        let object = value
129            .as_object()
130            .ok_or_else(|| SessionError::Format("session must be an object".to_string()))?;
131        let version = object
132            .get("version")
133            .and_then(JsonValue::as_i64)
134            .ok_or_else(|| SessionError::Format("missing version".to_string()))?;
135        let version = u32::try_from(version)
136            .map_err(|_| SessionError::Format("version out of range".to_string()))?;
137        let messages = object
138            .get("messages")
139            .and_then(JsonValue::as_array)
140            .ok_or_else(|| SessionError::Format("missing messages".to_string()))?
141            .iter()
142            .map(ConversationMessage::from_json)
143            .collect::<Result<Vec<_>, _>>()?;
144        Ok(Self { version, messages })
145    }
146}
147
148impl Default for Session {
149    fn default() -> Self {
150        Self::new()
151    }
152}
153
154impl ConversationMessage {
155    #[must_use]
156    pub fn user_text(text: impl Into<String>) -> Self {
157        Self {
158            role: MessageRole::User,
159            blocks: vec![ContentBlock::Text { text: text.into() }],
160            usage: None,
161        }
162    }
163
164    #[must_use]
165    pub fn assistant(blocks: Vec<ContentBlock>) -> Self {
166        Self {
167            role: MessageRole::Assistant,
168            blocks,
169            usage: None,
170        }
171    }
172
173    #[must_use]
174    pub fn assistant_with_usage(blocks: Vec<ContentBlock>, usage: Option<TokenUsage>) -> Self {
175        Self {
176            role: MessageRole::Assistant,
177            blocks,
178            usage,
179        }
180    }
181
182    #[must_use]
183    pub fn tool_result(
184        tool_use_id: impl Into<String>,
185        tool_name: impl Into<String>,
186        output: impl Into<String>,
187        is_error: bool,
188    ) -> Self {
189        Self {
190            role: MessageRole::Tool,
191            blocks: vec![ContentBlock::ToolResult {
192                tool_use_id: tool_use_id.into(),
193                tool_name: tool_name.into(),
194                output: output.into(),
195                is_error,
196            }],
197            usage: None,
198        }
199    }
200
201    #[must_use]
202    pub fn to_json(&self) -> JsonValue {
203        let mut object = BTreeMap::new();
204        object.insert(
205            "role".to_string(),
206            JsonValue::String(
207                match self.role {
208                    MessageRole::System => "system",
209                    MessageRole::User => "user",
210                    MessageRole::Assistant => "assistant",
211                    MessageRole::Tool => "tool",
212                }
213                .to_string(),
214            ),
215        );
216        object.insert(
217            "blocks".to_string(),
218            JsonValue::Array(self.blocks.iter().map(ContentBlock::to_json).collect()),
219        );
220        if let Some(usage) = self.usage {
221            object.insert("usage".to_string(), usage_to_json(usage));
222        }
223        JsonValue::Object(object)
224    }
225
226    fn from_json(value: &JsonValue) -> Result<Self, SessionError> {
227        let object = value
228            .as_object()
229            .ok_or_else(|| SessionError::Format("message must be an object".to_string()))?;
230        let role = match object
231            .get("role")
232            .and_then(JsonValue::as_str)
233            .ok_or_else(|| SessionError::Format("missing role".to_string()))?
234        {
235            "system" => MessageRole::System,
236            "user" => MessageRole::User,
237            "assistant" => MessageRole::Assistant,
238            "tool" => MessageRole::Tool,
239            other => {
240                return Err(SessionError::Format(format!(
241                    "unsupported message role: {other}"
242                )))
243            }
244        };
245        let blocks = object
246            .get("blocks")
247            .and_then(JsonValue::as_array)
248            .ok_or_else(|| SessionError::Format("missing blocks".to_string()))?
249            .iter()
250            .map(ContentBlock::from_json)
251            .collect::<Result<Vec<_>, _>>()?;
252        let usage = object.get("usage").map(usage_from_json).transpose()?;
253        Ok(Self {
254            role,
255            blocks,
256            usage,
257        })
258    }
259}
260
261impl ContentBlock {
262    #[must_use]
263    pub fn to_json(&self) -> JsonValue {
264        let mut object = BTreeMap::new();
265        match self {
266            Self::Text { text } => {
267                object.insert("type".to_string(), JsonValue::String("text".to_string()));
268                object.insert("text".to_string(), JsonValue::String(text.clone()));
269            }
270            Self::ToolUse { id, name, input } => {
271                object.insert(
272                    "type".to_string(),
273                    JsonValue::String("tool_use".to_string()),
274                );
275                object.insert("id".to_string(), JsonValue::String(id.clone()));
276                object.insert("name".to_string(), JsonValue::String(name.clone()));
277                object.insert("input".to_string(), JsonValue::String(input.clone()));
278            }
279            Self::ToolResult {
280                tool_use_id,
281                tool_name,
282                output,
283                is_error,
284            } => {
285                object.insert(
286                    "type".to_string(),
287                    JsonValue::String("tool_result".to_string()),
288                );
289                object.insert(
290                    "tool_use_id".to_string(),
291                    JsonValue::String(tool_use_id.clone()),
292                );
293                object.insert(
294                    "tool_name".to_string(),
295                    JsonValue::String(tool_name.clone()),
296                );
297                object.insert("output".to_string(), JsonValue::String(output.clone()));
298                object.insert("is_error".to_string(), JsonValue::Bool(*is_error));
299            }
300            // Store only the mime type — base64 data is not persisted.
301            Self::Image { media_type, .. } => {
302                object.insert("type".to_string(), JsonValue::String("text".to_string()));
303                object.insert(
304                    "text".to_string(),
305                    JsonValue::String(format!("[image: {media_type}]")),
306                );
307            }
308            Self::Thinking { text, thought_signature } => {
309                object.insert("type".to_string(), JsonValue::String("thinking".to_string()));
310                object.insert("text".to_string(), JsonValue::String(text.clone()));
311                if let Some(sig) = thought_signature {
312                    object.insert("thought_signature".to_string(), JsonValue::String(sig.clone()));
313                }
314            }
315        }
316        JsonValue::Object(object)
317    }
318
319    fn from_json(value: &JsonValue) -> Result<Self, SessionError> {
320        let object = value
321            .as_object()
322            .ok_or_else(|| SessionError::Format("block must be an object".to_string()))?;
323        match object
324            .get("type")
325            .and_then(JsonValue::as_str)
326            .ok_or_else(|| SessionError::Format("missing block type".to_string()))?
327        {
328            "text" => Ok(Self::Text {
329                text: required_string(object, "text")?,
330            }),
331            "tool_use" => Ok(Self::ToolUse {
332                id: required_string(object, "id")?,
333                name: required_string(object, "name")?,
334                input: required_string(object, "input")?,
335            }),
336            "tool_result" => Ok(Self::ToolResult {
337                tool_use_id: required_string(object, "tool_use_id")?,
338                tool_name: required_string(object, "tool_name")?,
339                output: required_string(object, "output")?,
340                is_error: object
341                    .get("is_error")
342                    .and_then(JsonValue::as_bool)
343                    .ok_or_else(|| SessionError::Format("missing is_error".to_string()))?,
344            }),
345            "thinking" => Ok(Self::Thinking {
346                text: required_string(object, "text")?,
347                thought_signature: object.get("thought_signature").and_then(JsonValue::as_str).map(String::from),
348            }),
349            other => Err(SessionError::Format(format!(
350                "unsupported block type: {other}"
351            ))),
352        }
353    }
354}
355
356fn usage_to_json(usage: TokenUsage) -> JsonValue {
357    let mut object = BTreeMap::new();
358    object.insert(
359        "input_tokens".to_string(),
360        JsonValue::Number(i64::from(usage.input_tokens)),
361    );
362    object.insert(
363        "output_tokens".to_string(),
364        JsonValue::Number(i64::from(usage.output_tokens)),
365    );
366    object.insert(
367        "cache_creation_input_tokens".to_string(),
368        JsonValue::Number(i64::from(usage.cache_creation_input_tokens)),
369    );
370    object.insert(
371        "cache_read_input_tokens".to_string(),
372        JsonValue::Number(i64::from(usage.cache_read_input_tokens)),
373    );
374    JsonValue::Object(object)
375}
376
377fn usage_from_json(value: &JsonValue) -> Result<TokenUsage, SessionError> {
378    let object = value
379        .as_object()
380        .ok_or_else(|| SessionError::Format("usage must be an object".to_string()))?;
381    Ok(TokenUsage {
382        input_tokens: required_u32(object, "input_tokens")?,
383        output_tokens: required_u32(object, "output_tokens")?,
384        cache_creation_input_tokens: required_u32(object, "cache_creation_input_tokens")?,
385        cache_read_input_tokens: required_u32(object, "cache_read_input_tokens")?,
386    })
387}
388
389fn required_string(
390    object: &BTreeMap<String, JsonValue>,
391    key: &str,
392) -> Result<String, SessionError> {
393    object
394        .get(key)
395        .and_then(JsonValue::as_str)
396        .map(ToOwned::to_owned)
397        .ok_or_else(|| SessionError::Format(format!("missing {key}")))
398}
399
400fn required_u32(object: &BTreeMap<String, JsonValue>, key: &str) -> Result<u32, SessionError> {
401    let value = object
402        .get(key)
403        .and_then(JsonValue::as_i64)
404        .ok_or_else(|| SessionError::Format(format!("missing {key}")))?;
405    u32::try_from(value).map_err(|_| SessionError::Format(format!("{key} out of range")))
406}
407
408#[cfg(test)]
409mod tests {
410    use super::{ContentBlock, ConversationMessage, MessageRole, Session};
411    use crate::usage::TokenUsage;
412    use std::fs;
413    use std::time::{SystemTime, UNIX_EPOCH};
414
415    #[test]
416    fn persists_and_restores_session_json() {
417        let mut session = Session::new();
418        session
419            .messages
420            .push(ConversationMessage::user_text("hello"));
421        session
422            .messages
423            .push(ConversationMessage::assistant_with_usage(
424                vec![
425                    ContentBlock::Text {
426                        text: "thinking".to_string(),
427                    },
428                    ContentBlock::ToolUse {
429                        id: "tool-1".to_string(),
430                        name: "bash".to_string(),
431                        input: "echo hi".to_string(),
432                    },
433                ],
434                Some(TokenUsage {
435                    input_tokens: 10,
436                    output_tokens: 4,
437                    cache_creation_input_tokens: 1,
438                    cache_read_input_tokens: 2,
439                }),
440            ));
441        session.messages.push(ConversationMessage::tool_result(
442            "tool-1", "bash", "hi", false,
443        ));
444
445        let nanos = SystemTime::now()
446            .duration_since(UNIX_EPOCH)
447            .expect("system time should be after epoch")
448            .as_nanos();
449        let path = std::env::temp_dir().join(format!("runtime-session-{nanos}.json"));
450        session.save_to_path(&path).expect("session should save");
451        let restored = Session::load_from_path(&path).expect("session should load");
452        fs::remove_file(&path).expect("temp file should be removable");
453
454        assert_eq!(restored, session);
455        assert_eq!(restored.messages[2].role, MessageRole::Tool);
456        assert_eq!(
457            restored.messages[1].usage.expect("usage").total_tokens(),
458            17
459        );
460    }
461}