Skip to main content

agent_sdk_rs/agent/
mod.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3use std::time::Instant;
4
5use async_stream::try_stream;
6use futures_util::{Stream, StreamExt};
7use tokio::time::{Duration, sleep};
8
9use crate::error::{AgentError, ProviderError, ToolError};
10use crate::llm::{
11    ChatModel, ModelCompletion, ModelMessage, ModelToolCall, ModelToolChoice, ModelToolDefinition,
12};
13use crate::tools::{DependencyMap, ToolOutcome, ToolSpec};
14
15#[derive(Debug, Clone, PartialEq, Eq)]
16pub enum AgentToolChoice {
17    Auto,
18    Required,
19    None,
20    Tool(String),
21}
22
23impl Default for AgentToolChoice {
24    fn default() -> Self {
25        Self::Auto
26    }
27}
28
29#[derive(Debug, Clone)]
30pub struct AgentConfig {
31    pub require_done_tool: bool,
32    pub max_iterations: u32,
33    pub system_prompt: Option<String>,
34    pub tool_choice: AgentToolChoice,
35    pub llm_max_retries: u32,
36    pub llm_retry_base_delay_ms: u64,
37    pub llm_retry_max_delay_ms: u64,
38    pub hidden_user_message_prompt: Option<String>,
39}
40
41impl Default for AgentConfig {
42    fn default() -> Self {
43        Self {
44            require_done_tool: false,
45            max_iterations: 24,
46            system_prompt: None,
47            tool_choice: AgentToolChoice::Auto,
48            llm_max_retries: 5,
49            llm_retry_base_delay_ms: 1_000,
50            llm_retry_max_delay_ms: 60_000,
51            hidden_user_message_prompt: None,
52        }
53    }
54}
55
56#[derive(Debug, Clone, Copy, PartialEq, Eq)]
57pub enum AgentRole {
58    User,
59    Assistant,
60}
61
62#[derive(Debug, Clone, Copy, PartialEq, Eq)]
63pub enum StepStatus {
64    Completed,
65    Error,
66}
67
68#[derive(Debug, Clone, PartialEq)]
69pub enum AgentEvent {
70    MessageStart {
71        message_id: String,
72        role: AgentRole,
73    },
74    MessageComplete {
75        message_id: String,
76        content: String,
77    },
78    HiddenUserMessage {
79        content: String,
80    },
81    StepStart {
82        step_id: String,
83        title: String,
84        step_number: u32,
85    },
86    StepComplete {
87        step_id: String,
88        status: StepStatus,
89        duration_ms: u128,
90    },
91    Thinking {
92        content: String,
93    },
94    Text {
95        content: String,
96    },
97    ToolCall {
98        tool: String,
99        args_json: serde_json::Value,
100        tool_call_id: String,
101    },
102    ToolResult {
103        tool: String,
104        result_text: String,
105        tool_call_id: String,
106        is_error: bool,
107    },
108    FinalResponse {
109        content: String,
110    },
111}
112
113pub struct AgentBuilder {
114    model: Option<Arc<dyn ChatModel>>,
115    tools: Vec<ToolSpec>,
116    config: AgentConfig,
117    dependencies: DependencyMap,
118    dependency_overrides: DependencyMap,
119}
120
121impl Default for AgentBuilder {
122    fn default() -> Self {
123        Self {
124            model: None,
125            tools: Vec::new(),
126            config: AgentConfig::default(),
127            dependencies: DependencyMap::new(),
128            dependency_overrides: DependencyMap::new(),
129        }
130    }
131}
132
133impl AgentBuilder {
134    pub fn model<M>(mut self, model: M) -> Self
135    where
136        M: ChatModel + 'static,
137    {
138        self.model = Some(Arc::new(model));
139        self
140    }
141
142    pub fn tool(mut self, tool: ToolSpec) -> Self {
143        self.tools.push(tool);
144        self
145    }
146
147    pub fn tools(mut self, tools: Vec<ToolSpec>) -> Self {
148        self.tools.extend(tools);
149        self
150    }
151
152    pub fn config(mut self, config: AgentConfig) -> Self {
153        self.config = config;
154        self
155    }
156
157    pub fn system_prompt(mut self, system_prompt: impl Into<String>) -> Self {
158        self.config.system_prompt = Some(system_prompt.into());
159        self
160    }
161
162    pub fn require_done_tool(mut self, require_done_tool: bool) -> Self {
163        self.config.require_done_tool = require_done_tool;
164        self
165    }
166
167    pub fn max_iterations(mut self, max_iterations: u32) -> Self {
168        self.config.max_iterations = max_iterations;
169        self
170    }
171
172    pub fn tool_choice(mut self, tool_choice: AgentToolChoice) -> Self {
173        self.config.tool_choice = tool_choice;
174        self
175    }
176
177    pub fn llm_retry_config(
178        mut self,
179        max_retries: u32,
180        base_delay_ms: u64,
181        max_delay_ms: u64,
182    ) -> Self {
183        self.config.llm_max_retries = max_retries;
184        self.config.llm_retry_base_delay_ms = base_delay_ms;
185        self.config.llm_retry_max_delay_ms = max_delay_ms;
186        self
187    }
188
189    pub fn hidden_user_message_prompt(mut self, prompt: impl Into<String>) -> Self {
190        self.config.hidden_user_message_prompt = Some(prompt.into());
191        self
192    }
193
194    pub fn dependency<T>(self, value: T) -> Self
195    where
196        T: Send + Sync + 'static,
197    {
198        self.dependencies.insert(value);
199        self
200    }
201
202    pub fn dependency_named<T>(self, key: impl Into<String>, value: T) -> Self
203    where
204        T: Send + Sync + 'static,
205    {
206        self.dependencies.insert_named(key, value);
207        self
208    }
209
210    pub fn dependency_override<T>(self, value: T) -> Self
211    where
212        T: Send + Sync + 'static,
213    {
214        self.dependency_overrides.insert(value);
215        self
216    }
217
218    pub fn dependency_override_named<T>(self, key: impl Into<String>, value: T) -> Self
219    where
220        T: Send + Sync + 'static,
221    {
222        self.dependency_overrides.insert_named(key, value);
223        self
224    }
225
226    pub fn build(self) -> Result<Agent, AgentError> {
227        let Some(model) = self.model else {
228            return Err(AgentError::Config(
229                "agent model must be configured via AgentBuilder::model(...)".to_string(),
230            ));
231        };
232
233        let mut tool_map = HashMap::new();
234        for tool in &self.tools {
235            if tool_map
236                .insert(tool.name().to_string(), tool.clone())
237                .is_some()
238            {
239                return Err(AgentError::Config(format!(
240                    "duplicate tool registered: {}",
241                    tool.name()
242                )));
243            }
244        }
245
246        Ok(Agent {
247            model,
248            tools: self.tools,
249            tool_map,
250            config: self.config,
251            dependencies: self.dependencies,
252            dependency_overrides: self.dependency_overrides,
253            history: Vec::new(),
254            next_message_id: 0,
255        })
256    }
257}
258
259pub struct Agent {
260    model: Arc<dyn ChatModel>,
261    tools: Vec<ToolSpec>,
262    tool_map: HashMap<String, ToolSpec>,
263    config: AgentConfig,
264    dependencies: DependencyMap,
265    dependency_overrides: DependencyMap,
266    history: Vec<ModelMessage>,
267    next_message_id: u64,
268}
269
270impl Agent {
271    pub fn builder() -> AgentBuilder {
272        AgentBuilder::default()
273    }
274
275    pub fn clear_history(&mut self) {
276        self.history.clear();
277        self.next_message_id = 0;
278    }
279
280    pub fn load_history(&mut self, messages: Vec<ModelMessage>) {
281        self.next_message_id = messages.len() as u64;
282        self.history = messages;
283    }
284
285    pub fn messages_len(&self) -> usize {
286        self.history.len()
287    }
288
289    pub fn messages(&self) -> &[ModelMessage] {
290        &self.history
291    }
292
293    pub async fn query(&mut self, user_message: impl Into<String>) -> Result<String, AgentError> {
294        let stream = self.query_stream(user_message);
295        futures_util::pin_mut!(stream);
296
297        let mut final_response: Option<String> = None;
298
299        while let Some(event) = stream.next().await {
300            match event? {
301                AgentEvent::FinalResponse { content } => final_response = Some(content),
302                AgentEvent::MessageStart { .. }
303                | AgentEvent::MessageComplete { .. }
304                | AgentEvent::HiddenUserMessage { .. }
305                | AgentEvent::StepStart { .. }
306                | AgentEvent::StepComplete { .. }
307                | AgentEvent::Thinking { .. }
308                | AgentEvent::Text { .. }
309                | AgentEvent::ToolCall { .. }
310                | AgentEvent::ToolResult { .. } => {}
311            }
312        }
313
314        final_response.ok_or(AgentError::MissingFinalResponse)
315    }
316
317    pub fn query_stream(
318        &mut self,
319        user_message: impl Into<String>,
320    ) -> impl Stream<Item = Result<AgentEvent, AgentError>> + '_ {
321        let user_message = user_message.into();
322
323        try_stream! {
324            if self.history.is_empty() {
325                if let Some(system_prompt) = &self.config.system_prompt {
326                    self.history.push(ModelMessage::System(system_prompt.clone()));
327                }
328            }
329
330            let user_message_id = self.next_message_id(AgentRole::User);
331            yield AgentEvent::MessageStart {
332                message_id: user_message_id.clone(),
333                role: AgentRole::User,
334            };
335            self.history.push(ModelMessage::User(user_message.clone()));
336            yield AgentEvent::MessageComplete {
337                message_id: user_message_id,
338                content: user_message,
339            };
340
341            let tool_definitions = self
342                .tools
343                .iter()
344                .map(|tool| ModelToolDefinition {
345                    name: tool.name().to_string(),
346                    description: tool.description().to_string(),
347                    parameters: tool.json_schema().clone(),
348                })
349                .collect::<Vec<_>>();
350
351            let tool_choice = self.resolve_tool_choice(!tool_definitions.is_empty());
352            let mut hidden_prompt_injected = false;
353
354            for _ in 0..self.config.max_iterations {
355                let completion = self
356                    .invoke_with_retry(&tool_definitions, tool_choice.clone())
357                    .await?;
358
359                let assistant_message_id = self.next_message_id(AgentRole::Assistant);
360                yield AgentEvent::MessageStart {
361                    message_id: assistant_message_id.clone(),
362                    role: AgentRole::Assistant,
363                };
364
365                if let Some(thinking) = completion.thinking.clone() {
366                    yield AgentEvent::Thinking { content: thinking };
367                }
368
369                self.append_assistant_message(&completion);
370
371                if let Some(text) = completion.text.clone() {
372                    if !text.is_empty() {
373                        yield AgentEvent::Text {
374                            content: text.clone(),
375                        };
376                    }
377                }
378
379                let assistant_content = completion.text.clone().unwrap_or_default();
380                yield AgentEvent::MessageComplete {
381                    message_id: assistant_message_id,
382                    content: assistant_content.clone(),
383                };
384
385                if completion.tool_calls.is_empty() {
386                    if !self.config.require_done_tool {
387                        if !hidden_prompt_injected {
388                            if let Some(hidden_prompt) = self.config.hidden_user_message_prompt.clone() {
389                                hidden_prompt_injected = true;
390                                self.history.push(ModelMessage::User(hidden_prompt.clone()));
391                                yield AgentEvent::HiddenUserMessage {
392                                    content: hidden_prompt,
393                                };
394                                continue;
395                            }
396                        }
397
398                        yield AgentEvent::FinalResponse {
399                            content: completion.text.unwrap_or_default(),
400                        };
401                        return;
402                    }
403                    continue;
404                }
405
406                let mut step_number = 0_u32;
407                for tool_call in completion.tool_calls {
408                    step_number += 1;
409                    yield AgentEvent::StepStart {
410                        step_id: tool_call.id.clone(),
411                        title: tool_call.name.clone(),
412                        step_number,
413                    };
414
415                    yield AgentEvent::ToolCall {
416                        tool: tool_call.name.clone(),
417                        args_json: tool_call.arguments.clone(),
418                        tool_call_id: tool_call.id.clone(),
419                    };
420
421                    let step_start = Instant::now();
422                    let execution = self.execute_tool_call(&tool_call).await;
423                    self.history.push(ModelMessage::ToolResult {
424                        tool_call_id: tool_call.id.clone(),
425                        tool_name: tool_call.name.clone(),
426                        content: execution.result_text.clone(),
427                        is_error: execution.is_error,
428                    });
429
430                    yield AgentEvent::ToolResult {
431                        tool: tool_call.name.clone(),
432                        result_text: execution.result_text.clone(),
433                        tool_call_id: tool_call.id.clone(),
434                        is_error: execution.is_error,
435                    };
436
437                    yield AgentEvent::StepComplete {
438                        step_id: tool_call.id.clone(),
439                        status: if execution.is_error {
440                            StepStatus::Error
441                        } else {
442                            StepStatus::Completed
443                        },
444                        duration_ms: step_start.elapsed().as_millis(),
445                    };
446
447                    if let Some(done_message) = execution.done_message {
448                        yield AgentEvent::FinalResponse {
449                            content: done_message,
450                        };
451                        return;
452                    }
453                }
454            }
455
456            Err::<(), AgentError>(AgentError::MaxIterationsReached {
457                max_iterations: self.config.max_iterations,
458            })?;
459        }
460    }
461
462    fn next_message_id(&mut self, role: AgentRole) -> String {
463        self.next_message_id += 1;
464        let role_label = match role {
465            AgentRole::User => "user",
466            AgentRole::Assistant => "assistant",
467        };
468        format!("msg_{}_{}", self.next_message_id, role_label)
469    }
470
471    fn resolve_tool_choice(&self, has_tools: bool) -> ModelToolChoice {
472        if !has_tools {
473            return ModelToolChoice::None;
474        }
475
476        match &self.config.tool_choice {
477            AgentToolChoice::Auto => ModelToolChoice::Auto,
478            AgentToolChoice::Required => ModelToolChoice::Required,
479            AgentToolChoice::None => ModelToolChoice::None,
480            AgentToolChoice::Tool(name) => ModelToolChoice::Tool(name.clone()),
481        }
482    }
483
484    async fn invoke_with_retry(
485        &self,
486        tool_definitions: &[ModelToolDefinition],
487        tool_choice: ModelToolChoice,
488    ) -> Result<ModelCompletion, AgentError> {
489        let max_retries = self.config.llm_max_retries.max(1);
490        for attempt in 0..max_retries {
491            match self
492                .model
493                .invoke(&self.history, tool_definitions, tool_choice.clone())
494                .await
495            {
496                Ok(completion) => return Ok(completion),
497                Err(err) => {
498                    let should_retry =
499                        is_retryable_provider_error(&err) && (attempt + 1) < max_retries;
500                    if !should_retry {
501                        return Err(AgentError::Provider(err));
502                    }
503
504                    let delay_ms = retry_delay_ms(
505                        attempt,
506                        self.config.llm_retry_base_delay_ms,
507                        self.config.llm_retry_max_delay_ms,
508                    );
509                    sleep(Duration::from_millis(delay_ms)).await;
510                }
511            }
512        }
513
514        Err(AgentError::Config(
515            "retry loop failed unexpectedly".to_string(),
516        ))
517    }
518
519    fn append_assistant_message(&mut self, completion: &ModelCompletion) {
520        self.history.push(ModelMessage::Assistant {
521            content: completion.text.clone(),
522            tool_calls: completion.tool_calls.clone(),
523        });
524    }
525
526    async fn execute_tool_call(&self, tool_call: &ModelToolCall) -> ToolExecutionResult {
527        let Some(tool) = self.tool_map.get(&tool_call.name) else {
528            return ToolExecutionResult {
529                result_text: format!("Unknown tool '{}'.", tool_call.name),
530                is_error: true,
531                done_message: None,
532            };
533        };
534
535        let runtime_dependencies = self.dependencies.merged_with(&self.dependency_overrides);
536
537        match tool
538            .execute(tool_call.arguments.clone(), &runtime_dependencies)
539            .await
540        {
541            Ok(ToolOutcome::Text(text)) => ToolExecutionResult {
542                result_text: text,
543                is_error: false,
544                done_message: None,
545            },
546            Ok(ToolOutcome::Done(message)) => ToolExecutionResult {
547                result_text: format!("Task completed: {message}"),
548                is_error: false,
549                done_message: Some(message),
550            },
551            Err(err) => ToolExecutionResult {
552                result_text: format_tool_error(err),
553                is_error: true,
554                done_message: None,
555            },
556        }
557    }
558}
559
560fn is_retryable_provider_error(err: &ProviderError) -> bool {
561    match err {
562        ProviderError::Request(_) => true,
563        ProviderError::Response(_) => false,
564    }
565}
566
567fn retry_delay_ms(attempt: u32, base_delay_ms: u64, max_delay_ms: u64) -> u64 {
568    let mut delay = base_delay_ms;
569    for _ in 0..attempt {
570        delay = delay.saturating_mul(2);
571    }
572    delay.min(max_delay_ms)
573}
574
575fn format_tool_error(err: ToolError) -> String {
576    err.to_string()
577}
578
579struct ToolExecutionResult {
580    result_text: String,
581    is_error: bool,
582    done_message: Option<String>,
583}
584
585pub async fn query(
586    agent: &mut Agent,
587    user_message: impl Into<String>,
588) -> Result<String, AgentError> {
589    agent.query(user_message).await
590}
591
592pub fn query_stream(
593    agent: &mut Agent,
594    user_message: impl Into<String>,
595) -> impl Stream<Item = Result<AgentEvent, AgentError>> + '_ {
596    agent.query_stream(user_message)
597}
598
599#[cfg(test)]
600mod tests;