strands_agents/agent/
mod.rs

1//! Agent implementation for conversational AI.
2
3mod builder;
4mod result;
5mod state;
6
7use std::sync::Arc;
8
9use futures::StreamExt;
10
11use crate::conversation::ConversationManager;
12use crate::hooks::HookRegistry;
13use crate::models::Model;
14use crate::telemetry::EventLoopMetrics;
15use crate::tools::{InvocationState, ToolRegistry};
16use crate::types::content::{ContentBlock, Message, Messages, Role};
17use crate::types::errors::{Result, StrandsError};
18
19pub use builder::AgentBuilder;
20pub use result::AgentResult;
21pub use state::AgentState;
22
23/// Input types for agent invocation.
24pub enum AgentInput {
25    Text(String),
26    ContentBlocks(Vec<ContentBlock>),
27    Messages(Messages),
28    None,
29}
30
31impl From<&str> for AgentInput {
32    fn from(s: &str) -> Self { AgentInput::Text(s.to_string()) }
33}
34
35impl From<String> for AgentInput {
36    fn from(s: String) -> Self { AgentInput::Text(s) }
37}
38
39impl From<Vec<ContentBlock>> for AgentInput {
40    fn from(blocks: Vec<ContentBlock>) -> Self { AgentInput::ContentBlocks(blocks) }
41}
42
43impl From<Messages> for AgentInput {
44    fn from(messages: Messages) -> Self { AgentInput::Messages(messages) }
45}
46
47impl<T: Into<String>> From<Option<T>> for AgentInput {
48    fn from(opt: Option<T>) -> Self {
49        match opt {
50            Some(s) => AgentInput::Text(s.into()),
51            None => AgentInput::None,
52        }
53    }
54}
55
56/// Tool caller for direct tool invocation.
57pub struct ToolCaller<'a> {
58    agent: &'a mut Agent,
59}
60
61impl<'a> ToolCaller<'a> {
62    /// Invoke a tool by name with the given input.
63    pub async fn invoke(
64        &mut self,
65        tool_name: &str,
66        input: serde_json::Value,
67    ) -> Result<crate::types::tools::ToolResult> {
68        self.invoke_with_options(tool_name, input, None, None).await
69    }
70
71    /// Invoke a tool by name with options.
72    pub async fn invoke_with_options(
73        &mut self,
74        tool_name: &str,
75        input: serde_json::Value,
76        user_message_override: Option<&str>,
77        record_direct_tool_call: Option<bool>,
78    ) -> Result<crate::types::tools::ToolResult> {
79        use crate::types::tools::{ToolResult, ToolUse};
80        use crate::tools::ToolContext;
81
82        if self.agent.interrupt_state.activated {
83            return Err(StrandsError::EventLoopError {
84                message: "cannot directly call tool during interrupt".to_string(),
85            });
86        }
87
88        let tool = self.agent.tool_registry.get(tool_name)
89            .ok_or_else(|| StrandsError::ToolNotFound {
90                tool_name: tool_name.to_string(),
91            })?;
92
93        let tool_id = format!("tooluse_{}_{}", tool_name, uuid::Uuid::new_v4());
94        let tool_use = ToolUse {
95            name: tool_name.to_string(),
96            tool_use_id: tool_id.clone(),
97            input: input.clone(),
98        };
99
100        let context = ToolContext::with_state(InvocationState::new());
101        let result = match tool.invoke(input.clone(), &context).await {
102            Ok(r) => ToolResult {
103                tool_use_id: tool_id.clone(),
104                status: r.status,
105                content: r.content,
106            },
107            Err(e) => ToolResult::error(&tool_id, e),
108        };
109
110        let should_record = record_direct_tool_call
111            .unwrap_or(self.agent.record_direct_tool_call);
112
113        if should_record {
114            self.record_tool_execution(&tool_use, &result, user_message_override).await?;
115        }
116
117        self.agent.conversation_manager.apply_management(&mut self.agent.messages);
118
119        Ok(result)
120    }
121
122    async fn record_tool_execution(
123        &mut self,
124        tool_use: &crate::types::tools::ToolUse,
125        tool_result: &crate::types::tools::ToolResult,
126        user_message_override: Option<&str>,
127    ) -> Result<()> {
128        let input_json = serde_json::to_string(&tool_use.input)
129            .unwrap_or_else(|_| "<<non-serializable>>".to_string());
130
131        let mut user_content = Vec::new();
132        if let Some(msg) = user_message_override {
133            user_content.push(ContentBlock::text(format!("{}\n", msg)));
134        }
135        user_content.push(ContentBlock::text(format!(
136            "agent.tool.{} direct tool call.\nInput parameters: {}\n",
137            tool_use.name, input_json
138        )));
139
140        let user_msg = Message { role: Role::User, content: user_content };
141        let tool_use_msg = Message {
142            role: Role::Assistant,
143            content: vec![ContentBlock::tool_use(tool_use.clone())],
144        };
145        let tool_result_msg = Message {
146            role: Role::User,
147            content: vec![ContentBlock::tool_result(tool_result.clone())],
148        };
149        let assistant_msg = Message {
150            role: Role::Assistant,
151            content: vec![ContentBlock::text(format!("agent.tool.{} was called.", tool_use.name))],
152        };
153
154        self.agent.messages.push(user_msg);
155        self.agent.messages.push(tool_use_msg);
156        self.agent.messages.push(tool_result_msg);
157        self.agent.messages.push(assistant_msg);
158
159        Ok(())
160    }
161}
162
163/// The main agent struct for conversational AI.
164pub struct Agent {
165    pub(crate) model: Arc<dyn Model>,
166    pub(crate) messages: Messages,
167    pub(crate) system_prompt: Option<String>,
168    pub(crate) tool_registry: ToolRegistry,
169    agent_name: Option<String>,
170    pub agent_id: String,
171    pub description: Option<String>,
172    pub state: AgentState,
173    pub(crate) hooks: HookRegistry,
174    pub(crate) conversation_manager: Box<dyn ConversationManager>,
175    interrupt_state: crate::types::interrupt::InterruptState,
176    /// Whether to record direct tool calls in message history.
177    pub record_direct_tool_call: bool,
178    /// Custom trace attributes for OpenTelemetry.
179    pub trace_attributes: std::collections::HashMap<String, String>,
180    /// Maximum number of tool calls per cycle (None = unlimited).
181    pub max_tool_calls: Option<usize>,
182    /// Structured output context for typed responses.
183    pub(crate) structured_output_context: Option<crate::tools::structured_output::StructuredOutputContext>,
184}
185
186impl Agent {
187    /// Creates a new agent builder.
188    pub fn builder() -> AgentBuilder { AgentBuilder::new() }
189
190    /// Returns the agent name.
191    pub fn name(&self) -> Option<&String> { self.agent_name.as_ref() }
192
193    /// Sets the agent name.
194    pub fn set_name(&mut self, name: impl Into<String>) {
195        self.agent_name = Some(name.into());
196    }
197
198    /// Returns the system prompt.
199    pub fn system_prompt(&self) -> Option<&str> { self.system_prompt.as_deref() }
200
201    /// Sets the system prompt.
202    pub fn set_system_prompt(&mut self, prompt: impl Into<String>) {
203        self.system_prompt = Some(prompt.into());
204    }
205
206    /// Returns the current messages.
207    pub fn messages(&self) -> &Messages { &self.messages }
208
209    /// Adds a message to the conversation.
210    pub fn add_message(&mut self, message: Message) {
211        self.messages.push(message);
212    }
213
214    /// Clears all messages.
215    pub fn clear_messages(&mut self) {
216        self.messages.clear();
217    }
218
219    /// Returns the tool registry.
220    pub fn tool_registry(&self) -> &ToolRegistry { &self.tool_registry }
221
222    /// Returns a mutable reference to the tool registry.
223    pub fn tool_registry_mut(&mut self) -> &mut ToolRegistry { &mut self.tool_registry }
224
225    /// Returns the names of all registered tools.
226    pub fn tool_names(&self) -> Vec<&str> { self.tool_registry.tool_names() }
227
228    /// Returns the agent ID.
229    pub fn agent_id(&self) -> Option<&str> { Some(&self.agent_id) }
230
231    /// Returns the hooks registry.
232    pub fn hooks(&self) -> &HookRegistry { &self.hooks }
233
234    /// Returns a mutable reference to the hooks registry.
235    pub fn hooks_mut(&mut self) -> &mut HookRegistry { &mut self.hooks }
236
237    /// Returns the conversation manager.
238    pub fn conversation_manager(&self) -> &dyn ConversationManager { self.conversation_manager.as_ref() }
239
240    /// Returns a mutable reference to the conversation manager.
241    pub fn conversation_manager_mut(&mut self) -> &mut dyn ConversationManager { self.conversation_manager.as_mut() }
242
243    /// Returns a reference to the agent state.
244    pub fn state(&self) -> &AgentState { &self.state }
245
246    /// Returns a mutable reference to the agent state.
247    pub fn state_mut(&mut self) -> &mut AgentState { &mut self.state }
248
249    /// Returns a reference to the interrupt state.
250    pub fn interrupt_state(&self) -> &crate::types::interrupt::InterruptState { &self.interrupt_state }
251
252    /// Returns a mutable reference to the interrupt state.
253    pub fn interrupt_state_mut(&mut self) -> &mut crate::types::interrupt::InterruptState { &mut self.interrupt_state }
254
255    /// Sets the interrupt state.
256    pub fn set_interrupt_state(&mut self, state: crate::types::interrupt::InterruptState) {
257        self.interrupt_state = state;
258    }
259
260    /// Returns whether the agent is currently in an interrupt state.
261    pub fn is_interrupted(&self) -> bool { self.interrupt_state.activated }
262
263    /// Sets the messages directly.
264    pub fn set_messages(&mut self, messages: Messages) {
265        self.messages = messages;
266    }
267
268    /// Returns a tool caller for direct tool invocation.
269    ///
270    /// This enables the pattern: `agent.tool().invoke("tool_name", input).await`
271    pub fn tool(&mut self) -> ToolCaller<'_> {
272        ToolCaller { agent: self }
273    }
274
275    /// Returns the trace attributes.
276    pub fn trace_attributes(&self) -> &std::collections::HashMap<String, String> {
277        &self.trace_attributes
278    }
279
280    /// Sets a trace attribute.
281    pub fn set_trace_attribute(&mut self, key: impl Into<String>, value: impl Into<String>) {
282        self.trace_attributes.insert(key.into(), value.into());
283    }
284
285    /// Returns the max tool calls setting.
286    pub fn max_tool_calls(&self) -> Option<usize> {
287        self.max_tool_calls
288    }
289
290    /// Sets the max tool calls.
291    pub fn set_max_tool_calls(&mut self, max: Option<usize>) {
292        self.max_tool_calls = max;
293    }
294
295    /// Invokes the agent synchronously (blocking).
296    pub fn call(&mut self, prompt: impl Into<AgentInput>) -> Result<AgentResult> {
297        tokio::task::block_in_place(|| {
298            tokio::runtime::Handle::current().block_on(self.invoke_async(prompt))
299        })
300    }
301
302    /// Invokes the agent asynchronously.
303    pub async fn invoke_async(&mut self, prompt: impl Into<AgentInput>) -> Result<AgentResult> {
304        let input = prompt.into();
305        let new_messages = self.convert_input_to_messages(input)?;
306
307        for msg in new_messages {
308            self.messages.push(msg);
309        }
310
311        self.run_event_loop().await
312    }
313
314    /// Returns a stream of events from the agent.
315    pub async fn stream_async(
316        &mut self,
317        prompt: impl Into<AgentInput>,
318    ) -> impl futures::Stream<Item = Result<crate::event_loop::TypedEvent>> + '_ {
319        let input = prompt.into();
320
321        async_stream::stream! {
322            let new_messages = match self.convert_input_to_messages(input) {
323                Ok(msgs) => msgs,
324                Err(e) => {
325                    yield Err(e);
326                    return;
327                }
328            };
329
330            for msg in new_messages {
331                self.messages.push(msg);
332            }
333
334            match self.run_event_loop().await {
335                Ok(result) => yield Ok(crate::event_loop::TypedEvent::agent_result(result)),
336                Err(e) => yield Err(e),
337            }
338        }
339    }
340
341    fn convert_input_to_messages(&self, input: AgentInput) -> Result<Messages> {
342        match input {
343            AgentInput::Text(text) => Ok(vec![Message { role: Role::User, content: vec![ContentBlock::text(text)] }]),
344            AgentInput::ContentBlocks(blocks) => Ok(vec![Message { role: Role::User, content: blocks }]),
345            AgentInput::Messages(messages) => Ok(messages),
346            AgentInput::None => Ok(vec![]),
347        }
348    }
349
350    async fn run_event_loop(&mut self) -> Result<AgentResult> {
351        use crate::hooks::{BeforeInvocationEvent, AfterInvocationEvent, HookEvent};
352        let invocation_state = InvocationState::new();
353        self.hooks.invoke(&HookEvent::BeforeInvocation(BeforeInvocationEvent)).await;
354
355        let mut structured_output_ctx = self.structured_output_context.clone();
356
357        if let Some(ref ctx) = structured_output_ctx {
358            ctx.register_tool(&mut self.tool_registry);
359        }
360
361        let result = self.event_loop_inner(&invocation_state, &mut structured_output_ctx).await;
362
363        if let Some(ref ctx) = structured_output_ctx {
364            ctx.cleanup(&mut self.tool_registry);
365        }
366
367        self.conversation_manager.apply_management(&mut self.messages);
368        let agent_result = result.as_ref().ok().cloned();
369        self.hooks.invoke(&HookEvent::AfterInvocation(AfterInvocationEvent::new(agent_result))).await;
370        result
371    }
372
373    async fn event_loop_inner(
374        &mut self,
375        invocation_state: &InvocationState,
376        structured_output_ctx: &mut Option<crate::tools::structured_output::StructuredOutputContext>,
377    ) -> Result<AgentResult> {
378        use crate::hooks::{
379            BeforeModelCallEvent, AfterModelCallEvent, HookEvent, MessageAddedEvent,
380        };
381        use crate::types::streaming::{StopReason, Usage};
382
383        loop {
384            let tool_specs = self.tool_registry.get_all_tool_specs();
385            let messages_snapshot = self.messages.clone();
386            let system_prompt_snapshot = self.system_prompt.clone();
387
388            let tool_specs_ref: Option<&[_]> = if tool_specs.is_empty() { None } else { Some(&tool_specs) };
389
390            self.hooks.invoke(&HookEvent::BeforeModelCall(BeforeModelCallEvent)).await;
391
392            let stream = self.model.stream(
393                &messages_snapshot,
394                tool_specs_ref,
395                system_prompt_snapshot.as_deref(),
396                None,
397                None,
398            );
399
400            let mut response_content: Vec<ContentBlock> = Vec::new();
401            let mut stop_reason = StopReason::EndTurn;
402            let mut usage = Usage::default();
403            let mut current_tool_use: Option<crate::types::tools::ToolUse> = None;
404            let mut tool_input_buffer = String::new();
405
406            futures::pin_mut!(stream);
407            while let Some(event_result) = stream.next().await {
408                let event = event_result?;
409
410                if let Some(ref delta_event) = event.content_block_delta {
411                    if let Some(ref delta) = delta_event.delta {
412                        if let Some(ref text) = delta.text {
413                            if let Some(block) = response_content.last_mut() {
414                                if block.text.is_some() {
415                                    block.text.as_mut().unwrap().push_str(text);
416                                } else {
417                                    response_content.push(ContentBlock::text(text));
418                                }
419                            } else {
420                                response_content.push(ContentBlock::text(text));
421                            }
422                        }
423                        if let Some(ref tool_delta) = delta.tool_use {
424                            tool_input_buffer.push_str(&tool_delta.input);
425                        }
426                    }
427                }
428
429                if let Some(ref start_event) = event.content_block_start {
430                    if let Some(ref start) = start_event.start {
431                        if let Some(ref tu) = start.tool_use {
432                            current_tool_use = Some(crate::types::tools::ToolUse {
433                                name: tu.name.clone(),
434                                tool_use_id: tu.tool_use_id.clone(),
435                                input: serde_json::Value::Null,
436                            });
437                            tool_input_buffer.clear();
438                        }
439                    }
440                }
441
442                if event.content_block_stop.is_some() {
443                    if let Some(mut tu) = current_tool_use.take() {
444                        tu.input = serde_json::from_str(&tool_input_buffer).unwrap_or(serde_json::Value::Null);
445                        response_content.push(ContentBlock::tool_use(tu));
446                        tool_input_buffer.clear();
447                    }
448                }
449
450                if let Some(ref stop_event) = event.message_stop {
451                    if let Some(sr) = stop_event.stop_reason {
452                        stop_reason = sr;
453                    }
454                }
455
456                if let Some(ref meta) = event.metadata {
457                    if let Some(ref u) = meta.usage {
458                        usage = u.clone();
459                    }
460                }
461            }
462
463            let assistant_message = Message { role: Role::Assistant, content: response_content.clone() };
464            self.messages.push(assistant_message.clone());
465
466            self.hooks.invoke(&HookEvent::MessageAdded(MessageAddedEvent::new(assistant_message.clone()))).await;
467
468            self.hooks.invoke(&HookEvent::AfterModelCall(AfterModelCallEvent::success(
469                assistant_message.clone(),
470                stop_reason.clone(),
471            ))).await;
472
473            match stop_reason {
474                StopReason::EndTurn | StopReason::StopSequence => {
475                    if let Some(ref mut ctx) = structured_output_ctx {
476                        if ctx.is_enabled() {
477                            if ctx.force_attempted {
478                                return Err(StrandsError::StructuredOutputError {
479                                    message: "The model failed to invoke the structured output tool even after it was forced.".to_string(),
480                                });
481                            }
482
483                            ctx.set_forced_mode();
484                            tracing::debug!("Forcing structured output tool");
485
486                            let force_message = Message {
487                                role: Role::User,
488                                content: vec![ContentBlock::text("You must format the previous response as structured output.")],
489                            };
490                            self.messages.push(force_message);
491
492                            continue;
493                        }
494                    }
495
496                    return Ok(AgentResult {
497                        stop_reason,
498                        message: assistant_message,
499                        usage,
500                        metrics: EventLoopMetrics::default(),
501                        state: invocation_state.clone(),
502                        interrupts: None,
503                        structured_output: None,
504                    });
505                }
506                StopReason::ToolUse => {
507                    let (tool_results, extracted_output) = self.execute_tools_with_structured_output(
508                        &response_content,
509                        invocation_state,
510                        structured_output_ctx,
511                    ).await?;
512
513                    let tool_result_message = Message {
514                        role: Role::User,
515                        content: tool_results.into_iter().map(ContentBlock::tool_result).collect(),
516                    };
517                    self.messages.push(tool_result_message.clone());
518
519                    self.hooks.invoke(&HookEvent::MessageAdded(MessageAddedEvent::new(tool_result_message))).await;
520
521                    let should_stop = invocation_state.stop_event_loop
522                        || structured_output_ctx.as_ref().map(|c| c.stop_loop).unwrap_or(false);
523
524                    if should_stop {
525                        return Ok(AgentResult {
526                            stop_reason: StopReason::EndTurn,
527                            message: assistant_message,
528                            usage,
529                            metrics: EventLoopMetrics::default(),
530                            state: invocation_state.clone(),
531                            interrupts: None,
532                            structured_output: extracted_output,
533                        });
534                    }
535                }
536                StopReason::MaxTokens => return Err(StrandsError::MaxTokensReached),
537                StopReason::ContentFiltered => return Err(StrandsError::ContentFiltered { message: "Content was filtered".to_string() }),
538                StopReason::GuardrailIntervention => return Err(StrandsError::GuardrailIntervention { message: "Guardrail intervention".to_string() }),
539                StopReason::Interrupt => return Err(StrandsError::Interrupted { message: "Agent was interrupted".to_string() }),
540            }
541        }
542    }
543
544    async fn execute_tools_with_structured_output(
545        &self,
546        content: &[ContentBlock],
547        invocation_state: &InvocationState,
548        structured_output_ctx: &mut Option<crate::tools::structured_output::StructuredOutputContext>,
549    ) -> Result<(Vec<crate::types::tools::ToolResult>, Option<serde_json::Value>)> {
550        use crate::types::tools::ToolResult;
551        use crate::tools::ToolContext;
552
553        let mut results = Vec::new();
554        let mut extracted_output: Option<serde_json::Value> = None;
555
556        let expected_tool_name = structured_output_ctx
557            .as_ref()
558            .and_then(|ctx| ctx.expected_tool_name().map(|s| s.to_string()));
559
560        for block in content {
561            if let Some(ref tool_use) = block.tool_use {
562                let tool = self.tool_registry.get(&tool_use.name);
563
564                let is_structured_output_tool = expected_tool_name
565                    .as_ref()
566                    .map(|expected| expected == &tool_use.name)
567                    .unwrap_or(false);
568
569                let result = match tool {
570                    Some(tool) => {
571                        let context = ToolContext::with_state(invocation_state.clone());
572                        match tool.invoke(tool_use.input.clone(), &context).await {
573                            Ok(r) => {
574                                if is_structured_output_tool {
575                                    if let Some(ref mut ctx) = structured_output_ctx {
576                                        ctx.store_result(&tool_use.tool_use_id, tool_use.input.clone());
577                                        ctx.stop_loop = true;
578                                        extracted_output = Some(tool_use.input.clone());
579                                        tracing::debug!(
580                                            "Extracted structured output for tool: {}",
581                                            tool_use.name
582                                        );
583                                    }
584                                }
585
586                                ToolResult {
587                                    tool_use_id: tool_use.tool_use_id.clone(),
588                                    status: r.status,
589                                    content: r.content,
590                                }
591                            }
592                            Err(e) => ToolResult::error(&tool_use.tool_use_id, e),
593                        }
594                    }
595                    None => ToolResult::error(&tool_use.tool_use_id, format!("Tool not found: {}", tool_use.name)),
596                };
597
598                results.push(result);
599            }
600        }
601
602        Ok((results, extracted_output))
603    }
604}