1pub mod message;
2pub mod prompt;
3
4pub use message::{
5 BaseRenderCache, Message, MessageRole, MessageSequence, Part, PartAttachment, PartKind,
6 PruneState, RenderedPrompt, append_rendered_prompt, messages_are_prompt_resume_safe,
7 render_prompt, render_transcript_prompt, shared_parts,
8};
9pub use prompt::{
10 MAIN_AGENT_INTRO, PromptBuiltin, PromptLayer, PromptSlot, PromptSlotLayer, PromptTemplate,
11 PromptTemplateEntry, PromptTemplateSection, ResolvedPromptLayer, default_prompt_template,
12 resolve_prompt_layers,
13};
14
15use std::sync::Arc;
16
17use crate::MessageOrigin;
18use crate::ToolDefinition;
19use crate::llm::types::LlmToolSpec;
20use crate::plugin::{CheckpointKind, PluginMessage, PluginRuntimeEvent};
21
22#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
23#[allow(clippy::large_enum_variant)]
24pub enum SessionEventRecord<PE = ()> {
25 Conversation(ConversationRecord),
26 Protocol(PE),
27}
28
29#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
30pub struct ConversationRecord {
31 pub id: String,
32 pub role: MessageRole,
33 pub parts: Arc<Vec<Part>>,
34 #[serde(default, skip_serializing_if = "Option::is_none")]
35 pub origin: Option<MessageOrigin>,
36}
37
38#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
39pub struct AcceptedInjectedTurnInput {
40 #[serde(default, skip_serializing_if = "Option::is_none")]
41 pub id: Option<String>,
42 pub message: PluginMessage,
43}
44
45impl ConversationRecord {
46 pub fn from_message(message: Message) -> Self {
47 Self {
48 id: message.id,
49 role: message.role,
50 parts: message.parts,
51 origin: message.origin,
52 }
53 }
54
55 pub fn to_message(&self) -> Message {
56 Message {
57 id: self.id.clone(),
58 role: self.role,
59 parts: Arc::clone(&self.parts),
60 origin: self.origin.clone(),
61 }
62 }
63}
64
65#[derive(Clone, Debug, Default, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
67pub struct TokenUsage {
68 pub input_tokens: i64,
69 pub output_tokens: i64,
70 pub cache_read_input_tokens: i64,
71 pub cache_write_input_tokens: i64,
72 pub reasoning_output_tokens: i64,
73}
74
75impl TokenUsage {
76 pub fn total(&self) -> i64 {
77 self.input_tokens
78 + self.output_tokens
79 + self.cache_read_input_tokens
80 + self.cache_write_input_tokens
81 }
82
83 pub fn input_total(&self) -> i64 {
84 self.input_tokens + self.cache_read_input_tokens + self.cache_write_input_tokens
85 }
86
87 pub fn add(&mut self, other: &TokenUsage) {
88 self.input_tokens += other.input_tokens;
89 self.output_tokens += other.output_tokens;
90 self.cache_read_input_tokens += other.cache_read_input_tokens;
91 self.cache_write_input_tokens += other.cache_write_input_tokens;
92 self.reasoning_output_tokens += other.reasoning_output_tokens;
93 }
94}
95
96#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
105pub struct ErrorEnvelope {
106 pub kind: String,
107 #[serde(default, skip_serializing_if = "Option::is_none")]
108 pub code: Option<String>,
109 #[serde(default, skip_serializing_if = "Option::is_none")]
110 pub terminal_reason: Option<crate::llm::types::LlmTerminalReason>,
111 pub user_message: String,
112 #[serde(default, skip_serializing_if = "Option::is_none")]
113 pub raw: Option<String>,
114 #[serde(default, skip_serializing_if = "Option::is_none")]
118 pub retryable: Option<bool>,
119 #[serde(default, skip_serializing_if = "Option::is_none")]
123 pub provider_failure_kind: Option<crate::llm::types::ProviderFailureKind>,
124}
125
126#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
127#[serde(tag = "type")]
128#[allow(clippy::large_enum_variant)]
129pub enum SessionEvent {
130 #[serde(rename = "text_delta")]
131 TextDelta { content: String },
132 #[serde(rename = "reasoning_delta")]
137 ReasoningDelta { content: String },
138 #[serde(rename = "tool_call")]
139 ToolCall {
140 #[serde(default, skip_serializing_if = "Option::is_none")]
141 call_id: Option<String>,
142 name: String,
143 args: serde_json::Value,
144 output: crate::ToolCallOutput,
145 duration_ms: u64,
146 },
147 #[serde(rename = "tool_call_start")]
148 ToolCallStart {
149 #[serde(default, skip_serializing_if = "Option::is_none")]
150 call_id: Option<String>,
151 name: String,
152 args: serde_json::Value,
153 },
154 #[serde(rename = "message")]
155 Message { text: String, kind: String },
156 #[serde(rename = "llm_request")]
157 LlmRequest {
158 protocol_iteration: usize,
159 message_count: usize,
160 tool_list: String,
161 },
162 #[serde(rename = "llm_response")]
163 LlmResponse {
164 protocol_iteration: usize,
165 content: String,
166 duration_ms: u64,
167 },
168 #[serde(rename = "token_usage")]
169 TokenUsage {
170 protocol_iteration: usize,
171 usage: TokenUsage,
172 cumulative: TokenUsage,
173 },
174 #[serde(rename = "child_token_usage")]
175 ChildTokenUsage {
176 session_id: String,
177 source: String,
178 model: String,
179 protocol_iteration: usize,
180 usage: TokenUsage,
181 cumulative: TokenUsage,
182 },
183 #[serde(rename = "retry_status")]
184 RetryStatus {
185 wait_seconds: u64,
186 attempt: usize,
187 max_attempts: usize,
188 reason: String,
189 #[serde(default, skip_serializing_if = "Option::is_none")]
190 envelope: Option<ErrorEnvelope>,
191 },
192 #[serde(rename = "injected_turn_input_accepted")]
193 InjectedTurnInputAccepted {
194 inputs: Vec<AcceptedInjectedTurnInput>,
195 checkpoint: CheckpointKind,
196 },
197 #[serde(rename = "injected_messages_committed")]
198 InjectedMessagesCommitted {
199 messages: Vec<PluginMessage>,
200 checkpoint: CheckpointKind,
201 },
202 #[serde(rename = "plugin_event")]
203 PluginEvent {
204 plugin_id: String,
205 event: PluginRuntimeEvent,
206 },
207 #[serde(rename = "turn_outcome")]
210 TurnOutcome { outcome: TurnOutcome },
211 #[serde(rename = "done")]
212 Done,
213 #[serde(rename = "error")]
214 Error {
215 message: String,
216 #[serde(default, skip_serializing_if = "Option::is_none")]
217 envelope: Option<ErrorEnvelope>,
218 },
219}
220
221#[derive(Clone, Debug, PartialEq, serde::Serialize, serde::Deserialize)]
222#[serde(rename_all = "snake_case")]
223pub enum TurnOutcome {
224 Finished(TurnFinish),
225 AgentFrameSwitch { frame_id: String, task: String },
226 Stopped(TurnStop),
227}
228
229#[derive(Clone, Debug, PartialEq, serde::Serialize, serde::Deserialize)]
230#[serde(rename_all = "snake_case")]
231pub enum TurnFinish {
232 AssistantMessage {
233 text: String,
234 },
235 FinalValue {
236 value: serde_json::Value,
237 },
238 ToolValue {
239 tool_name: String,
240 value: serde_json::Value,
241 },
242}
243
244#[derive(Clone, Debug, PartialEq, serde::Serialize, serde::Deserialize)]
245#[serde(rename_all = "snake_case")]
246pub enum TurnStop {
247 Cancelled,
248 Incomplete,
249 InvalidInput,
250 MaxTurns,
251 ToolFailure,
252 ProviderError,
253 PluginAbort,
254 RuntimeError,
255 SubmittedError {
256 value: serde_json::Value,
257 },
258 ToolError {
259 tool_name: String,
260 value: serde_json::Value,
261 },
262}
263
264#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
265pub struct TurnTerminationPolicyState {
266 turn_limit_final_scheduled: bool,
267}
268
269impl Default for TurnTerminationPolicyState {
270 fn default() -> Self {
271 Self::new()
272 }
273}
274
275impl TurnTerminationPolicyState {
276 pub fn new() -> Self {
277 Self {
278 turn_limit_final_scheduled: false,
279 }
280 }
281
282 pub fn should_force_exit_after_grace_turn(&self) -> bool {
283 self.turn_limit_final_scheduled
284 }
285
286 pub fn turn_limit_final_to_schedule(
287 &self,
288 protocol_iteration: usize,
289 protocol_run_offset: usize,
290 max_turns: Option<usize>,
291 ) -> Option<usize> {
292 if self.turn_limit_final_scheduled {
293 return None;
294 }
295 let max = max_turns?;
296 if protocol_iteration < protocol_run_offset + max {
297 return None;
298 }
299 Some(max)
300 }
301
302 pub fn mark_turn_limit_final_scheduled(&mut self) {
303 self.turn_limit_final_scheduled = true;
304 }
305}
306
307pub fn make_error_envelope(
308 kind: &str,
309 code: Option<&str>,
310 terminal_reason: Option<crate::llm::types::LlmTerminalReason>,
311 user_message: impl Into<String>,
312 raw: Option<String>,
313) -> ErrorEnvelope {
314 let user_message = user_message.into();
315 ErrorEnvelope {
316 kind: kind.to_string(),
317 code: code.map(str::to_string),
318 terminal_reason,
319 user_message,
320 raw: raw.map(|s| truncate_raw_error(s.trim())),
321 retryable: None,
322 provider_failure_kind: None,
323 }
324}
325
326pub fn make_error_event(
327 kind: &str,
328 code: Option<&str>,
329 user_message: impl Into<String>,
330 raw: Option<String>,
331) -> SessionEvent {
332 let user_message = user_message.into();
333 SessionEvent::Error {
334 message: user_message.clone(),
335 envelope: Some(make_error_envelope(kind, code, None, user_message, raw)),
336 }
337}
338
339pub fn truncate_raw_error(s: &str) -> String {
340 const MAX_RAW: usize = 4000;
341 let raw_len = s.chars().count();
342 if raw_len <= MAX_RAW {
343 return s.to_string();
344 }
345 let keep = MAX_RAW / 2;
346 let head = s.chars().take(keep).collect::<String>();
347 let tail = s
348 .chars()
349 .rev()
350 .take(keep)
351 .collect::<Vec<_>>()
352 .into_iter()
353 .rev()
354 .collect::<String>();
355 let omitted = raw_len.saturating_sub(keep * 2);
356 format!("{head}\n\n... ({omitted} chars omitted) ...\n\n{tail}")
357}
358
359pub fn reassign_part_ids(message_id: &str, parts: &mut [Part]) {
360 for (idx, part) in parts.iter_mut().enumerate() {
361 part.id = format!("{message_id}.p{idx}");
362 }
363}
364
365pub fn model_tool_specs_iter<'a>(
366 tools: impl IntoIterator<Item = &'a ToolDefinition>,
367) -> Vec<LlmToolSpec> {
368 tools
369 .into_iter()
370 .map(|tool| {
371 let model_tool = tool.model_tool();
372 LlmToolSpec {
373 name: model_tool.name,
374 description: model_tool.description,
375 input_schema: model_tool.input_schema,
376 output_schema: model_tool.output_schema,
377 }
378 })
379 .collect()
380}
381
382pub fn model_tool_specs(tools: &[ToolDefinition]) -> Vec<LlmToolSpec> {
383 model_tool_specs_iter(tools.iter())
384}
385
386#[cfg(test)]
387mod tests {
388 use super::{ErrorEnvelope, SessionEvent};
389 use crate::llm::types::{LlmTerminalReason, ProviderFailureKind};
390
391 #[test]
400 fn error_envelope_decodes_legacy_snapshot_without_retryability_fields() {
401 let legacy = r#"{
402 "kind":"llm_provider",
403 "code":"429",
404 "terminal_reason":"provider_error",
405 "user_message":"LLM error: rate limited",
406 "raw":"{\"error\":\"rate_limited\"}"
407 }"#;
408 let envelope: ErrorEnvelope = serde_json::from_str(legacy).expect("legacy envelope");
409 assert_eq!(envelope.kind, "llm_provider");
410 assert_eq!(envelope.retryable, None);
411 assert_eq!(envelope.provider_failure_kind, None);
412
413 let legacy_event = r#"{
416 "type":"error",
417 "message":"LLM error: rate limited",
418 "envelope":{"kind":"llm_provider","user_message":"LLM error: rate limited"}
419 }"#;
420 let event: SessionEvent = serde_json::from_str(legacy_event).expect("legacy event");
421 match event {
422 SessionEvent::Error { envelope, .. } => {
423 let envelope = envelope.expect("envelope");
424 assert_eq!(envelope.retryable, None);
425 assert_eq!(envelope.provider_failure_kind, None);
426 }
427 other => panic!("expected error event, got {other:?}"),
428 }
429 }
430
431 #[test]
432 fn error_envelope_roundtrips_retryability_fields() {
433 let envelope = ErrorEnvelope {
434 kind: "llm_provider".to_string(),
435 code: Some("429".to_string()),
436 terminal_reason: Some(LlmTerminalReason::ProviderError),
437 user_message: "LLM error: rate limited".to_string(),
438 raw: None,
439 retryable: Some(true),
440 provider_failure_kind: Some(ProviderFailureKind::Quota),
441 };
442 let json = serde_json::to_value(&envelope).expect("serialize envelope");
443 assert_eq!(json["retryable"], serde_json::json!(true));
444 assert_eq!(json["provider_failure_kind"], serde_json::json!("quota"));
445 let decoded: ErrorEnvelope = serde_json::from_value(json).expect("decode envelope");
446 assert_eq!(decoded.retryable, Some(true));
447 assert_eq!(
448 decoded.provider_failure_kind,
449 Some(ProviderFailureKind::Quota)
450 );
451 }
452
453 #[test]
454 fn error_envelope_omits_unset_retryability_fields_on_the_wire() {
455 let envelope = ErrorEnvelope {
456 kind: "plugin".to_string(),
457 code: Some("plugin_abort".to_string()),
458 terminal_reason: None,
459 user_message: "stopped".to_string(),
460 raw: None,
461 retryable: None,
462 provider_failure_kind: None,
463 };
464 let json = serde_json::to_value(&envelope).expect("serialize envelope");
465 let object = json.as_object().expect("object");
466 assert!(!object.contains_key("retryable"));
467 assert!(!object.contains_key("provider_failure_kind"));
468 }
469
470 #[test]
471 fn provider_failure_kind_decodes_unknown_future_codes() {
472 let decoded: ProviderFailureKind =
475 serde_json::from_value(serde_json::json!("some_future_kind")).expect("future kind");
476 assert_eq!(decoded, ProviderFailureKind::Unknown);
477 for kind in [
478 ProviderFailureKind::Transport,
479 ProviderFailureKind::Timeout,
480 ProviderFailureKind::Http,
481 ProviderFailureKind::Stream,
482 ProviderFailureKind::Auth,
483 ProviderFailureKind::Validation,
484 ProviderFailureKind::Quota,
485 ProviderFailureKind::Unsupported,
486 ProviderFailureKind::Unknown,
487 ] {
488 let json = serde_json::to_value(kind).expect("serialize kind");
489 assert_eq!(json, serde_json::json!(kind.code()));
490 let round: ProviderFailureKind = serde_json::from_value(json).expect("decode kind");
491 assert_eq!(round, kind);
492 }
493 }
494}