Skip to main content

lash_sansio/session_model/
mod.rs

1pub mod message;
2pub mod prompt;
3
4pub use message::{
5    BaseRenderCache, Message, MessageRole, MessageSequence, Part, PartAttachment, PartKind,
6    PruneState, RenderedPrompt, append_rendered_prompt, messages_are_prompt_resume_safe,
7    render_prompt, render_transcript_prompt, shared_parts,
8};
9pub use prompt::{
10    MAIN_AGENT_INTRO, PromptBuiltin, PromptLayer, PromptSlot, PromptSlotLayer, PromptTemplate,
11    PromptTemplateEntry, PromptTemplateSection, ResolvedPromptLayer, default_prompt_template,
12    resolve_prompt_layers,
13};
14
15use std::sync::Arc;
16
17use crate::ToolDefinition;
18use crate::llm::types::LlmToolSpec;
19use crate::plugin::{CheckpointKind, PluginMessage, PluginRuntimeEvent};
20use crate::{MessageOrigin, ToolCallRecord};
21
22#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
23#[allow(clippy::large_enum_variant)]
24pub enum SessionEventRecord<PE = ()> {
25    Conversation(ConversationRecord),
26    Tool(ToolEvent),
27    Protocol(PE),
28}
29
30#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
31pub struct ConversationRecord {
32    pub id: String,
33    pub role: MessageRole,
34    pub parts: Arc<Vec<Part>>,
35    #[serde(default, skip_serializing_if = "Option::is_none")]
36    pub origin: Option<MessageOrigin>,
37}
38
39#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
40pub struct AcceptedInjectedTurnInput {
41    #[serde(default, skip_serializing_if = "Option::is_none")]
42    pub id: Option<String>,
43    pub message: PluginMessage,
44}
45
46impl ConversationRecord {
47    pub fn from_message(message: Message) -> Self {
48        Self {
49            id: message.id,
50            role: message.role,
51            parts: message.parts,
52            origin: message.origin,
53        }
54    }
55
56    pub fn to_message(&self) -> Message {
57        Message {
58            id: self.id.clone(),
59            role: self.role,
60            parts: Arc::clone(&self.parts),
61            origin: self.origin.clone(),
62        }
63    }
64}
65
66#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
67pub enum ToolEvent {
68    Invocation {
69        stable_key: String,
70        record: ToolCallRecord,
71    },
72}
73
74/// Token usage statistics from an LLM call.
75#[derive(Clone, Debug, Default, serde::Serialize, serde::Deserialize)]
76pub struct TokenUsage {
77    pub input_tokens: i64,
78    pub output_tokens: i64,
79    pub cached_input_tokens: i64,
80    #[serde(default)]
81    pub reasoning_tokens: i64,
82}
83
84impl TokenUsage {
85    pub fn total(&self) -> i64 {
86        self.input_tokens + self.output_tokens + self.reasoning_tokens
87    }
88
89    pub fn add(&mut self, other: &TokenUsage) {
90        self.input_tokens += other.input_tokens;
91        self.output_tokens += other.output_tokens;
92        self.cached_input_tokens += other.cached_input_tokens;
93        self.reasoning_tokens += other.reasoning_tokens;
94    }
95}
96
97#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
98pub struct ErrorEnvelope {
99    pub kind: String,
100    #[serde(default, skip_serializing_if = "Option::is_none")]
101    pub code: Option<String>,
102    #[serde(default, skip_serializing_if = "Option::is_none")]
103    pub terminal_reason: Option<crate::llm::types::LlmTerminalReason>,
104    pub user_message: String,
105    #[serde(default, skip_serializing_if = "Option::is_none")]
106    pub raw: Option<String>,
107}
108
109#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
110#[serde(tag = "type")]
111#[allow(clippy::large_enum_variant)]
112pub enum SessionEvent {
113    #[serde(rename = "text_delta")]
114    TextDelta { content: String },
115    /// Streaming update for the model's reasoning summary ("thinking").
116    /// The UI renders these incrementally in a muted/italic style;
117    /// reasoning content is never fed back to the model on subsequent
118    /// turns.
119    #[serde(rename = "reasoning_delta")]
120    ReasoningDelta { content: String },
121    #[serde(rename = "tool_call")]
122    ToolCall {
123        #[serde(default, skip_serializing_if = "Option::is_none")]
124        call_id: Option<String>,
125        name: String,
126        args: serde_json::Value,
127        output: crate::ToolCallOutput,
128        duration_ms: u64,
129    },
130    #[serde(rename = "tool_call_start")]
131    ToolCallStart {
132        #[serde(default, skip_serializing_if = "Option::is_none")]
133        call_id: Option<String>,
134        name: String,
135        args: serde_json::Value,
136    },
137    #[serde(rename = "message")]
138    Message { text: String, kind: String },
139    #[serde(rename = "llm_request")]
140    LlmRequest {
141        protocol_iteration: usize,
142        message_count: usize,
143        tool_list: String,
144    },
145    #[serde(rename = "llm_response")]
146    LlmResponse {
147        protocol_iteration: usize,
148        content: String,
149        duration_ms: u64,
150    },
151    #[serde(rename = "token_usage")]
152    TokenUsage {
153        protocol_iteration: usize,
154        usage: TokenUsage,
155        cumulative: TokenUsage,
156    },
157    #[serde(rename = "child_token_usage")]
158    ChildTokenUsage {
159        session_id: String,
160        source: String,
161        model: String,
162        protocol_iteration: usize,
163        usage: TokenUsage,
164        cumulative: TokenUsage,
165    },
166    #[serde(rename = "retry_status")]
167    RetryStatus {
168        wait_seconds: u64,
169        attempt: usize,
170        max_attempts: usize,
171        reason: String,
172        #[serde(default, skip_serializing_if = "Option::is_none")]
173        envelope: Option<ErrorEnvelope>,
174    },
175    #[serde(rename = "injected_turn_input_accepted")]
176    InjectedTurnInputAccepted {
177        inputs: Vec<AcceptedInjectedTurnInput>,
178        checkpoint: CheckpointKind,
179    },
180    #[serde(rename = "injected_messages_committed")]
181    InjectedMessagesCommitted {
182        messages: Vec<PluginMessage>,
183        checkpoint: CheckpointKind,
184    },
185    #[serde(rename = "plugin_event")]
186    PluginEvent {
187        plugin_id: String,
188        event: PluginRuntimeEvent,
189    },
190    /// Semantic result for a completed turn. `Done` remains the machine
191    /// lifecycle marker emitted after this event.
192    #[serde(rename = "turn_outcome")]
193    TurnOutcome { outcome: TurnOutcome },
194    #[serde(rename = "done")]
195    Done,
196    #[serde(rename = "error")]
197    Error {
198        message: String,
199        #[serde(default, skip_serializing_if = "Option::is_none")]
200        envelope: Option<ErrorEnvelope>,
201    },
202}
203
204#[derive(Clone, Debug, PartialEq, serde::Serialize, serde::Deserialize)]
205#[serde(rename_all = "snake_case")]
206pub enum TurnOutcome {
207    Finished(TurnFinish),
208    AgentFrameSwitch { frame_id: String },
209    Stopped(TurnStop),
210}
211
212#[derive(Clone, Debug, PartialEq, serde::Serialize, serde::Deserialize)]
213#[serde(rename_all = "snake_case")]
214pub enum TurnFinish {
215    AssistantMessage {
216        text: String,
217    },
218    SubmittedValue {
219        value: serde_json::Value,
220    },
221    ToolValue {
222        tool_name: String,
223        value: serde_json::Value,
224    },
225}
226
227#[derive(Clone, Debug, PartialEq, serde::Serialize, serde::Deserialize)]
228#[serde(rename_all = "snake_case")]
229pub enum TurnStop {
230    Cancelled,
231    Incomplete,
232    InvalidInput,
233    MaxTurns,
234    ToolFailure,
235    ProviderError,
236    PluginAbort,
237    RuntimeError,
238    SubmittedError {
239        value: serde_json::Value,
240    },
241    ToolError {
242        tool_name: String,
243        value: serde_json::Value,
244    },
245}
246
247#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
248pub struct TurnTerminationPolicyState {
249    turn_limit_final_scheduled: bool,
250}
251
252impl Default for TurnTerminationPolicyState {
253    fn default() -> Self {
254        Self::new()
255    }
256}
257
258impl TurnTerminationPolicyState {
259    pub fn new() -> Self {
260        Self {
261            turn_limit_final_scheduled: false,
262        }
263    }
264
265    pub fn should_force_exit_after_grace_turn(&self) -> bool {
266        self.turn_limit_final_scheduled
267    }
268
269    pub fn turn_limit_final_to_schedule(
270        &self,
271        protocol_iteration: usize,
272        protocol_run_offset: usize,
273        max_turns: Option<usize>,
274    ) -> Option<usize> {
275        if self.turn_limit_final_scheduled {
276            return None;
277        }
278        let max = max_turns?;
279        if protocol_iteration < protocol_run_offset + max {
280            return None;
281        }
282        Some(max)
283    }
284
285    pub fn mark_turn_limit_final_scheduled(&mut self) {
286        self.turn_limit_final_scheduled = true;
287    }
288}
289
290pub fn make_error_envelope(
291    kind: &str,
292    code: Option<&str>,
293    terminal_reason: Option<crate::llm::types::LlmTerminalReason>,
294    user_message: impl Into<String>,
295    raw: Option<String>,
296) -> ErrorEnvelope {
297    let user_message = user_message.into();
298    ErrorEnvelope {
299        kind: kind.to_string(),
300        code: code.map(str::to_string),
301        terminal_reason,
302        user_message,
303        raw: raw.map(|s| truncate_raw_error(s.trim())),
304    }
305}
306
307pub fn make_error_event(
308    kind: &str,
309    code: Option<&str>,
310    user_message: impl Into<String>,
311    raw: Option<String>,
312) -> SessionEvent {
313    let user_message = user_message.into();
314    SessionEvent::Error {
315        message: user_message.clone(),
316        envelope: Some(make_error_envelope(kind, code, None, user_message, raw)),
317    }
318}
319
320pub fn truncate_raw_error(s: &str) -> String {
321    const MAX_RAW: usize = 4000;
322    let raw_len = s.chars().count();
323    if raw_len <= MAX_RAW {
324        return s.to_string();
325    }
326    let keep = MAX_RAW / 2;
327    let head = s.chars().take(keep).collect::<String>();
328    let tail = s
329        .chars()
330        .rev()
331        .take(keep)
332        .collect::<Vec<_>>()
333        .into_iter()
334        .rev()
335        .collect::<String>();
336    let omitted = raw_len.saturating_sub(keep * 2);
337    format!("{head}\n\n... ({omitted} chars omitted) ...\n\n{tail}")
338}
339
340pub fn reassign_part_ids(message_id: &str, parts: &mut [Part]) {
341    for (idx, part) in parts.iter_mut().enumerate() {
342        part.id = format!("{message_id}.p{idx}");
343    }
344}
345
346pub fn model_tool_specs_iter<'a>(
347    tools: impl IntoIterator<Item = &'a ToolDefinition>,
348) -> Vec<LlmToolSpec> {
349    tools
350        .into_iter()
351        .map(|tool| {
352            let model_tool = tool.model_tool();
353            LlmToolSpec {
354                name: model_tool.name,
355                description: model_tool.description,
356                input_schema: model_tool.input_schema,
357                output_schema: model_tool.output_schema,
358                input_schema_projections: model_tool.input_schema_projections,
359                output_schema_projections: model_tool.output_schema_projections,
360            }
361        })
362        .collect()
363}
364
365pub fn model_tool_specs(tools: &[ToolDefinition]) -> Vec<LlmToolSpec> {
366    model_tool_specs_iter(tools.iter())
367}