use crate::audit::AuditProvenance;
use crate::llm::{ContentBlock, ContentSource};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use time::OffsetDateTime;
use uuid::Uuid;
#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct ThreadId(pub String);
impl ThreadId {
#[must_use]
pub fn new() -> Self {
Self(Uuid::new_v4().to_string())
}
#[must_use]
pub fn from_string(s: impl Into<String>) -> Self {
Self(s.into())
}
}
impl Default for ThreadId {
fn default() -> Self {
Self::new()
}
}
impl std::fmt::Display for ThreadId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
#[derive(Clone, Debug)]
pub struct AgentConfig {
pub max_turns: Option<usize>,
pub max_tokens: Option<u32>,
pub system_prompt: String,
pub model: String,
pub retry: RetryConfig,
pub streaming: bool,
pub tool_timeout_ms: Option<u64>,
}
impl Default for AgentConfig {
fn default() -> Self {
Self {
max_turns: None,
max_tokens: None,
system_prompt: String::new(),
model: String::from("claude-sonnet-4-5-20250929"),
retry: RetryConfig::default(),
streaming: false,
tool_timeout_ms: None,
}
}
}
#[derive(Clone, Debug)]
pub struct RetryConfig {
pub max_retries: u32,
pub base_delay_ms: u64,
pub max_delay_ms: u64,
}
impl Default for RetryConfig {
fn default() -> Self {
Self {
max_retries: 5,
base_delay_ms: 1000,
max_delay_ms: 120_000,
}
}
}
impl RetryConfig {
#[must_use]
pub const fn no_retry() -> Self {
Self {
max_retries: 0,
base_delay_ms: 0,
max_delay_ms: 0,
}
}
#[must_use]
pub const fn fast() -> Self {
Self {
max_retries: 5,
base_delay_ms: 10,
max_delay_ms: 100,
}
}
}
#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
pub struct TokenUsage {
pub input_tokens: u32,
pub output_tokens: u32,
#[serde(default)]
pub cached_input_tokens: u32,
#[serde(default)]
pub cache_creation_input_tokens: u32,
}
impl TokenUsage {
pub const fn add(&mut self, other: &Self) {
self.input_tokens = self.input_tokens.saturating_add(other.input_tokens);
self.output_tokens = self.output_tokens.saturating_add(other.output_tokens);
self.cached_input_tokens = self
.cached_input_tokens
.saturating_add(other.cached_input_tokens);
self.cache_creation_input_tokens = self
.cache_creation_input_tokens
.saturating_add(other.cache_creation_input_tokens);
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ToolResult {
pub success: bool,
pub output: String,
pub data: Option<serde_json::Value>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub documents: Vec<ContentSource>,
pub duration_ms: Option<u64>,
}
impl ToolResult {
#[must_use]
pub fn success(output: impl Into<String>) -> Self {
Self {
success: true,
output: output.into(),
data: None,
documents: Vec::new(),
duration_ms: None,
}
}
#[must_use]
pub fn success_with_data(output: impl Into<String>, data: serde_json::Value) -> Self {
Self {
success: true,
output: output.into(),
data: Some(data),
documents: Vec::new(),
duration_ms: None,
}
}
#[must_use]
pub fn error(message: impl Into<String>) -> Self {
Self {
success: false,
output: message.into(),
data: None,
documents: Vec::new(),
duration_ms: None,
}
}
#[must_use]
pub const fn with_duration(mut self, duration_ms: u64) -> Self {
self.duration_ms = Some(duration_ms);
self
}
#[must_use]
pub fn with_documents(mut self, documents: Vec<ContentSource>) -> Self {
self.documents = documents;
self
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub enum ToolTier {
Observe,
Confirm,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct AgentState {
pub thread_id: ThreadId,
pub turn_count: usize,
pub total_usage: TokenUsage,
pub metadata: HashMap<String, serde_json::Value>,
#[serde(with = "time::serde::rfc3339")]
pub created_at: OffsetDateTime,
}
impl AgentState {
#[must_use]
pub fn new(thread_id: ThreadId) -> Self {
Self {
thread_id,
turn_count: 0,
total_usage: TokenUsage::default(),
metadata: HashMap::new(),
created_at: OffsetDateTime::now_utc(),
}
}
}
#[derive(Debug, Clone)]
pub struct AgentError {
pub message: String,
pub recoverable: bool,
}
impl AgentError {
#[must_use]
pub fn new(message: impl Into<String>, recoverable: bool) -> Self {
Self {
message: message.into(),
recoverable,
}
}
}
impl std::fmt::Display for AgentError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.message)
}
}
impl std::error::Error for AgentError {}
#[derive(Debug)]
#[non_exhaustive]
pub enum AgentRunState {
Done {
total_turns: u32,
total_usage: TokenUsage,
},
Refusal {
total_turns: u32,
total_usage: TokenUsage,
},
Error(AgentError),
AwaitingConfirmation {
tool_call_id: String,
tool_name: String,
display_name: String,
input: serde_json::Value,
description: String,
continuation: Box<ContinuationEnvelope>,
},
Cancelled {
total_turns: u32,
total_usage: TokenUsage,
},
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct PendingToolCallInfo {
pub id: String,
pub name: String,
pub display_name: String,
#[serde(default = "default_pending_tier")]
pub tier: ToolTier,
pub input: serde_json::Value,
#[serde(default)]
pub effective_input: serde_json::Value,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub listen_context: Option<ListenExecutionContext>,
}
const fn default_pending_tier() -> ToolTier {
ToolTier::Confirm
}
#[derive(Clone, Debug)]
pub struct ToolInvocation {
pub tool_call_id: String,
pub tool_name: String,
pub display_name: String,
pub tier: ToolTier,
pub requested_input: serde_json::Value,
pub effective_input: serde_json::Value,
pub listen_context: Option<ListenExecutionContext>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ListenExecutionContext {
pub operation_id: String,
pub revision: u64,
pub snapshot: serde_json::Value,
#[serde(
default,
skip_serializing_if = "Option::is_none",
with = "time::serde::rfc3339::option"
)]
pub expires_at: Option<OffsetDateTime>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct AgentContinuation {
pub thread_id: ThreadId,
pub turn: usize,
pub total_usage: TokenUsage,
pub turn_usage: TokenUsage,
pub pending_tool_calls: Vec<PendingToolCallInfo>,
pub awaiting_index: usize,
pub completed_results: Vec<(String, ToolResult)>,
pub state: AgentState,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub response_id: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub stop_reason: Option<crate::llm::StopReason>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub response_content: Vec<crate::llm::ContentBlock>,
}
pub const CONTINUATION_VERSION: u32 = 1;
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ContinuationEnvelope {
pub version: u32,
pub payload: AgentContinuation,
}
impl ContinuationEnvelope {
#[must_use]
pub const fn wrap(payload: AgentContinuation) -> Self {
Self {
version: CONTINUATION_VERSION,
payload,
}
}
pub fn unwrap_validated(self) -> Result<AgentContinuation, String> {
if self.version != CONTINUATION_VERSION {
return Err(format!(
"Unsupported continuation version {}: expected {}",
self.version, CONTINUATION_VERSION,
));
}
Ok(self.payload)
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ExternalToolResult {
pub tool_call_id: String,
pub result: ToolResult,
}
#[derive(Debug)]
pub enum AgentInput {
Text(String),
Message(Vec<ContentBlock>),
Resume {
continuation: Box<ContinuationEnvelope>,
tool_call_id: String,
confirmed: bool,
rejection_reason: Option<String>,
},
SubmitToolResults {
continuation: Box<ContinuationEnvelope>,
results: Vec<ExternalToolResult>,
},
Continue,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum ToolOutcome {
Success(ToolResult),
Failed(ToolResult),
InProgress {
operation_id: String,
message: String,
},
}
impl ToolOutcome {
#[must_use]
pub fn success(output: impl Into<String>) -> Self {
Self::Success(ToolResult::success(output))
}
#[must_use]
pub fn failed(message: impl Into<String>) -> Self {
Self::Failed(ToolResult::error(message))
}
#[must_use]
pub fn in_progress(operation_id: impl Into<String>, message: impl Into<String>) -> Self {
Self::InProgress {
operation_id: operation_id.into(),
message: message.into(),
}
}
#[must_use]
pub const fn is_in_progress(&self) -> bool {
matches!(self, Self::InProgress { .. })
}
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub enum ExecutionStatus {
InFlight,
Completed,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ToolExecution {
pub tool_call_id: String,
pub thread_id: ThreadId,
pub tool_name: String,
pub display_name: String,
pub input: serde_json::Value,
pub status: ExecutionStatus,
pub result: Option<ToolResult>,
pub operation_id: Option<String>,
#[serde(with = "time::serde::rfc3339")]
pub started_at: OffsetDateTime,
#[serde(with = "time::serde::rfc3339::option")]
pub completed_at: Option<OffsetDateTime>,
}
impl ToolExecution {
#[must_use]
pub fn new_in_flight(
tool_call_id: impl Into<String>,
thread_id: ThreadId,
tool_name: impl Into<String>,
display_name: impl Into<String>,
input: serde_json::Value,
started_at: OffsetDateTime,
) -> Self {
Self {
tool_call_id: tool_call_id.into(),
thread_id,
tool_name: tool_name.into(),
display_name: display_name.into(),
input,
status: ExecutionStatus::InFlight,
result: None,
operation_id: None,
started_at,
completed_at: None,
}
}
pub fn complete(&mut self, result: ToolResult) {
self.status = ExecutionStatus::Completed;
self.result = Some(result);
self.completed_at = Some(OffsetDateTime::now_utc());
}
pub fn set_operation_id(&mut self, operation_id: impl Into<String>) {
self.operation_id = Some(operation_id.into());
}
#[must_use]
pub fn is_in_flight(&self) -> bool {
self.status == ExecutionStatus::InFlight
}
#[must_use]
pub fn is_completed(&self) -> bool {
self.status == ExecutionStatus::Completed
}
}
#[derive(Debug)]
pub enum TurnOutcome {
NeedsMoreTurns {
turn: usize,
turn_usage: TokenUsage,
total_usage: TokenUsage,
summary: TurnSummary,
},
Done {
total_turns: u32,
total_usage: TokenUsage,
summary: TurnSummary,
},
AwaitingConfirmation {
tool_call_id: String,
tool_name: String,
display_name: String,
input: serde_json::Value,
description: String,
continuation: Box<ContinuationEnvelope>,
summary: TurnSummary,
},
Refusal {
total_turns: u32,
total_usage: TokenUsage,
summary: TurnSummary,
},
Cancelled {
total_turns: u32,
total_usage: TokenUsage,
summary: TurnSummary,
},
Error(AgentError),
PendingToolCalls {
turn: usize,
turn_usage: TokenUsage,
total_usage: TokenUsage,
tool_calls: Vec<PendingToolCallInfo>,
continuation: Box<ContinuationEnvelope>,
summary: TurnSummary,
},
}
impl TurnOutcome {
#[must_use]
pub const fn summary(&self) -> Option<&TurnSummary> {
match self {
Self::NeedsMoreTurns { summary, .. }
| Self::Done { summary, .. }
| Self::AwaitingConfirmation { summary, .. }
| Self::Refusal { summary, .. }
| Self::Cancelled { summary, .. }
| Self::PendingToolCalls { summary, .. } => Some(summary),
Self::Error(_) => None,
}
}
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct TurnSummary {
pub thread_id: ThreadId,
pub turn: usize,
pub total_turns: u32,
pub turn_usage: TokenUsage,
pub total_usage: TokenUsage,
pub provenance: AuditProvenance,
pub response_id: Option<String>,
pub stop_reason: Option<crate::llm::StopReason>,
pub tool_call_count: usize,
pub duration_ms: u64,
pub tool_runtime: ToolRuntime,
pub strict_durability: bool,
}
impl TurnSummary {
#[must_use]
pub fn new(
thread_id: ThreadId,
turn: usize,
provenance: AuditProvenance,
options: &TurnOptions,
) -> Self {
Self {
thread_id,
turn,
total_turns: 0,
turn_usage: TokenUsage::default(),
total_usage: TokenUsage::default(),
provenance,
response_id: None,
stop_reason: None,
tool_call_count: 0,
duration_ms: 0,
tool_runtime: options.tool_runtime.clone(),
strict_durability: options.strict_durability,
}
}
}
#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ToolRuntime {
#[default]
Inline,
External,
}
#[derive(Debug, Clone, Default)]
pub struct TurnOptions {
pub tool_runtime: ToolRuntime,
pub strict_durability: bool,
}
#[derive(Clone, Debug, Default)]
pub struct RunOptions {
pub session_id: Option<String>,
pub user_id: Option<String>,
pub trace_name: Option<String>,
pub trace_tags: Vec<String>,
pub trace_metadata: serde_json::Map<String, serde_json::Value>,
pub release: Option<String>,
pub environment: Option<String>,
pub trace_text_max_chars: Option<usize>,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::llm::StopReason;
fn sample_summary() -> TurnSummary {
TurnSummary {
thread_id: ThreadId::from_string("t-summary"),
turn: 2,
total_turns: 2,
turn_usage: TokenUsage {
input_tokens: 100,
output_tokens: 50,
..Default::default()
},
total_usage: TokenUsage {
input_tokens: 200,
output_tokens: 75,
..Default::default()
},
provenance: AuditProvenance::new("anthropic", "claude-sonnet-4-5-20250929"),
response_id: Some("resp_123".into()),
stop_reason: Some(StopReason::ToolUse),
tool_call_count: 3,
duration_ms: 1_234,
tool_runtime: ToolRuntime::External,
strict_durability: true,
}
}
#[test]
fn turn_summary_round_trips_through_json() {
let original = sample_summary();
let json = serde_json::to_string(&original).expect("serialize");
let recovered: TurnSummary = serde_json::from_str(&json).expect("deserialize");
assert_eq!(recovered, original);
}
#[test]
fn turn_summary_json_has_expected_keys() {
let summary = sample_summary();
let value = serde_json::to_value(&summary).unwrap();
for key in [
"thread_id",
"turn",
"total_turns",
"turn_usage",
"total_usage",
"provenance",
"response_id",
"stop_reason",
"tool_call_count",
"duration_ms",
"tool_runtime",
"strict_durability",
] {
assert!(value.get(key).is_some(), "missing key {key}");
}
assert_eq!(value["tool_runtime"], serde_json::json!("external"));
assert_eq!(value["stop_reason"], serde_json::json!("tool_use"));
}
#[test]
fn turn_outcome_summary_accessor_works_for_every_variant() {
let summary = sample_summary();
let outcomes = vec![
TurnOutcome::NeedsMoreTurns {
turn: 1,
turn_usage: TokenUsage::default(),
total_usage: TokenUsage::default(),
summary: summary.clone(),
},
TurnOutcome::Done {
total_turns: 1,
total_usage: TokenUsage::default(),
summary: summary.clone(),
},
TurnOutcome::Refusal {
total_turns: 1,
total_usage: TokenUsage::default(),
summary: summary.clone(),
},
TurnOutcome::Cancelled {
total_turns: 1,
total_usage: TokenUsage::default(),
summary: summary.clone(),
},
];
for outcome in &outcomes {
let got = outcome.summary().expect("summary must be present");
assert_eq!(got, &summary);
}
let error_outcome =
TurnOutcome::Error(AgentError::new("boom", false));
assert!(error_outcome.summary().is_none());
}
#[test]
fn empty_turn_summary_new_captures_options_and_provenance() {
let opts = TurnOptions {
tool_runtime: ToolRuntime::External,
strict_durability: true,
};
let provenance = AuditProvenance::new("openai", "gpt-5");
let summary =
TurnSummary::new(ThreadId::from_string("t-new"), 7, provenance.clone(), &opts);
assert_eq!(summary.thread_id, ThreadId::from_string("t-new"));
assert_eq!(summary.turn, 7);
assert_eq!(summary.total_turns, 0);
assert_eq!(summary.provenance, provenance);
assert_eq!(summary.tool_runtime, ToolRuntime::External);
assert!(summary.strict_durability);
assert!(summary.response_id.is_none());
assert!(summary.stop_reason.is_none());
assert_eq!(summary.tool_call_count, 0);
assert_eq!(summary.duration_ms, 0);
}
#[test]
fn stop_reason_as_str_matches_serde_representation() {
let cases = [
(StopReason::EndTurn, "end_turn"),
(StopReason::ToolUse, "tool_use"),
(StopReason::MaxTokens, "max_tokens"),
(StopReason::StopSequence, "stop_sequence"),
(StopReason::Refusal, "refusal"),
(
StopReason::ModelContextWindowExceeded,
"model_context_window_exceeded",
),
];
for (variant, expected) in cases {
assert_eq!(variant.as_str(), expected);
let json = serde_json::to_value(variant).unwrap();
assert_eq!(json, serde_json::json!(expected));
}
}
fn sample_continuation() -> AgentContinuation {
let thread = ThreadId::from_string("t-continuation");
AgentContinuation {
thread_id: thread.clone(),
turn: 4,
total_usage: TokenUsage {
input_tokens: 200,
output_tokens: 80,
..Default::default()
},
turn_usage: TokenUsage {
input_tokens: 50,
output_tokens: 40,
..Default::default()
},
pending_tool_calls: vec![PendingToolCallInfo {
id: "call_1".into(),
name: "echo".into(),
display_name: "Echo".into(),
tier: ToolTier::Confirm,
input: serde_json::json!({"message": "hi"}),
effective_input: serde_json::json!({"message": "hi"}),
listen_context: None,
}],
awaiting_index: 0,
completed_results: Vec::new(),
state: AgentState::new(thread),
response_id: Some("resp_7914".into()),
stop_reason: Some(StopReason::ToolUse),
response_content: Vec::new(),
}
}
#[test]
fn agent_continuation_round_trips_llm_metadata() {
let original = sample_continuation();
let json = serde_json::to_string(&original).expect("serialize");
let value: serde_json::Value = serde_json::from_str(&json).expect("to value");
assert_eq!(value["response_id"], serde_json::json!("resp_7914"));
assert_eq!(value["stop_reason"], serde_json::json!("tool_use"));
let recovered: AgentContinuation = serde_json::from_str(&json).expect("deserialize");
assert_eq!(recovered.response_id.as_deref(), Some("resp_7914"));
assert_eq!(recovered.stop_reason, Some(StopReason::ToolUse));
}
#[test]
fn agent_continuation_deserializes_legacy_payload_without_llm_metadata() {
let thread = ThreadId::from_string("t-legacy");
let legacy_json = serde_json::json!({
"thread_id": thread,
"turn": 1,
"total_usage": { "input_tokens": 10, "output_tokens": 5 },
"turn_usage": { "input_tokens": 10, "output_tokens": 5 },
"pending_tool_calls": [],
"awaiting_index": 0,
"completed_results": [],
"state": AgentState::new(thread.clone()),
});
let recovered: AgentContinuation =
serde_json::from_value(legacy_json).expect("legacy payload deserialises");
assert_eq!(recovered.thread_id, thread);
assert_eq!(recovered.turn, 1);
assert!(
recovered.response_id.is_none(),
"legacy payloads default to None",
);
assert!(
recovered.stop_reason.is_none(),
"legacy payloads default to None",
);
}
#[test]
fn agent_continuation_omits_llm_metadata_when_none() {
let thread = ThreadId::from_string("t-omit");
let cont = AgentContinuation {
thread_id: thread.clone(),
turn: 1,
total_usage: TokenUsage::default(),
turn_usage: TokenUsage::default(),
pending_tool_calls: Vec::new(),
awaiting_index: 0,
completed_results: Vec::new(),
state: AgentState::new(thread),
response_id: None,
stop_reason: None,
response_content: Vec::new(),
};
let value = serde_json::to_value(&cont).unwrap();
assert!(value.get("response_id").is_none());
assert!(value.get("stop_reason").is_none());
assert!(value.get("response_content").is_none());
}
}