Skip to main content

lash_sansio/session_model/
mod.rs

1pub mod message;
2pub mod prompt;
3
4pub use message::{
5    BaseRenderCache, Message, MessageRole, MessageSequence, Part, PartAttachment, PartKind,
6    PruneState, RenderedPrompt, append_rendered_prompt, messages_are_prompt_resume_safe,
7    render_prompt, render_transcript_prompt, shared_parts,
8};
9pub use prompt::{
10    MAIN_AGENT_INTRO, PromptBuiltin, PromptLayer, PromptSlot, PromptSlotLayer, PromptTemplate,
11    PromptTemplateEntry, PromptTemplateSection, ResolvedPromptLayer, default_prompt_template,
12    resolve_prompt_layers,
13};
14
15use std::sync::Arc;
16
17use crate::MessageOrigin;
18use crate::ToolDefinition;
19use crate::llm::types::LlmToolSpec;
20use crate::plugin::{CheckpointKind, PluginMessage, PluginRuntimeEvent};
21
22#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
23#[allow(clippy::large_enum_variant)]
24pub enum SessionEventRecord<PE = ()> {
25    Conversation(ConversationRecord),
26    Protocol(PE),
27}
28
29#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
30pub struct ConversationRecord {
31    pub id: String,
32    pub role: MessageRole,
33    pub parts: Arc<Vec<Part>>,
34    #[serde(default, skip_serializing_if = "Option::is_none")]
35    pub origin: Option<MessageOrigin>,
36}
37
38#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
39pub struct AcceptedInjectedTurnInput {
40    #[serde(default, skip_serializing_if = "Option::is_none")]
41    pub id: Option<String>,
42    pub message: PluginMessage,
43}
44
45impl ConversationRecord {
46    pub fn from_message(message: Message) -> Self {
47        Self {
48            id: message.id,
49            role: message.role,
50            parts: message.parts,
51            origin: message.origin,
52        }
53    }
54
55    pub fn to_message(&self) -> Message {
56        Message {
57            id: self.id.clone(),
58            role: self.role,
59            parts: Arc::clone(&self.parts),
60            origin: self.origin.clone(),
61        }
62    }
63}
64
65/// Token usage statistics from an LLM call.
66#[derive(Clone, Debug, Default, serde::Serialize, serde::Deserialize)]
67pub struct TokenUsage {
68    pub input_tokens: i64,
69    pub output_tokens: i64,
70    pub cached_input_tokens: i64,
71    #[serde(default)]
72    pub reasoning_tokens: i64,
73}
74
75impl TokenUsage {
76    pub fn total(&self) -> i64 {
77        self.input_tokens + self.output_tokens + self.reasoning_tokens
78    }
79
80    pub fn add(&mut self, other: &TokenUsage) {
81        self.input_tokens += other.input_tokens;
82        self.output_tokens += other.output_tokens;
83        self.cached_input_tokens += other.cached_input_tokens;
84        self.reasoning_tokens += other.reasoning_tokens;
85    }
86}
87
88#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
89pub struct ErrorEnvelope {
90    pub kind: String,
91    #[serde(default, skip_serializing_if = "Option::is_none")]
92    pub code: Option<String>,
93    #[serde(default, skip_serializing_if = "Option::is_none")]
94    pub terminal_reason: Option<crate::llm::types::LlmTerminalReason>,
95    pub user_message: String,
96    #[serde(default, skip_serializing_if = "Option::is_none")]
97    pub raw: Option<String>,
98}
99
100#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
101#[serde(tag = "type")]
102#[allow(clippy::large_enum_variant)]
103pub enum SessionEvent {
104    #[serde(rename = "text_delta")]
105    TextDelta { content: String },
106    /// Streaming update for the model's reasoning summary ("thinking").
107    /// The UI renders these incrementally in a muted/italic style;
108    /// reasoning content is never fed back to the model on subsequent
109    /// turns.
110    #[serde(rename = "reasoning_delta")]
111    ReasoningDelta { content: String },
112    #[serde(rename = "tool_call")]
113    ToolCall {
114        #[serde(default, skip_serializing_if = "Option::is_none")]
115        call_id: Option<String>,
116        name: String,
117        args: serde_json::Value,
118        output: crate::ToolCallOutput,
119        duration_ms: u64,
120    },
121    #[serde(rename = "tool_call_start")]
122    ToolCallStart {
123        #[serde(default, skip_serializing_if = "Option::is_none")]
124        call_id: Option<String>,
125        name: String,
126        args: serde_json::Value,
127    },
128    #[serde(rename = "message")]
129    Message { text: String, kind: String },
130    #[serde(rename = "llm_request")]
131    LlmRequest {
132        protocol_iteration: usize,
133        message_count: usize,
134        tool_list: String,
135    },
136    #[serde(rename = "llm_response")]
137    LlmResponse {
138        protocol_iteration: usize,
139        content: String,
140        duration_ms: u64,
141    },
142    #[serde(rename = "token_usage")]
143    TokenUsage {
144        protocol_iteration: usize,
145        usage: TokenUsage,
146        cumulative: TokenUsage,
147    },
148    #[serde(rename = "child_token_usage")]
149    ChildTokenUsage {
150        session_id: String,
151        source: String,
152        model: String,
153        protocol_iteration: usize,
154        usage: TokenUsage,
155        cumulative: TokenUsage,
156    },
157    #[serde(rename = "retry_status")]
158    RetryStatus {
159        wait_seconds: u64,
160        attempt: usize,
161        max_attempts: usize,
162        reason: String,
163        #[serde(default, skip_serializing_if = "Option::is_none")]
164        envelope: Option<ErrorEnvelope>,
165    },
166    #[serde(rename = "injected_turn_input_accepted")]
167    InjectedTurnInputAccepted {
168        inputs: Vec<AcceptedInjectedTurnInput>,
169        checkpoint: CheckpointKind,
170    },
171    #[serde(rename = "injected_messages_committed")]
172    InjectedMessagesCommitted {
173        messages: Vec<PluginMessage>,
174        checkpoint: CheckpointKind,
175    },
176    #[serde(rename = "plugin_event")]
177    PluginEvent {
178        plugin_id: String,
179        event: PluginRuntimeEvent,
180    },
181    /// Semantic result for a completed turn. `Done` remains the machine
182    /// lifecycle marker emitted after this event.
183    #[serde(rename = "turn_outcome")]
184    TurnOutcome { outcome: TurnOutcome },
185    #[serde(rename = "done")]
186    Done,
187    #[serde(rename = "error")]
188    Error {
189        message: String,
190        #[serde(default, skip_serializing_if = "Option::is_none")]
191        envelope: Option<ErrorEnvelope>,
192    },
193}
194
195#[derive(Clone, Debug, PartialEq, serde::Serialize, serde::Deserialize)]
196#[serde(rename_all = "snake_case")]
197pub enum TurnOutcome {
198    Finished(TurnFinish),
199    AgentFrameSwitch { frame_id: String },
200    Stopped(TurnStop),
201}
202
203#[derive(Clone, Debug, PartialEq, serde::Serialize, serde::Deserialize)]
204#[serde(rename_all = "snake_case")]
205pub enum TurnFinish {
206    AssistantMessage {
207        text: String,
208    },
209    SubmittedValue {
210        value: serde_json::Value,
211    },
212    ToolValue {
213        tool_name: String,
214        value: serde_json::Value,
215    },
216}
217
218#[derive(Clone, Debug, PartialEq, serde::Serialize, serde::Deserialize)]
219#[serde(rename_all = "snake_case")]
220pub enum TurnStop {
221    Cancelled,
222    Incomplete,
223    InvalidInput,
224    MaxTurns,
225    ToolFailure,
226    ProviderError,
227    PluginAbort,
228    RuntimeError,
229    SubmittedError {
230        value: serde_json::Value,
231    },
232    ToolError {
233        tool_name: String,
234        value: serde_json::Value,
235    },
236}
237
238#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
239pub struct TurnTerminationPolicyState {
240    turn_limit_final_scheduled: bool,
241}
242
243impl Default for TurnTerminationPolicyState {
244    fn default() -> Self {
245        Self::new()
246    }
247}
248
249impl TurnTerminationPolicyState {
250    pub fn new() -> Self {
251        Self {
252            turn_limit_final_scheduled: false,
253        }
254    }
255
256    pub fn should_force_exit_after_grace_turn(&self) -> bool {
257        self.turn_limit_final_scheduled
258    }
259
260    pub fn turn_limit_final_to_schedule(
261        &self,
262        protocol_iteration: usize,
263        protocol_run_offset: usize,
264        max_turns: Option<usize>,
265    ) -> Option<usize> {
266        if self.turn_limit_final_scheduled {
267            return None;
268        }
269        let max = max_turns?;
270        if protocol_iteration < protocol_run_offset + max {
271            return None;
272        }
273        Some(max)
274    }
275
276    pub fn mark_turn_limit_final_scheduled(&mut self) {
277        self.turn_limit_final_scheduled = true;
278    }
279}
280
281pub fn make_error_envelope(
282    kind: &str,
283    code: Option<&str>,
284    terminal_reason: Option<crate::llm::types::LlmTerminalReason>,
285    user_message: impl Into<String>,
286    raw: Option<String>,
287) -> ErrorEnvelope {
288    let user_message = user_message.into();
289    ErrorEnvelope {
290        kind: kind.to_string(),
291        code: code.map(str::to_string),
292        terminal_reason,
293        user_message,
294        raw: raw.map(|s| truncate_raw_error(s.trim())),
295    }
296}
297
298pub fn make_error_event(
299    kind: &str,
300    code: Option<&str>,
301    user_message: impl Into<String>,
302    raw: Option<String>,
303) -> SessionEvent {
304    let user_message = user_message.into();
305    SessionEvent::Error {
306        message: user_message.clone(),
307        envelope: Some(make_error_envelope(kind, code, None, user_message, raw)),
308    }
309}
310
311pub fn truncate_raw_error(s: &str) -> String {
312    const MAX_RAW: usize = 4000;
313    let raw_len = s.chars().count();
314    if raw_len <= MAX_RAW {
315        return s.to_string();
316    }
317    let keep = MAX_RAW / 2;
318    let head = s.chars().take(keep).collect::<String>();
319    let tail = s
320        .chars()
321        .rev()
322        .take(keep)
323        .collect::<Vec<_>>()
324        .into_iter()
325        .rev()
326        .collect::<String>();
327    let omitted = raw_len.saturating_sub(keep * 2);
328    format!("{head}\n\n... ({omitted} chars omitted) ...\n\n{tail}")
329}
330
331pub fn reassign_part_ids(message_id: &str, parts: &mut [Part]) {
332    for (idx, part) in parts.iter_mut().enumerate() {
333        part.id = format!("{message_id}.p{idx}");
334    }
335}
336
337pub fn model_tool_specs_iter<'a>(
338    tools: impl IntoIterator<Item = &'a ToolDefinition>,
339) -> Vec<LlmToolSpec> {
340    tools
341        .into_iter()
342        .map(|tool| {
343            let model_tool = tool.model_tool();
344            LlmToolSpec {
345                name: model_tool.name,
346                description: model_tool.description,
347                input_schema: model_tool.input_schema,
348                output_schema: model_tool.output_schema,
349                input_schema_projections: model_tool.input_schema_projections,
350                output_schema_projections: model_tool.output_schema_projections,
351            }
352        })
353        .collect()
354}
355
356pub fn model_tool_specs(tools: &[ToolDefinition]) -> Vec<LlmToolSpec> {
357    model_tool_specs_iter(tools.iter())
358}