1mod builder;
4mod result;
5mod state;
6
7use std::sync::Arc;
8
9use futures::StreamExt;
10
11use crate::conversation::ConversationManager;
12use crate::hooks::HookRegistry;
13use crate::models::Model;
14use crate::telemetry::EventLoopMetrics;
15use crate::tools::{InvocationState, ToolRegistry};
16use crate::types::content::{ContentBlock, Message, Messages, Role};
17use crate::types::errors::{Result, StrandsError};
18
19pub use builder::AgentBuilder;
20pub use result::AgentResult;
21pub use state::AgentState;
22
23pub enum AgentInput {
25 Text(String),
26 ContentBlocks(Vec<ContentBlock>),
27 Messages(Messages),
28 None,
29}
30
31impl From<&str> for AgentInput {
32 fn from(s: &str) -> Self { AgentInput::Text(s.to_string()) }
33}
34
35impl From<String> for AgentInput {
36 fn from(s: String) -> Self { AgentInput::Text(s) }
37}
38
39impl From<Vec<ContentBlock>> for AgentInput {
40 fn from(blocks: Vec<ContentBlock>) -> Self { AgentInput::ContentBlocks(blocks) }
41}
42
43impl From<Messages> for AgentInput {
44 fn from(messages: Messages) -> Self { AgentInput::Messages(messages) }
45}
46
47impl<T: Into<String>> From<Option<T>> for AgentInput {
48 fn from(opt: Option<T>) -> Self {
49 match opt {
50 Some(s) => AgentInput::Text(s.into()),
51 None => AgentInput::None,
52 }
53 }
54}
55
56pub struct ToolCaller<'a> {
58 agent: &'a mut Agent,
59}
60
61impl<'a> ToolCaller<'a> {
62 pub async fn invoke(
64 &mut self,
65 tool_name: &str,
66 input: serde_json::Value,
67 ) -> Result<crate::types::tools::ToolResult> {
68 self.invoke_with_options(tool_name, input, None, None).await
69 }
70
71 pub async fn invoke_with_options(
73 &mut self,
74 tool_name: &str,
75 input: serde_json::Value,
76 user_message_override: Option<&str>,
77 record_direct_tool_call: Option<bool>,
78 ) -> Result<crate::types::tools::ToolResult> {
79 use crate::types::tools::{ToolResult, ToolUse};
80 use crate::tools::ToolContext;
81
82 if self.agent.interrupt_state.activated {
83 return Err(StrandsError::EventLoopError {
84 message: "cannot directly call tool during interrupt".to_string(),
85 });
86 }
87
88 let tool = self.agent.tool_registry.get(tool_name)
89 .ok_or_else(|| StrandsError::ToolNotFound {
90 tool_name: tool_name.to_string(),
91 })?;
92
93 let tool_id = format!("tooluse_{}_{}", tool_name, uuid::Uuid::new_v4());
94 let tool_use = ToolUse {
95 name: tool_name.to_string(),
96 tool_use_id: tool_id.clone(),
97 input: input.clone(),
98 };
99
100 let context = ToolContext::with_state(InvocationState::new());
101 let result = match tool.invoke(input.clone(), &context).await {
102 Ok(r) => ToolResult {
103 tool_use_id: tool_id.clone(),
104 status: r.status,
105 content: r.content,
106 },
107 Err(e) => ToolResult::error(&tool_id, e),
108 };
109
110 let should_record = record_direct_tool_call
111 .unwrap_or(self.agent.record_direct_tool_call);
112
113 if should_record {
114 self.record_tool_execution(&tool_use, &result, user_message_override).await?;
115 }
116
117 self.agent.conversation_manager.apply_management(&mut self.agent.messages);
118
119 Ok(result)
120 }
121
122 async fn record_tool_execution(
123 &mut self,
124 tool_use: &crate::types::tools::ToolUse,
125 tool_result: &crate::types::tools::ToolResult,
126 user_message_override: Option<&str>,
127 ) -> Result<()> {
128 let input_json = serde_json::to_string(&tool_use.input)
129 .unwrap_or_else(|_| "<<non-serializable>>".to_string());
130
131 let mut user_content = Vec::new();
132 if let Some(msg) = user_message_override {
133 user_content.push(ContentBlock::text(format!("{}\n", msg)));
134 }
135 user_content.push(ContentBlock::text(format!(
136 "agent.tool.{} direct tool call.\nInput parameters: {}\n",
137 tool_use.name, input_json
138 )));
139
140 let user_msg = Message { role: Role::User, content: user_content };
141 let tool_use_msg = Message {
142 role: Role::Assistant,
143 content: vec![ContentBlock::tool_use(tool_use.clone())],
144 };
145 let tool_result_msg = Message {
146 role: Role::User,
147 content: vec![ContentBlock::tool_result(tool_result.clone())],
148 };
149 let assistant_msg = Message {
150 role: Role::Assistant,
151 content: vec![ContentBlock::text(format!("agent.tool.{} was called.", tool_use.name))],
152 };
153
154 self.agent.messages.push(user_msg);
155 self.agent.messages.push(tool_use_msg);
156 self.agent.messages.push(tool_result_msg);
157 self.agent.messages.push(assistant_msg);
158
159 Ok(())
160 }
161}
162
163pub struct Agent {
165 pub(crate) model: Arc<dyn Model>,
166 pub(crate) messages: Messages,
167 pub(crate) system_prompt: Option<String>,
168 pub(crate) tool_registry: ToolRegistry,
169 agent_name: Option<String>,
170 pub agent_id: String,
171 pub description: Option<String>,
172 pub state: AgentState,
173 pub(crate) hooks: HookRegistry,
174 pub(crate) conversation_manager: Box<dyn ConversationManager>,
175 interrupt_state: crate::types::interrupt::InterruptState,
176 pub record_direct_tool_call: bool,
178 pub trace_attributes: std::collections::HashMap<String, String>,
180 pub max_tool_calls: Option<usize>,
182 pub(crate) structured_output_context: Option<crate::tools::structured_output::StructuredOutputContext>,
184}
185
186impl Agent {
187 pub fn builder() -> AgentBuilder { AgentBuilder::new() }
189
190 pub fn name(&self) -> Option<&String> { self.agent_name.as_ref() }
192
193 pub fn set_name(&mut self, name: impl Into<String>) {
195 self.agent_name = Some(name.into());
196 }
197
198 pub fn system_prompt(&self) -> Option<&str> { self.system_prompt.as_deref() }
200
201 pub fn set_system_prompt(&mut self, prompt: impl Into<String>) {
203 self.system_prompt = Some(prompt.into());
204 }
205
206 pub fn messages(&self) -> &Messages { &self.messages }
208
209 pub fn add_message(&mut self, message: Message) {
211 self.messages.push(message);
212 }
213
214 pub fn clear_messages(&mut self) {
216 self.messages.clear();
217 }
218
219 pub fn tool_registry(&self) -> &ToolRegistry { &self.tool_registry }
221
222 pub fn tool_registry_mut(&mut self) -> &mut ToolRegistry { &mut self.tool_registry }
224
225 pub fn tool_names(&self) -> Vec<&str> { self.tool_registry.tool_names() }
227
228 pub fn agent_id(&self) -> Option<&str> { Some(&self.agent_id) }
230
231 pub fn hooks(&self) -> &HookRegistry { &self.hooks }
233
234 pub fn hooks_mut(&mut self) -> &mut HookRegistry { &mut self.hooks }
236
237 pub fn conversation_manager(&self) -> &dyn ConversationManager { self.conversation_manager.as_ref() }
239
240 pub fn conversation_manager_mut(&mut self) -> &mut dyn ConversationManager { self.conversation_manager.as_mut() }
242
243 pub fn state(&self) -> &AgentState { &self.state }
245
246 pub fn state_mut(&mut self) -> &mut AgentState { &mut self.state }
248
249 pub fn interrupt_state(&self) -> &crate::types::interrupt::InterruptState { &self.interrupt_state }
251
252 pub fn interrupt_state_mut(&mut self) -> &mut crate::types::interrupt::InterruptState { &mut self.interrupt_state }
254
255 pub fn set_interrupt_state(&mut self, state: crate::types::interrupt::InterruptState) {
257 self.interrupt_state = state;
258 }
259
260 pub fn is_interrupted(&self) -> bool { self.interrupt_state.activated }
262
263 pub fn set_messages(&mut self, messages: Messages) {
265 self.messages = messages;
266 }
267
268 pub fn tool(&mut self) -> ToolCaller<'_> {
272 ToolCaller { agent: self }
273 }
274
275 pub fn trace_attributes(&self) -> &std::collections::HashMap<String, String> {
277 &self.trace_attributes
278 }
279
280 pub fn set_trace_attribute(&mut self, key: impl Into<String>, value: impl Into<String>) {
282 self.trace_attributes.insert(key.into(), value.into());
283 }
284
285 pub fn max_tool_calls(&self) -> Option<usize> {
287 self.max_tool_calls
288 }
289
290 pub fn set_max_tool_calls(&mut self, max: Option<usize>) {
292 self.max_tool_calls = max;
293 }
294
295 pub fn call(&mut self, prompt: impl Into<AgentInput>) -> Result<AgentResult> {
297 tokio::task::block_in_place(|| {
298 tokio::runtime::Handle::current().block_on(self.invoke_async(prompt))
299 })
300 }
301
302 pub async fn invoke_async(&mut self, prompt: impl Into<AgentInput>) -> Result<AgentResult> {
304 let input = prompt.into();
305 let new_messages = self.convert_input_to_messages(input)?;
306
307 for msg in new_messages {
308 self.messages.push(msg);
309 }
310
311 self.run_event_loop().await
312 }
313
314 pub async fn stream_async(
316 &mut self,
317 prompt: impl Into<AgentInput>,
318 ) -> impl futures::Stream<Item = Result<crate::event_loop::TypedEvent>> + '_ {
319 let input = prompt.into();
320
321 async_stream::stream! {
322 let new_messages = match self.convert_input_to_messages(input) {
323 Ok(msgs) => msgs,
324 Err(e) => {
325 yield Err(e);
326 return;
327 }
328 };
329
330 for msg in new_messages {
331 self.messages.push(msg);
332 }
333
334 match self.run_event_loop().await {
335 Ok(result) => yield Ok(crate::event_loop::TypedEvent::agent_result(result)),
336 Err(e) => yield Err(e),
337 }
338 }
339 }
340
341 fn convert_input_to_messages(&self, input: AgentInput) -> Result<Messages> {
342 match input {
343 AgentInput::Text(text) => Ok(vec![Message { role: Role::User, content: vec![ContentBlock::text(text)] }]),
344 AgentInput::ContentBlocks(blocks) => Ok(vec![Message { role: Role::User, content: blocks }]),
345 AgentInput::Messages(messages) => Ok(messages),
346 AgentInput::None => Ok(vec![]),
347 }
348 }
349
350 async fn run_event_loop(&mut self) -> Result<AgentResult> {
351 use crate::hooks::{BeforeInvocationEvent, AfterInvocationEvent, HookEvent};
352 let invocation_state = InvocationState::new();
353 self.hooks.invoke(&HookEvent::BeforeInvocation(BeforeInvocationEvent)).await;
354
355 let mut structured_output_ctx = self.structured_output_context.clone();
356
357 if let Some(ref ctx) = structured_output_ctx {
358 ctx.register_tool(&mut self.tool_registry);
359 }
360
361 let result = self.event_loop_inner(&invocation_state, &mut structured_output_ctx).await;
362
363 if let Some(ref ctx) = structured_output_ctx {
364 ctx.cleanup(&mut self.tool_registry);
365 }
366
367 self.conversation_manager.apply_management(&mut self.messages);
368 let agent_result = result.as_ref().ok().cloned();
369 self.hooks.invoke(&HookEvent::AfterInvocation(AfterInvocationEvent::new(agent_result))).await;
370 result
371 }
372
373 async fn event_loop_inner(
374 &mut self,
375 invocation_state: &InvocationState,
376 structured_output_ctx: &mut Option<crate::tools::structured_output::StructuredOutputContext>,
377 ) -> Result<AgentResult> {
378 use crate::hooks::{
379 BeforeModelCallEvent, AfterModelCallEvent, HookEvent, MessageAddedEvent,
380 };
381 use crate::types::streaming::{StopReason, Usage};
382
383 loop {
384 let tool_specs = self.tool_registry.get_all_tool_specs();
385 let messages_snapshot = self.messages.clone();
386 let system_prompt_snapshot = self.system_prompt.clone();
387
388 let tool_specs_ref: Option<&[_]> = if tool_specs.is_empty() { None } else { Some(&tool_specs) };
389
390 self.hooks.invoke(&HookEvent::BeforeModelCall(BeforeModelCallEvent)).await;
391
392 let stream = self.model.stream(
393 &messages_snapshot,
394 tool_specs_ref,
395 system_prompt_snapshot.as_deref(),
396 None,
397 None,
398 );
399
400 let mut response_content: Vec<ContentBlock> = Vec::new();
401 let mut stop_reason = StopReason::EndTurn;
402 let mut usage = Usage::default();
403 let mut current_tool_use: Option<crate::types::tools::ToolUse> = None;
404 let mut tool_input_buffer = String::new();
405
406 futures::pin_mut!(stream);
407 while let Some(event_result) = stream.next().await {
408 let event = event_result?;
409
410 if let Some(ref delta_event) = event.content_block_delta {
411 if let Some(ref delta) = delta_event.delta {
412 if let Some(ref text) = delta.text {
413 if let Some(block) = response_content.last_mut() {
414 if block.text.is_some() {
415 block.text.as_mut().unwrap().push_str(text);
416 } else {
417 response_content.push(ContentBlock::text(text));
418 }
419 } else {
420 response_content.push(ContentBlock::text(text));
421 }
422 }
423 if let Some(ref tool_delta) = delta.tool_use {
424 tool_input_buffer.push_str(&tool_delta.input);
425 }
426 }
427 }
428
429 if let Some(ref start_event) = event.content_block_start {
430 if let Some(ref start) = start_event.start {
431 if let Some(ref tu) = start.tool_use {
432 current_tool_use = Some(crate::types::tools::ToolUse {
433 name: tu.name.clone(),
434 tool_use_id: tu.tool_use_id.clone(),
435 input: serde_json::Value::Null,
436 });
437 tool_input_buffer.clear();
438 }
439 }
440 }
441
442 if event.content_block_stop.is_some() {
443 if let Some(mut tu) = current_tool_use.take() {
444 tu.input = serde_json::from_str(&tool_input_buffer).unwrap_or(serde_json::Value::Null);
445 response_content.push(ContentBlock::tool_use(tu));
446 tool_input_buffer.clear();
447 }
448 }
449
450 if let Some(ref stop_event) = event.message_stop {
451 if let Some(sr) = stop_event.stop_reason {
452 stop_reason = sr;
453 }
454 }
455
456 if let Some(ref meta) = event.metadata {
457 if let Some(ref u) = meta.usage {
458 usage = u.clone();
459 }
460 }
461 }
462
463 let assistant_message = Message { role: Role::Assistant, content: response_content.clone() };
464 self.messages.push(assistant_message.clone());
465
466 self.hooks.invoke(&HookEvent::MessageAdded(MessageAddedEvent::new(assistant_message.clone()))).await;
467
468 self.hooks.invoke(&HookEvent::AfterModelCall(AfterModelCallEvent::success(
469 assistant_message.clone(),
470 stop_reason.clone(),
471 ))).await;
472
473 match stop_reason {
474 StopReason::EndTurn | StopReason::StopSequence => {
475 if let Some(ref mut ctx) = structured_output_ctx {
476 if ctx.is_enabled() {
477 if ctx.force_attempted {
478 return Err(StrandsError::StructuredOutputError {
479 message: "The model failed to invoke the structured output tool even after it was forced.".to_string(),
480 });
481 }
482
483 ctx.set_forced_mode();
484 tracing::debug!("Forcing structured output tool");
485
486 let force_message = Message {
487 role: Role::User,
488 content: vec![ContentBlock::text("You must format the previous response as structured output.")],
489 };
490 self.messages.push(force_message);
491
492 continue;
493 }
494 }
495
496 return Ok(AgentResult {
497 stop_reason,
498 message: assistant_message,
499 usage,
500 metrics: EventLoopMetrics::default(),
501 state: invocation_state.clone(),
502 interrupts: None,
503 structured_output: None,
504 });
505 }
506 StopReason::ToolUse => {
507 let (tool_results, extracted_output) = self.execute_tools_with_structured_output(
508 &response_content,
509 invocation_state,
510 structured_output_ctx,
511 ).await?;
512
513 let tool_result_message = Message {
514 role: Role::User,
515 content: tool_results.into_iter().map(ContentBlock::tool_result).collect(),
516 };
517 self.messages.push(tool_result_message.clone());
518
519 self.hooks.invoke(&HookEvent::MessageAdded(MessageAddedEvent::new(tool_result_message))).await;
520
521 let should_stop = invocation_state.stop_event_loop
522 || structured_output_ctx.as_ref().map(|c| c.stop_loop).unwrap_or(false);
523
524 if should_stop {
525 return Ok(AgentResult {
526 stop_reason: StopReason::EndTurn,
527 message: assistant_message,
528 usage,
529 metrics: EventLoopMetrics::default(),
530 state: invocation_state.clone(),
531 interrupts: None,
532 structured_output: extracted_output,
533 });
534 }
535 }
536 StopReason::MaxTokens => return Err(StrandsError::MaxTokensReached),
537 StopReason::ContentFiltered => return Err(StrandsError::ContentFiltered { message: "Content was filtered".to_string() }),
538 StopReason::GuardrailIntervention => return Err(StrandsError::GuardrailIntervention { message: "Guardrail intervention".to_string() }),
539 StopReason::Interrupt => return Err(StrandsError::Interrupted { message: "Agent was interrupted".to_string() }),
540 }
541 }
542 }
543
544 async fn execute_tools_with_structured_output(
545 &self,
546 content: &[ContentBlock],
547 invocation_state: &InvocationState,
548 structured_output_ctx: &mut Option<crate::tools::structured_output::StructuredOutputContext>,
549 ) -> Result<(Vec<crate::types::tools::ToolResult>, Option<serde_json::Value>)> {
550 use crate::types::tools::ToolResult;
551 use crate::tools::ToolContext;
552
553 let mut results = Vec::new();
554 let mut extracted_output: Option<serde_json::Value> = None;
555
556 let expected_tool_name = structured_output_ctx
557 .as_ref()
558 .and_then(|ctx| ctx.expected_tool_name().map(|s| s.to_string()));
559
560 for block in content {
561 if let Some(ref tool_use) = block.tool_use {
562 let tool = self.tool_registry.get(&tool_use.name);
563
564 let is_structured_output_tool = expected_tool_name
565 .as_ref()
566 .map(|expected| expected == &tool_use.name)
567 .unwrap_or(false);
568
569 let result = match tool {
570 Some(tool) => {
571 let context = ToolContext::with_state(invocation_state.clone());
572 match tool.invoke(tool_use.input.clone(), &context).await {
573 Ok(r) => {
574 if is_structured_output_tool {
575 if let Some(ref mut ctx) = structured_output_ctx {
576 ctx.store_result(&tool_use.tool_use_id, tool_use.input.clone());
577 ctx.stop_loop = true;
578 extracted_output = Some(tool_use.input.clone());
579 tracing::debug!(
580 "Extracted structured output for tool: {}",
581 tool_use.name
582 );
583 }
584 }
585
586 ToolResult {
587 tool_use_id: tool_use.tool_use_id.clone(),
588 status: r.status,
589 content: r.content,
590 }
591 }
592 Err(e) => ToolResult::error(&tool_use.tool_use_id, e),
593 }
594 }
595 None => ToolResult::error(&tool_use.tool_use_id, format!("Tool not found: {}", tool_use.name)),
596 };
597
598 results.push(result);
599 }
600 }
601
602 Ok((results, extracted_output))
603 }
604}