Skip to main content

agent_sdk_rs/agent/
mod.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3use std::time::Instant;
4
5use async_stream::try_stream;
6use futures_util::{Stream, StreamExt};
7use tokio::time::{Duration, sleep};
8
9use crate::error::{AgentError, ProviderError, ToolError};
10use crate::llm::{
11    ChatModel, ModelCompletion, ModelMessage, ModelToolCall, ModelToolChoice, ModelToolDefinition,
12};
13use crate::tools::{DependencyMap, ToolOutcome, ToolSpec};
14
15#[derive(Debug, Clone, PartialEq, Eq)]
16/// Tool-calling policy sent to the underlying model.
17pub enum AgentToolChoice {
18    /// Model decides whether to call tools.
19    Auto,
20    /// Model must call at least one tool.
21    Required,
22    /// Model must not call tools.
23    None,
24    /// Model must call the named tool.
25    Tool(String),
26}
27
28impl Default for AgentToolChoice {
29    fn default() -> Self {
30        Self::Auto
31    }
32}
33
34#[derive(Debug, Clone)]
35/// Runtime configuration for an [`Agent`].
36pub struct AgentConfig {
37    /// Require explicit completion via `ToolOutcome::Done`.
38    pub require_done_tool: bool,
39    /// Maximum model-tool iterations per user query.
40    pub max_iterations: u32,
41    /// Optional system prompt injected at the start of empty history.
42    pub system_prompt: Option<String>,
43    /// Tool-choice policy passed to the model adapter.
44    pub tool_choice: AgentToolChoice,
45    /// Maximum number of retries for request-level provider failures.
46    pub llm_max_retries: u32,
47    /// Initial retry delay in milliseconds.
48    pub llm_retry_base_delay_ms: u64,
49    /// Maximum retry delay in milliseconds.
50    pub llm_retry_max_delay_ms: u64,
51    /// Optional hidden follow-up user message injected once before finishing.
52    pub hidden_user_message_prompt: Option<String>,
53}
54
55impl Default for AgentConfig {
56    fn default() -> Self {
57        Self {
58            require_done_tool: false,
59            max_iterations: 24,
60            system_prompt: None,
61            tool_choice: AgentToolChoice::Auto,
62            llm_max_retries: 5,
63            llm_retry_base_delay_ms: 1_000,
64            llm_retry_max_delay_ms: 60_000,
65            hidden_user_message_prompt: None,
66        }
67    }
68}
69
70#[derive(Debug, Clone, Copy, PartialEq, Eq)]
71/// Message role used in emitted stream events.
72pub enum AgentRole {
73    /// End-user message.
74    User,
75    /// Assistant/model message.
76    Assistant,
77}
78
79#[derive(Debug, Clone, Copy, PartialEq, Eq)]
80/// Final status for a tool execution step.
81pub enum StepStatus {
82    /// Tool step succeeded.
83    Completed,
84    /// Tool step failed.
85    Error,
86}
87
88#[derive(Debug, Clone, PartialEq)]
89/// Streamed events emitted by [`Agent::query_stream`].
90pub enum AgentEvent {
91    /// A new message started.
92    MessageStart {
93        /// Deterministic message id generated by the SDK.
94        message_id: String,
95        /// Role for the message.
96        role: AgentRole,
97    },
98    /// A message finished.
99    MessageComplete {
100        /// Message id.
101        message_id: String,
102        /// Rendered message content.
103        content: String,
104    },
105    /// Hidden user prompt injected by config.
106    HiddenUserMessage {
107        /// Hidden prompt content.
108        content: String,
109    },
110    /// A tool execution step started.
111    StepStart {
112        /// Tool-call id.
113        step_id: String,
114        /// Tool name.
115        title: String,
116        /// 1-based step number for this assistant turn.
117        step_number: u32,
118    },
119    /// A tool execution step finished.
120    StepComplete {
121        /// Tool-call id.
122        step_id: String,
123        /// Step completion status.
124        status: StepStatus,
125        /// Execution duration in milliseconds.
126        duration_ms: u128,
127    },
128    /// Model returned reasoning/thinking text.
129    Thinking {
130        /// Thinking content.
131        content: String,
132    },
133    /// Model returned regular text content.
134    Text {
135        /// Text content.
136        content: String,
137    },
138    /// Model requested a tool call.
139    ToolCall {
140        /// Tool name.
141        tool: String,
142        /// Raw JSON arguments.
143        args_json: serde_json::Value,
144        /// Provider/tool-call id.
145        tool_call_id: String,
146    },
147    /// Tool execution result was recorded.
148    ToolResult {
149        /// Tool name.
150        tool: String,
151        /// Result text returned to the model.
152        result_text: String,
153        /// Provider/tool-call id.
154        tool_call_id: String,
155        /// Whether this tool result represents an error.
156        is_error: bool,
157    },
158    /// Final response for the query.
159    FinalResponse {
160        /// Final assistant output.
161        content: String,
162    },
163}
164
165/// Builder for [`Agent`].
166pub struct AgentBuilder {
167    model: Option<Arc<dyn ChatModel>>,
168    tools: Vec<ToolSpec>,
169    config: AgentConfig,
170    dependencies: DependencyMap,
171    dependency_overrides: DependencyMap,
172}
173
174impl Default for AgentBuilder {
175    fn default() -> Self {
176        Self {
177            model: None,
178            tools: Vec::new(),
179            config: AgentConfig::default(),
180            dependencies: DependencyMap::new(),
181            dependency_overrides: DependencyMap::new(),
182        }
183    }
184}
185
186impl AgentBuilder {
187    /// Sets the model adapter used by the agent.
188    pub fn model<M>(mut self, model: M) -> Self
189    where
190        M: ChatModel + 'static,
191    {
192        self.model = Some(Arc::new(model));
193        self
194    }
195
196    /// Adds one tool to the registry.
197    pub fn tool(mut self, tool: ToolSpec) -> Self {
198        self.tools.push(tool);
199        self
200    }
201
202    /// Adds multiple tools to the registry.
203    pub fn tools(mut self, tools: Vec<ToolSpec>) -> Self {
204        self.tools.extend(tools);
205        self
206    }
207
208    /// Replaces the full agent config.
209    pub fn config(mut self, config: AgentConfig) -> Self {
210        self.config = config;
211        self
212    }
213
214    /// Sets the system prompt.
215    pub fn system_prompt(mut self, system_prompt: impl Into<String>) -> Self {
216        self.config.system_prompt = Some(system_prompt.into());
217        self
218    }
219
220    /// Enables or disables explicit `done` completion mode.
221    pub fn require_done_tool(mut self, require_done_tool: bool) -> Self {
222        self.config.require_done_tool = require_done_tool;
223        self
224    }
225
226    /// Sets max iterations for each query.
227    pub fn max_iterations(mut self, max_iterations: u32) -> Self {
228        self.config.max_iterations = max_iterations;
229        self
230    }
231
232    /// Sets tool-choice policy for model invocations.
233    pub fn tool_choice(mut self, tool_choice: AgentToolChoice) -> Self {
234        self.config.tool_choice = tool_choice;
235        self
236    }
237
238    /// Configures request retry behavior (exponential backoff).
239    pub fn llm_retry_config(
240        mut self,
241        max_retries: u32,
242        base_delay_ms: u64,
243        max_delay_ms: u64,
244    ) -> Self {
245        self.config.llm_max_retries = max_retries;
246        self.config.llm_retry_base_delay_ms = base_delay_ms;
247        self.config.llm_retry_max_delay_ms = max_delay_ms;
248        self
249    }
250
251    /// Sets a hidden user prompt injected once if model returns no tool calls.
252    pub fn hidden_user_message_prompt(mut self, prompt: impl Into<String>) -> Self {
253        self.config.hidden_user_message_prompt = Some(prompt.into());
254        self
255    }
256
257    /// Inserts a typed runtime dependency.
258    pub fn dependency<T>(self, value: T) -> Self
259    where
260        T: Send + Sync + 'static,
261    {
262        self.dependencies.insert(value);
263        self
264    }
265
266    /// Inserts a named runtime dependency.
267    pub fn dependency_named<T>(self, key: impl Into<String>, value: T) -> Self
268    where
269        T: Send + Sync + 'static,
270    {
271        self.dependencies.insert_named(key, value);
272        self
273    }
274
275    /// Inserts a typed dependency override.
276    pub fn dependency_override<T>(self, value: T) -> Self
277    where
278        T: Send + Sync + 'static,
279    {
280        self.dependency_overrides.insert(value);
281        self
282    }
283
284    /// Inserts a named dependency override.
285    pub fn dependency_override_named<T>(self, key: impl Into<String>, value: T) -> Self
286    where
287        T: Send + Sync + 'static,
288    {
289        self.dependency_overrides.insert_named(key, value);
290        self
291    }
292
293    /// Builds an [`Agent`] and validates required config.
294    pub fn build(self) -> Result<Agent, AgentError> {
295        let Some(model) = self.model else {
296            return Err(AgentError::Config(
297                "agent model must be configured via AgentBuilder::model(...)".to_string(),
298            ));
299        };
300
301        let mut tool_map = HashMap::new();
302        for tool in &self.tools {
303            if tool_map
304                .insert(tool.name().to_string(), tool.clone())
305                .is_some()
306            {
307                return Err(AgentError::Config(format!(
308                    "duplicate tool registered: {}",
309                    tool.name()
310                )));
311            }
312        }
313
314        Ok(Agent {
315            model,
316            tools: self.tools,
317            tool_map,
318            config: self.config,
319            dependencies: self.dependencies,
320            dependency_overrides: self.dependency_overrides,
321            history: Vec::new(),
322            next_message_id: 0,
323        })
324    }
325}
326
327/// Stateful agent runtime with conversation history and tool registry.
328pub struct Agent {
329    model: Arc<dyn ChatModel>,
330    tools: Vec<ToolSpec>,
331    tool_map: HashMap<String, ToolSpec>,
332    config: AgentConfig,
333    dependencies: DependencyMap,
334    dependency_overrides: DependencyMap,
335    history: Vec<ModelMessage>,
336    next_message_id: u64,
337}
338
339impl Agent {
340    /// Creates a new builder.
341    pub fn builder() -> AgentBuilder {
342        AgentBuilder::default()
343    }
344
345    /// Clears conversation history and resets message-id counter.
346    pub fn clear_history(&mut self) {
347        self.history.clear();
348        self.next_message_id = 0;
349    }
350
351    /// Replaces history with a preloaded message sequence.
352    pub fn load_history(&mut self, messages: Vec<ModelMessage>) {
353        self.next_message_id = messages.len() as u64;
354        self.history = messages;
355    }
356
357    /// Returns number of history messages.
358    pub fn messages_len(&self) -> usize {
359        self.history.len()
360    }
361
362    /// Returns current history slice.
363    pub fn messages(&self) -> &[ModelMessage] {
364        &self.history
365    }
366
367    /// Runs one user query and returns the final response text.
368    pub async fn query(&mut self, user_message: impl Into<String>) -> Result<String, AgentError> {
369        let stream = self.query_stream(user_message);
370        futures_util::pin_mut!(stream);
371
372        let mut final_response: Option<String> = None;
373
374        while let Some(event) = stream.next().await {
375            match event? {
376                AgentEvent::FinalResponse { content } => final_response = Some(content),
377                AgentEvent::MessageStart { .. }
378                | AgentEvent::MessageComplete { .. }
379                | AgentEvent::HiddenUserMessage { .. }
380                | AgentEvent::StepStart { .. }
381                | AgentEvent::StepComplete { .. }
382                | AgentEvent::Thinking { .. }
383                | AgentEvent::Text { .. }
384                | AgentEvent::ToolCall { .. }
385                | AgentEvent::ToolResult { .. } => {}
386            }
387        }
388
389        final_response.ok_or(AgentError::MissingFinalResponse)
390    }
391
392    /// Runs one user query and streams intermediate events.
393    pub fn query_stream(
394        &mut self,
395        user_message: impl Into<String>,
396    ) -> impl Stream<Item = Result<AgentEvent, AgentError>> + '_ {
397        let user_message = user_message.into();
398
399        try_stream! {
400            if self.history.is_empty() {
401                if let Some(system_prompt) = &self.config.system_prompt {
402                    self.history.push(ModelMessage::System(system_prompt.clone()));
403                }
404            }
405
406            let user_message_id = self.next_message_id(AgentRole::User);
407            yield AgentEvent::MessageStart {
408                message_id: user_message_id.clone(),
409                role: AgentRole::User,
410            };
411            self.history.push(ModelMessage::User(user_message.clone()));
412            yield AgentEvent::MessageComplete {
413                message_id: user_message_id,
414                content: user_message,
415            };
416
417            let tool_definitions = self
418                .tools
419                .iter()
420                .map(|tool| ModelToolDefinition {
421                    name: tool.name().to_string(),
422                    description: tool.description().to_string(),
423                    parameters: tool.json_schema().clone(),
424                })
425                .collect::<Vec<_>>();
426
427            let tool_choice = self.resolve_tool_choice(!tool_definitions.is_empty());
428            let mut hidden_prompt_injected = false;
429
430            for _ in 0..self.config.max_iterations {
431                let completion = self
432                    .invoke_with_retry(&tool_definitions, tool_choice.clone())
433                    .await?;
434
435                let assistant_message_id = self.next_message_id(AgentRole::Assistant);
436                yield AgentEvent::MessageStart {
437                    message_id: assistant_message_id.clone(),
438                    role: AgentRole::Assistant,
439                };
440
441                if let Some(thinking) = completion.thinking.clone() {
442                    yield AgentEvent::Thinking { content: thinking };
443                }
444
445                self.append_assistant_message(&completion);
446
447                if let Some(text) = completion.text.clone() {
448                    if !text.is_empty() {
449                        yield AgentEvent::Text {
450                            content: text.clone(),
451                        };
452                    }
453                }
454
455                let assistant_content = completion.text.clone().unwrap_or_default();
456                yield AgentEvent::MessageComplete {
457                    message_id: assistant_message_id,
458                    content: assistant_content.clone(),
459                };
460
461                if completion.tool_calls.is_empty() {
462                    if !self.config.require_done_tool {
463                        if !hidden_prompt_injected {
464                            if let Some(hidden_prompt) = self.config.hidden_user_message_prompt.clone() {
465                                hidden_prompt_injected = true;
466                                self.history.push(ModelMessage::User(hidden_prompt.clone()));
467                                yield AgentEvent::HiddenUserMessage {
468                                    content: hidden_prompt,
469                                };
470                                continue;
471                            }
472                        }
473
474                        yield AgentEvent::FinalResponse {
475                            content: completion.text.unwrap_or_default(),
476                        };
477                        return;
478                    }
479                    continue;
480                }
481
482                let mut step_number = 0_u32;
483                for tool_call in completion.tool_calls {
484                    step_number += 1;
485                    yield AgentEvent::StepStart {
486                        step_id: tool_call.id.clone(),
487                        title: tool_call.name.clone(),
488                        step_number,
489                    };
490
491                    yield AgentEvent::ToolCall {
492                        tool: tool_call.name.clone(),
493                        args_json: tool_call.arguments.clone(),
494                        tool_call_id: tool_call.id.clone(),
495                    };
496
497                    let step_start = Instant::now();
498                    let execution = self.execute_tool_call(&tool_call).await;
499                    self.history.push(ModelMessage::ToolResult {
500                        tool_call_id: tool_call.id.clone(),
501                        tool_name: tool_call.name.clone(),
502                        content: execution.result_text.clone(),
503                        is_error: execution.is_error,
504                    });
505
506                    yield AgentEvent::ToolResult {
507                        tool: tool_call.name.clone(),
508                        result_text: execution.result_text.clone(),
509                        tool_call_id: tool_call.id.clone(),
510                        is_error: execution.is_error,
511                    };
512
513                    yield AgentEvent::StepComplete {
514                        step_id: tool_call.id.clone(),
515                        status: if execution.is_error {
516                            StepStatus::Error
517                        } else {
518                            StepStatus::Completed
519                        },
520                        duration_ms: step_start.elapsed().as_millis(),
521                    };
522
523                    if let Some(done_message) = execution.done_message {
524                        yield AgentEvent::FinalResponse {
525                            content: done_message,
526                        };
527                        return;
528                    }
529                }
530            }
531
532            Err::<(), AgentError>(AgentError::MaxIterationsReached {
533                max_iterations: self.config.max_iterations,
534            })?;
535        }
536    }
537
538    fn next_message_id(&mut self, role: AgentRole) -> String {
539        self.next_message_id += 1;
540        let role_label = match role {
541            AgentRole::User => "user",
542            AgentRole::Assistant => "assistant",
543        };
544        format!("msg_{}_{}", self.next_message_id, role_label)
545    }
546
547    fn resolve_tool_choice(&self, has_tools: bool) -> ModelToolChoice {
548        if !has_tools {
549            return ModelToolChoice::None;
550        }
551
552        match &self.config.tool_choice {
553            AgentToolChoice::Auto => ModelToolChoice::Auto,
554            AgentToolChoice::Required => ModelToolChoice::Required,
555            AgentToolChoice::None => ModelToolChoice::None,
556            AgentToolChoice::Tool(name) => ModelToolChoice::Tool(name.clone()),
557        }
558    }
559
560    async fn invoke_with_retry(
561        &self,
562        tool_definitions: &[ModelToolDefinition],
563        tool_choice: ModelToolChoice,
564    ) -> Result<ModelCompletion, AgentError> {
565        let max_retries = self.config.llm_max_retries.max(1);
566        for attempt in 0..max_retries {
567            match self
568                .model
569                .invoke(&self.history, tool_definitions, tool_choice.clone())
570                .await
571            {
572                Ok(completion) => return Ok(completion),
573                Err(err) => {
574                    let should_retry =
575                        is_retryable_provider_error(&err) && (attempt + 1) < max_retries;
576                    if !should_retry {
577                        return Err(AgentError::Provider(err));
578                    }
579
580                    let delay_ms = retry_delay_ms(
581                        attempt,
582                        self.config.llm_retry_base_delay_ms,
583                        self.config.llm_retry_max_delay_ms,
584                    );
585                    sleep(Duration::from_millis(delay_ms)).await;
586                }
587            }
588        }
589
590        Err(AgentError::Config(
591            "retry loop failed unexpectedly".to_string(),
592        ))
593    }
594
595    fn append_assistant_message(&mut self, completion: &ModelCompletion) {
596        self.history.push(ModelMessage::Assistant {
597            content: completion.text.clone(),
598            tool_calls: completion.tool_calls.clone(),
599        });
600    }
601
602    async fn execute_tool_call(&self, tool_call: &ModelToolCall) -> ToolExecutionResult {
603        let Some(tool) = self.tool_map.get(&tool_call.name) else {
604            return ToolExecutionResult {
605                result_text: format!("Unknown tool '{}'.", tool_call.name),
606                is_error: true,
607                done_message: None,
608            };
609        };
610
611        let runtime_dependencies = self.dependencies.merged_with(&self.dependency_overrides);
612
613        match tool
614            .execute(tool_call.arguments.clone(), &runtime_dependencies)
615            .await
616        {
617            Ok(ToolOutcome::Text(text)) => ToolExecutionResult {
618                result_text: text,
619                is_error: false,
620                done_message: None,
621            },
622            Ok(ToolOutcome::Done(message)) => ToolExecutionResult {
623                result_text: format!("Task completed: {message}"),
624                is_error: false,
625                done_message: Some(message),
626            },
627            Err(err) => ToolExecutionResult {
628                result_text: format_tool_error(err),
629                is_error: true,
630                done_message: None,
631            },
632        }
633    }
634}
635
636fn is_retryable_provider_error(err: &ProviderError) -> bool {
637    match err {
638        ProviderError::Request(_) => true,
639        ProviderError::Response(_) => false,
640    }
641}
642
643fn retry_delay_ms(attempt: u32, base_delay_ms: u64, max_delay_ms: u64) -> u64 {
644    let mut delay = base_delay_ms;
645    for _ in 0..attempt {
646        delay = delay.saturating_mul(2);
647    }
648    delay.min(max_delay_ms)
649}
650
651fn format_tool_error(err: ToolError) -> String {
652    err.to_string()
653}
654
655struct ToolExecutionResult {
656    result_text: String,
657    is_error: bool,
658    done_message: Option<String>,
659}
660
661/// Convenience wrapper around [`Agent::query`].
662pub async fn query(
663    agent: &mut Agent,
664    user_message: impl Into<String>,
665) -> Result<String, AgentError> {
666    agent.query(user_message).await
667}
668
669/// Convenience wrapper around [`Agent::query_stream`].
670pub fn query_stream(
671    agent: &mut Agent,
672    user_message: impl Into<String>,
673) -> impl Stream<Item = Result<AgentEvent, AgentError>> + '_ {
674    agent.query_stream(user_message)
675}
676
677#[cfg(test)]
678mod tests;