Skip to main content

agent_sdk/
types.rs

1//! Core types for the agent SDK.
2//!
3//! This module contains the fundamental types used throughout the SDK:
4//!
5//! - [`ThreadId`]: Unique identifier for conversation threads
6//! - [`AgentConfig`]: Configuration for the agent loop
7//! - [`TokenUsage`]: Token consumption statistics
8//! - [`ToolResult`]: Result returned from tool execution
9//! - [`ToolTier`]: Permission tiers for tools
10//! - [`AgentRunState`]: Outcome of running the agent loop (looping mode)
11//! - [`TurnOutcome`]: Outcome of running a single turn (single-turn mode)
12//! - [`AgentInput`]: Input to start or resume an agent run
13//! - [`AgentContinuation`]: Opaque state for resuming after confirmation
14//! - [`AgentState`]: Checkpointable agent state
15
16use crate::llm::ThinkingConfig;
17use serde::{Deserialize, Serialize};
18use std::collections::HashMap;
19use time::OffsetDateTime;
20use uuid::Uuid;
21
22/// Unique identifier for a conversation thread
23#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
24pub struct ThreadId(pub String);
25
26impl ThreadId {
27    #[must_use]
28    pub fn new() -> Self {
29        Self(Uuid::new_v4().to_string())
30    }
31
32    #[must_use]
33    pub fn from_string(s: impl Into<String>) -> Self {
34        Self(s.into())
35    }
36}
37
38impl Default for ThreadId {
39    fn default() -> Self {
40        Self::new()
41    }
42}
43
44impl std::fmt::Display for ThreadId {
45    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
46        write!(f, "{}", self.0)
47    }
48}
49
50/// Configuration for the agent loop
51#[derive(Clone, Debug)]
52pub struct AgentConfig {
53    /// Maximum number of turns (LLM round-trips) before stopping
54    pub max_turns: usize,
55    /// Maximum tokens per response
56    pub max_tokens: u32,
57    /// System prompt for the agent
58    pub system_prompt: String,
59    /// Model identifier
60    pub model: String,
61    /// Retry configuration for transient errors
62    pub retry: RetryConfig,
63    /// Optional extended thinking configuration
64    pub thinking: Option<ThinkingConfig>,
65    /// Enable streaming responses from the LLM.
66    ///
67    /// When `true`, emits `TextDelta` and `Thinking` events as text arrives
68    /// in real-time. When `false` (default), waits for the complete response
69    /// before emitting `Text` and `Thinking` events.
70    pub streaming: bool,
71}
72
73impl Default for AgentConfig {
74    fn default() -> Self {
75        Self {
76            max_turns: 10,
77            max_tokens: 4096,
78            system_prompt: String::new(),
79            model: String::from("claude-sonnet-4-20250514"),
80            retry: RetryConfig::default(),
81            thinking: None,
82            streaming: false,
83        }
84    }
85}
86
87/// Configuration for retry behavior on transient errors.
88#[derive(Clone, Debug)]
89pub struct RetryConfig {
90    /// Maximum number of retry attempts
91    pub max_retries: u32,
92    /// Base delay in milliseconds for exponential backoff
93    pub base_delay_ms: u64,
94    /// Maximum delay cap in milliseconds
95    pub max_delay_ms: u64,
96}
97
98impl Default for RetryConfig {
99    fn default() -> Self {
100        Self {
101            max_retries: 5,
102            base_delay_ms: 1000,
103            max_delay_ms: 120_000,
104        }
105    }
106}
107
108impl RetryConfig {
109    /// Create a retry config with no retries (for testing)
110    #[must_use]
111    pub const fn no_retry() -> Self {
112        Self {
113            max_retries: 0,
114            base_delay_ms: 0,
115            max_delay_ms: 0,
116        }
117    }
118
119    /// Create a retry config with fast retries (for testing)
120    #[must_use]
121    pub const fn fast() -> Self {
122        Self {
123            max_retries: 5,
124            base_delay_ms: 10,
125            max_delay_ms: 100,
126        }
127    }
128}
129
130/// Token usage statistics
131#[derive(Clone, Debug, Default, Serialize, Deserialize)]
132pub struct TokenUsage {
133    pub input_tokens: u32,
134    pub output_tokens: u32,
135}
136
137impl TokenUsage {
138    pub const fn add(&mut self, other: &Self) {
139        self.input_tokens += other.input_tokens;
140        self.output_tokens += other.output_tokens;
141    }
142}
143
144/// Result of a tool execution
145#[derive(Clone, Debug, Serialize, Deserialize)]
146pub struct ToolResult {
147    /// Whether the tool execution succeeded
148    pub success: bool,
149    /// Output content (displayed to user and fed back to LLM)
150    pub output: String,
151    /// Optional structured data
152    pub data: Option<serde_json::Value>,
153    /// Duration of the tool execution in milliseconds
154    pub duration_ms: Option<u64>,
155}
156
157impl ToolResult {
158    #[must_use]
159    pub fn success(output: impl Into<String>) -> Self {
160        Self {
161            success: true,
162            output: output.into(),
163            data: None,
164            duration_ms: None,
165        }
166    }
167
168    #[must_use]
169    pub fn success_with_data(output: impl Into<String>, data: serde_json::Value) -> Self {
170        Self {
171            success: true,
172            output: output.into(),
173            data: Some(data),
174            duration_ms: None,
175        }
176    }
177
178    #[must_use]
179    pub fn error(message: impl Into<String>) -> Self {
180        Self {
181            success: false,
182            output: message.into(),
183            data: None,
184            duration_ms: None,
185        }
186    }
187
188    #[must_use]
189    pub const fn with_duration(mut self, duration_ms: u64) -> Self {
190        self.duration_ms = Some(duration_ms);
191        self
192    }
193}
194
195/// Permission tier for tools
196#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
197pub enum ToolTier {
198    /// Read-only, always allowed (e.g., `get_balance`)
199    Observe,
200    /// Requires confirmation before execution.
201    /// The application determines the confirmation type (normal, PIN, biometric).
202    Confirm,
203}
204
205/// Snapshot of agent state for checkpointing
206#[derive(Clone, Debug, Serialize, Deserialize)]
207pub struct AgentState {
208    pub thread_id: ThreadId,
209    pub turn_count: usize,
210    pub total_usage: TokenUsage,
211    pub metadata: HashMap<String, serde_json::Value>,
212    #[serde(with = "time::serde::rfc3339")]
213    pub created_at: OffsetDateTime,
214}
215
216impl AgentState {
217    #[must_use]
218    pub fn new(thread_id: ThreadId) -> Self {
219        Self {
220            thread_id,
221            turn_count: 0,
222            total_usage: TokenUsage::default(),
223            metadata: HashMap::new(),
224            created_at: OffsetDateTime::now_utc(),
225        }
226    }
227}
228
229/// Error from the agent loop.
230#[derive(Debug, Clone)]
231pub struct AgentError {
232    /// Error message
233    pub message: String,
234    /// Whether the error is potentially recoverable
235    pub recoverable: bool,
236}
237
238impl AgentError {
239    #[must_use]
240    pub fn new(message: impl Into<String>, recoverable: bool) -> Self {
241        Self {
242            message: message.into(),
243            recoverable,
244        }
245    }
246}
247
248impl std::fmt::Display for AgentError {
249    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
250        write!(f, "{}", self.message)
251    }
252}
253
254impl std::error::Error for AgentError {}
255
256/// Outcome of running the agent loop.
257#[derive(Debug)]
258pub enum AgentRunState {
259    /// Agent completed successfully.
260    Done {
261        total_turns: u32,
262        input_tokens: u64,
263        output_tokens: u64,
264    },
265
266    /// Agent encountered an error.
267    Error(AgentError),
268
269    /// Agent is awaiting confirmation for a tool call.
270    /// The application should present this to the user and call resume.
271    AwaitingConfirmation {
272        /// ID of the pending tool call (from LLM)
273        tool_call_id: String,
274        /// Tool name string (for LLM protocol)
275        tool_name: String,
276        /// Human-readable display name
277        display_name: String,
278        /// Tool input parameters
279        input: serde_json::Value,
280        /// Description of what confirmation is needed
281        description: String,
282        /// Continuation state for resuming (boxed for enum size efficiency)
283        continuation: Box<AgentContinuation>,
284    },
285}
286
287/// Information about a pending tool call that was extracted from the LLM response.
288#[derive(Clone, Debug, Serialize, Deserialize)]
289pub struct PendingToolCallInfo {
290    /// Unique ID for this tool call (from LLM)
291    pub id: String,
292    /// Tool name string (for LLM protocol)
293    pub name: String,
294    /// Human-readable display name
295    pub display_name: String,
296    /// Tool input parameters
297    pub input: serde_json::Value,
298}
299
300/// Continuation state that allows resuming the agent loop.
301///
302/// This contains all the internal state needed to continue execution
303/// after receiving a confirmation decision. Pass this back when resuming.
304#[derive(Clone, Debug, Serialize, Deserialize)]
305pub struct AgentContinuation {
306    /// Thread ID (used for validation on resume)
307    pub thread_id: ThreadId,
308    /// Current turn number
309    pub turn: usize,
310    /// Total token usage so far
311    pub total_usage: TokenUsage,
312    /// Token usage for this specific turn (from the LLM call that generated tool calls)
313    pub turn_usage: TokenUsage,
314    /// All pending tool calls from this turn
315    pub pending_tool_calls: Vec<PendingToolCallInfo>,
316    /// Index of the tool call awaiting confirmation
317    pub awaiting_index: usize,
318    /// Tool results already collected (for tools before the awaiting one)
319    pub completed_results: Vec<(String, ToolResult)>,
320    /// Agent state snapshot
321    pub state: AgentState,
322}
323
324/// Input to start or resume an agent run.
325#[derive(Debug)]
326pub enum AgentInput {
327    /// Start a new conversation with user text.
328    Text(String),
329
330    /// Resume after a confirmation decision.
331    Resume {
332        /// The continuation state from `AwaitingConfirmation` (boxed for enum size efficiency).
333        continuation: Box<AgentContinuation>,
334        /// ID of the tool call being confirmed/rejected.
335        tool_call_id: String,
336        /// Whether the user confirmed the action.
337        confirmed: bool,
338        /// Optional reason if rejected.
339        rejection_reason: Option<String>,
340    },
341
342    /// Continue to the next turn (for single-turn mode).
343    ///
344    /// Use this after `TurnOutcome::NeedsMoreTurns` to execute the next turn.
345    /// The message history already contains tool results from the previous turn.
346    Continue,
347}
348
349/// Result of tool execution - may indicate async operation in progress.
350#[derive(Clone, Debug, Serialize, Deserialize)]
351pub enum ToolOutcome {
352    /// Tool completed synchronously with success
353    Success(ToolResult),
354
355    /// Tool completed synchronously with failure
356    Failed(ToolResult),
357
358    /// Tool started an async operation - must stream status to completion
359    InProgress {
360        /// Identifier for the operation (to query status)
361        operation_id: String,
362        /// Initial message for the user
363        message: String,
364    },
365}
366
367impl ToolOutcome {
368    #[must_use]
369    pub fn success(output: impl Into<String>) -> Self {
370        Self::Success(ToolResult::success(output))
371    }
372
373    #[must_use]
374    pub fn failed(message: impl Into<String>) -> Self {
375        Self::Failed(ToolResult::error(message))
376    }
377
378    #[must_use]
379    pub fn in_progress(operation_id: impl Into<String>, message: impl Into<String>) -> Self {
380        Self::InProgress {
381            operation_id: operation_id.into(),
382            message: message.into(),
383        }
384    }
385
386    /// Returns true if operation is still in progress
387    #[must_use]
388    pub const fn is_in_progress(&self) -> bool {
389        matches!(self, Self::InProgress { .. })
390    }
391}
392
393// ============================================================================
394// Tool Execution Idempotency Types
395// ============================================================================
396
397/// Status of a tool execution for idempotency tracking.
398#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
399pub enum ExecutionStatus {
400    /// Execution started but not yet completed
401    InFlight,
402    /// Execution completed (success or failure)
403    Completed,
404}
405
406/// Record of a tool execution for idempotency.
407///
408/// This struct tracks tool executions to prevent duplicate execution when
409/// the agent loop retries after a failure. The write-ahead pattern ensures
410/// that execution intent is recorded BEFORE calling the tool, and updated
411/// with results AFTER completion.
412#[derive(Clone, Debug, Serialize, Deserialize)]
413pub struct ToolExecution {
414    /// The tool call ID from the LLM (unique per invocation)
415    pub tool_call_id: String,
416    /// Thread this execution belongs to
417    pub thread_id: ThreadId,
418    /// Tool name
419    pub tool_name: String,
420    /// Display name
421    pub display_name: String,
422    /// Input parameters (for verification)
423    pub input: serde_json::Value,
424    /// Current status
425    pub status: ExecutionStatus,
426    /// Result if completed
427    pub result: Option<ToolResult>,
428    /// For async tools: the operation ID returned by `execute()`
429    pub operation_id: Option<String>,
430    /// Timestamp when execution started
431    #[serde(with = "time::serde::rfc3339")]
432    pub started_at: OffsetDateTime,
433    /// Timestamp when execution completed
434    #[serde(with = "time::serde::rfc3339::option")]
435    pub completed_at: Option<OffsetDateTime>,
436}
437
438impl ToolExecution {
439    /// Create a new in-flight execution record.
440    #[must_use]
441    pub fn new_in_flight(
442        tool_call_id: impl Into<String>,
443        thread_id: ThreadId,
444        tool_name: impl Into<String>,
445        display_name: impl Into<String>,
446        input: serde_json::Value,
447        started_at: OffsetDateTime,
448    ) -> Self {
449        Self {
450            tool_call_id: tool_call_id.into(),
451            thread_id,
452            tool_name: tool_name.into(),
453            display_name: display_name.into(),
454            input,
455            status: ExecutionStatus::InFlight,
456            result: None,
457            operation_id: None,
458            started_at,
459            completed_at: None,
460        }
461    }
462
463    /// Mark this execution as completed with a result.
464    pub fn complete(&mut self, result: ToolResult) {
465        self.status = ExecutionStatus::Completed;
466        self.result = Some(result);
467        self.completed_at = Some(OffsetDateTime::now_utc());
468    }
469
470    /// Set the operation ID for async tool tracking.
471    pub fn set_operation_id(&mut self, operation_id: impl Into<String>) {
472        self.operation_id = Some(operation_id.into());
473    }
474
475    /// Returns true if this execution is still in flight.
476    #[must_use]
477    pub fn is_in_flight(&self) -> bool {
478        self.status == ExecutionStatus::InFlight
479    }
480
481    /// Returns true if this execution has completed.
482    #[must_use]
483    pub fn is_completed(&self) -> bool {
484        self.status == ExecutionStatus::Completed
485    }
486}
487
488/// Outcome of running a single turn.
489///
490/// This is returned by `run_turn` to indicate what happened and what to do next.
491#[derive(Debug)]
492pub enum TurnOutcome {
493    /// Turn completed successfully, but more turns are needed.
494    ///
495    /// Tools were executed and their results are stored in the message history.
496    /// Call `run_turn` again with `AgentInput::Continue` to proceed.
497    NeedsMoreTurns {
498        /// The turn number that just completed
499        turn: usize,
500        /// Token usage for this turn
501        turn_usage: TokenUsage,
502        /// Cumulative token usage so far
503        total_usage: TokenUsage,
504    },
505
506    /// Agent completed successfully (no more tool calls).
507    Done {
508        /// Total turns executed
509        total_turns: u32,
510        /// Total input tokens consumed
511        input_tokens: u64,
512        /// Total output tokens consumed
513        output_tokens: u64,
514    },
515
516    /// A tool requires user confirmation.
517    ///
518    /// Present this to the user and call `run_turn` with `AgentInput::Resume`
519    /// to continue.
520    AwaitingConfirmation {
521        /// ID of the pending tool call (from LLM)
522        tool_call_id: String,
523        /// Tool name string (for LLM protocol)
524        tool_name: String,
525        /// Human-readable display name
526        display_name: String,
527        /// Tool input parameters
528        input: serde_json::Value,
529        /// Description of what confirmation is needed
530        description: String,
531        /// Continuation state for resuming (boxed for enum size efficiency)
532        continuation: Box<AgentContinuation>,
533    },
534
535    /// An error occurred.
536    Error(AgentError),
537}