Skip to main content

lash_sansio/session_model/
mod.rs

1pub mod message;
2pub mod prompt;
3
4pub use message::{
5    BaseRenderCache, Message, MessageRole, MessageSequence, Part, PartAttachment, PartKind,
6    PruneState, RenderedPrompt, append_rendered_prompt, messages_are_prompt_resume_safe,
7    render_prompt, render_transcript_prompt, shared_parts,
8};
9pub use prompt::{
10    MAIN_AGENT_INTRO, PromptBuiltin, PromptLayer, PromptSlot, PromptSlotLayer, PromptTemplate,
11    PromptTemplateEntry, PromptTemplateSection, ResolvedPromptLayer, default_prompt_template,
12    resolve_prompt_layers,
13};
14
15use std::sync::Arc;
16
17use crate::MessageOrigin;
18use crate::ToolDefinition;
19use crate::llm::types::LlmToolSpec;
20use crate::plugin::{CheckpointKind, PluginMessage, PluginRuntimeEvent};
21
22#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
23#[allow(clippy::large_enum_variant)]
24pub enum SessionEventRecord<PE = ()> {
25    Conversation(ConversationRecord),
26    Protocol(PE),
27}
28
29#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
30pub struct ConversationRecord {
31    pub id: String,
32    pub role: MessageRole,
33    pub parts: Arc<Vec<Part>>,
34    #[serde(default, skip_serializing_if = "Option::is_none")]
35    pub origin: Option<MessageOrigin>,
36}
37
38#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
39pub struct AcceptedInjectedTurnInput {
40    #[serde(default, skip_serializing_if = "Option::is_none")]
41    pub id: Option<String>,
42    pub message: PluginMessage,
43}
44
45impl ConversationRecord {
46    pub fn from_message(message: Message) -> Self {
47        Self {
48            id: message.id,
49            role: message.role,
50            parts: message.parts,
51            origin: message.origin,
52        }
53    }
54
55    pub fn to_message(&self) -> Message {
56        Message {
57            id: self.id.clone(),
58            role: self.role,
59            parts: Arc::clone(&self.parts),
60            origin: self.origin.clone(),
61        }
62    }
63}
64
65/// Token usage statistics from an LLM call.
66#[derive(Clone, Debug, Default, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
67pub struct TokenUsage {
68    pub input_tokens: i64,
69    pub output_tokens: i64,
70    pub cache_read_input_tokens: i64,
71    pub cache_write_input_tokens: i64,
72    pub reasoning_output_tokens: i64,
73}
74
75impl TokenUsage {
76    pub fn total(&self) -> i64 {
77        self.input_tokens
78            + self.output_tokens
79            + self.cache_read_input_tokens
80            + self.cache_write_input_tokens
81    }
82
83    pub fn input_total(&self) -> i64 {
84        self.input_tokens + self.cache_read_input_tokens + self.cache_write_input_tokens
85    }
86
87    pub fn add(&mut self, other: &TokenUsage) {
88        self.input_tokens += other.input_tokens;
89        self.output_tokens += other.output_tokens;
90        self.cache_read_input_tokens += other.cache_read_input_tokens;
91        self.cache_write_input_tokens += other.cache_write_input_tokens;
92        self.reasoning_output_tokens += other.reasoning_output_tokens;
93    }
94}
95
96/// Structured error payload carried on [`SessionEvent::Error`] (and
97/// [`SessionEvent::RetryStatus`]).
98///
99/// Durability: this type appears inside persisted session snapshots and turn
100/// checkpoints, so every field added after the initial shape must stay
101/// additive — `#[serde(default)]` on decode and
102/// `#[serde(skip_serializing_if = "Option::is_none")]` on encode — to keep
103/// old snapshots decodable and new snapshots readable by older readers.
104#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
105pub struct ErrorEnvelope {
106    pub kind: String,
107    #[serde(default, skip_serializing_if = "Option::is_none")]
108    pub code: Option<String>,
109    #[serde(default, skip_serializing_if = "Option::is_none")]
110    pub terminal_reason: Option<crate::llm::types::LlmTerminalReason>,
111    pub user_message: String,
112    #[serde(default, skip_serializing_if = "Option::is_none")]
113    pub raw: Option<String>,
114    /// Whether the failing operation is safe to retry, when the source
115    /// carries a typed signal (provider transports classify retryability).
116    /// `None` means the source did not know.
117    #[serde(default, skip_serializing_if = "Option::is_none")]
118    pub retryable: Option<bool>,
119    /// Typed provider-failure classification, set only when the error came
120    /// from an LLM provider/transport failure whose kind was classified
121    /// (an unclassified `Unknown` kind is surfaced as `None`).
122    #[serde(default, skip_serializing_if = "Option::is_none")]
123    pub provider_failure_kind: Option<crate::llm::types::ProviderFailureKind>,
124}
125
126#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
127#[serde(tag = "type")]
128#[allow(clippy::large_enum_variant)]
129pub enum SessionEvent {
130    #[serde(rename = "text_delta")]
131    TextDelta { content: String },
132    /// Streaming update for the model's reasoning summary ("thinking").
133    /// The UI renders these incrementally in a muted/italic style;
134    /// reasoning content is never fed back to the model on subsequent
135    /// turns.
136    #[serde(rename = "reasoning_delta")]
137    ReasoningDelta { content: String },
138    #[serde(rename = "tool_call")]
139    ToolCall {
140        #[serde(default, skip_serializing_if = "Option::is_none")]
141        call_id: Option<String>,
142        name: String,
143        args: serde_json::Value,
144        output: crate::ToolCallOutput,
145        duration_ms: u64,
146    },
147    #[serde(rename = "tool_call_start")]
148    ToolCallStart {
149        #[serde(default, skip_serializing_if = "Option::is_none")]
150        call_id: Option<String>,
151        name: String,
152        args: serde_json::Value,
153    },
154    #[serde(rename = "message")]
155    Message { text: String, kind: String },
156    #[serde(rename = "llm_request")]
157    LlmRequest {
158        protocol_iteration: usize,
159        message_count: usize,
160        tool_list: String,
161    },
162    #[serde(rename = "llm_response")]
163    LlmResponse {
164        protocol_iteration: usize,
165        content: String,
166        duration_ms: u64,
167    },
168    #[serde(rename = "token_usage")]
169    TokenUsage {
170        protocol_iteration: usize,
171        usage: TokenUsage,
172        cumulative: TokenUsage,
173    },
174    #[serde(rename = "child_token_usage")]
175    ChildTokenUsage {
176        session_id: String,
177        source: String,
178        model: String,
179        protocol_iteration: usize,
180        usage: TokenUsage,
181        cumulative: TokenUsage,
182    },
183    #[serde(rename = "retry_status")]
184    RetryStatus {
185        wait_seconds: u64,
186        attempt: usize,
187        max_attempts: usize,
188        reason: String,
189        #[serde(default, skip_serializing_if = "Option::is_none")]
190        envelope: Option<ErrorEnvelope>,
191    },
192    #[serde(rename = "injected_turn_input_accepted")]
193    InjectedTurnInputAccepted {
194        inputs: Vec<AcceptedInjectedTurnInput>,
195        checkpoint: CheckpointKind,
196    },
197    #[serde(rename = "injected_messages_committed")]
198    InjectedMessagesCommitted {
199        messages: Vec<PluginMessage>,
200        checkpoint: CheckpointKind,
201    },
202    #[serde(rename = "plugin_event")]
203    PluginEvent {
204        plugin_id: String,
205        event: PluginRuntimeEvent,
206    },
207    /// Semantic result for a completed turn. `Done` remains the machine
208    /// lifecycle marker emitted after this event.
209    #[serde(rename = "turn_outcome")]
210    TurnOutcome { outcome: TurnOutcome },
211    #[serde(rename = "done")]
212    Done,
213    #[serde(rename = "error")]
214    Error {
215        message: String,
216        #[serde(default, skip_serializing_if = "Option::is_none")]
217        envelope: Option<ErrorEnvelope>,
218    },
219}
220
221#[derive(Clone, Debug, PartialEq, serde::Serialize, serde::Deserialize)]
222#[serde(rename_all = "snake_case")]
223pub enum TurnOutcome {
224    Finished(TurnFinish),
225    AgentFrameSwitch { frame_id: String, task: String },
226    Stopped(TurnStop),
227}
228
229#[derive(Clone, Debug, PartialEq, serde::Serialize, serde::Deserialize)]
230#[serde(rename_all = "snake_case")]
231pub enum TurnFinish {
232    AssistantMessage {
233        text: String,
234    },
235    FinalValue {
236        value: serde_json::Value,
237    },
238    ToolValue {
239        tool_name: String,
240        value: serde_json::Value,
241    },
242}
243
244#[derive(Clone, Debug, PartialEq, serde::Serialize, serde::Deserialize)]
245#[serde(rename_all = "snake_case")]
246pub enum TurnStop {
247    Cancelled,
248    Incomplete,
249    InvalidInput,
250    MaxTurns,
251    ToolFailure,
252    ProviderError,
253    PluginAbort,
254    RuntimeError,
255    SubmittedError {
256        value: serde_json::Value,
257    },
258    ToolError {
259        tool_name: String,
260        value: serde_json::Value,
261    },
262}
263
264#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
265pub struct TurnTerminationPolicyState {
266    turn_limit_final_scheduled: bool,
267}
268
269impl Default for TurnTerminationPolicyState {
270    fn default() -> Self {
271        Self::new()
272    }
273}
274
275impl TurnTerminationPolicyState {
276    pub fn new() -> Self {
277        Self {
278            turn_limit_final_scheduled: false,
279        }
280    }
281
282    pub fn should_force_exit_after_grace_turn(&self) -> bool {
283        self.turn_limit_final_scheduled
284    }
285
286    pub fn turn_limit_final_to_schedule(
287        &self,
288        protocol_iteration: usize,
289        protocol_run_offset: usize,
290        max_turns: Option<usize>,
291    ) -> Option<usize> {
292        if self.turn_limit_final_scheduled {
293            return None;
294        }
295        let max = max_turns?;
296        if protocol_iteration < protocol_run_offset + max {
297            return None;
298        }
299        Some(max)
300    }
301
302    pub fn mark_turn_limit_final_scheduled(&mut self) {
303        self.turn_limit_final_scheduled = true;
304    }
305}
306
307pub fn make_error_envelope(
308    kind: &str,
309    code: Option<&str>,
310    terminal_reason: Option<crate::llm::types::LlmTerminalReason>,
311    user_message: impl Into<String>,
312    raw: Option<String>,
313) -> ErrorEnvelope {
314    let user_message = user_message.into();
315    ErrorEnvelope {
316        kind: kind.to_string(),
317        code: code.map(str::to_string),
318        terminal_reason,
319        user_message,
320        raw: raw.map(|s| truncate_raw_error(s.trim())),
321        retryable: None,
322        provider_failure_kind: None,
323    }
324}
325
326pub fn make_error_event(
327    kind: &str,
328    code: Option<&str>,
329    user_message: impl Into<String>,
330    raw: Option<String>,
331) -> SessionEvent {
332    let user_message = user_message.into();
333    SessionEvent::Error {
334        message: user_message.clone(),
335        envelope: Some(make_error_envelope(kind, code, None, user_message, raw)),
336    }
337}
338
339pub fn truncate_raw_error(s: &str) -> String {
340    const MAX_RAW: usize = 4000;
341    let raw_len = s.chars().count();
342    if raw_len <= MAX_RAW {
343        return s.to_string();
344    }
345    let keep = MAX_RAW / 2;
346    let head = s.chars().take(keep).collect::<String>();
347    let tail = s
348        .chars()
349        .rev()
350        .take(keep)
351        .collect::<Vec<_>>()
352        .into_iter()
353        .rev()
354        .collect::<String>();
355    let omitted = raw_len.saturating_sub(keep * 2);
356    format!("{head}\n\n... ({omitted} chars omitted) ...\n\n{tail}")
357}
358
359pub fn reassign_part_ids(message_id: &str, parts: &mut [Part]) {
360    for (idx, part) in parts.iter_mut().enumerate() {
361        part.id = format!("{message_id}.p{idx}");
362    }
363}
364
365pub fn model_tool_specs_iter<'a>(
366    tools: impl IntoIterator<Item = &'a ToolDefinition>,
367) -> Vec<LlmToolSpec> {
368    tools
369        .into_iter()
370        .map(|tool| {
371            let model_tool = tool.model_tool();
372            LlmToolSpec {
373                name: model_tool.name,
374                description: model_tool.description,
375                input_schema: model_tool.input_schema,
376                output_schema: model_tool.output_schema,
377            }
378        })
379        .collect()
380}
381
382pub fn model_tool_specs(tools: &[ToolDefinition]) -> Vec<LlmToolSpec> {
383    model_tool_specs_iter(tools.iter())
384}
385
386#[cfg(test)]
387mod tests {
388    use super::{ErrorEnvelope, SessionEvent};
389    use crate::llm::types::{LlmTerminalReason, ProviderFailureKind};
390
391    // ─── ErrorEnvelope durable-snapshot compatibility ──────────────────
392    //
393    // `ErrorEnvelope` is persisted inside session snapshots and turn
394    // checkpoints. The retryability fields added after the initial shape
395    // must decode from legacy JSON (absent fields → `None`) and must not
396    // appear on the wire when unset, so old readers keep decoding new
397    // snapshots too.
398
399    #[test]
400    fn error_envelope_decodes_legacy_snapshot_without_retryability_fields() {
401        let legacy = r#"{
402            "kind":"llm_provider",
403            "code":"429",
404            "terminal_reason":"provider_error",
405            "user_message":"LLM error: rate limited",
406            "raw":"{\"error\":\"rate_limited\"}"
407        }"#;
408        let envelope: ErrorEnvelope = serde_json::from_str(legacy).expect("legacy envelope");
409        assert_eq!(envelope.kind, "llm_provider");
410        assert_eq!(envelope.retryable, None);
411        assert_eq!(envelope.provider_failure_kind, None);
412
413        // The legacy shape embedded in a persisted `SessionEvent::Error`
414        // record decodes the same way.
415        let legacy_event = r#"{
416            "type":"error",
417            "message":"LLM error: rate limited",
418            "envelope":{"kind":"llm_provider","user_message":"LLM error: rate limited"}
419        }"#;
420        let event: SessionEvent = serde_json::from_str(legacy_event).expect("legacy event");
421        match event {
422            SessionEvent::Error { envelope, .. } => {
423                let envelope = envelope.expect("envelope");
424                assert_eq!(envelope.retryable, None);
425                assert_eq!(envelope.provider_failure_kind, None);
426            }
427            other => panic!("expected error event, got {other:?}"),
428        }
429    }
430
431    #[test]
432    fn error_envelope_roundtrips_retryability_fields() {
433        let envelope = ErrorEnvelope {
434            kind: "llm_provider".to_string(),
435            code: Some("429".to_string()),
436            terminal_reason: Some(LlmTerminalReason::ProviderError),
437            user_message: "LLM error: rate limited".to_string(),
438            raw: None,
439            retryable: Some(true),
440            provider_failure_kind: Some(ProviderFailureKind::Quota),
441        };
442        let json = serde_json::to_value(&envelope).expect("serialize envelope");
443        assert_eq!(json["retryable"], serde_json::json!(true));
444        assert_eq!(json["provider_failure_kind"], serde_json::json!("quota"));
445        let decoded: ErrorEnvelope = serde_json::from_value(json).expect("decode envelope");
446        assert_eq!(decoded.retryable, Some(true));
447        assert_eq!(
448            decoded.provider_failure_kind,
449            Some(ProviderFailureKind::Quota)
450        );
451    }
452
453    #[test]
454    fn error_envelope_omits_unset_retryability_fields_on_the_wire() {
455        let envelope = ErrorEnvelope {
456            kind: "plugin".to_string(),
457            code: Some("plugin_abort".to_string()),
458            terminal_reason: None,
459            user_message: "stopped".to_string(),
460            raw: None,
461            retryable: None,
462            provider_failure_kind: None,
463        };
464        let json = serde_json::to_value(&envelope).expect("serialize envelope");
465        let object = json.as_object().expect("object");
466        assert!(!object.contains_key("retryable"));
467        assert!(!object.contains_key("provider_failure_kind"));
468    }
469
470    #[test]
471    fn provider_failure_kind_decodes_unknown_future_codes() {
472        // Forward compatibility: a snapshot written by a newer runtime with a
473        // kind this build does not know decodes as `Unknown`.
474        let decoded: ProviderFailureKind =
475            serde_json::from_value(serde_json::json!("some_future_kind")).expect("future kind");
476        assert_eq!(decoded, ProviderFailureKind::Unknown);
477        for kind in [
478            ProviderFailureKind::Transport,
479            ProviderFailureKind::Timeout,
480            ProviderFailureKind::Http,
481            ProviderFailureKind::Stream,
482            ProviderFailureKind::Auth,
483            ProviderFailureKind::Validation,
484            ProviderFailureKind::Quota,
485            ProviderFailureKind::Unsupported,
486            ProviderFailureKind::Unknown,
487        ] {
488            let json = serde_json::to_value(kind).expect("serialize kind");
489            assert_eq!(json, serde_json::json!(kind.code()));
490            let round: ProviderFailureKind = serde_json::from_value(json).expect("decode kind");
491            assert_eq!(round, kind);
492        }
493    }
494}