strands_agents/models/
bedrock.rs

1//! Amazon Bedrock model provider.
2
3use async_trait::async_trait;
4use aws_config::BehaviorVersion;
5use aws_sdk_bedrockruntime::types::{
6    ContentBlock as BedrockContentBlock, ConversationRole, InferenceConfiguration, Message as BedrockMessage,
7    SystemContentBlock as BedrockSystemContentBlock, Tool, ToolConfiguration, ToolInputSchema, ToolSpecification,
8};
9use aws_sdk_bedrockruntime::Client;
10use aws_smithy_types::Document;
11
12use super::{Model, ModelConfig, StreamEventStream};
13use crate::types::{
14    content::{ContentBlock, Message, Role, SystemContentBlock},
15    errors::StrandsError,
16    streaming::{
17        ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockDeltaToolUse, ContentBlockStart,
18        ContentBlockStartEvent, ContentBlockStartToolUse, ContentBlockStopEvent, MessageStartEvent,
19        MessageStopEvent, MetadataEvent, Metrics, StopReason, StreamEvent, Usage,
20    },
21    tools::{ToolChoice, ToolSpec},
22};
23
24const DEFAULT_MODEL_ID: &str = "us.anthropic.claude-sonnet-4-20250514-v1:0";
25
26/// Amazon Bedrock model provider.
27#[derive(Debug, Clone)]
28pub struct BedrockModel {
29    config: ModelConfig,
30    region: Option<String>,
31}
32
33impl BedrockModel {
34    pub fn new(model_id: impl Into<String>) -> Self {
35        Self {
36            config: ModelConfig::new(model_id),
37            region: None,
38        }
39    }
40
41    pub fn with_config(config: ModelConfig) -> Self {
42        Self { config, region: None }
43    }
44
45    pub fn with_region(mut self, region: impl Into<String>) -> Self {
46        self.region = Some(region.into());
47        self
48    }
49
50    fn format_messages(&self, messages: &[Message]) -> Vec<BedrockMessage> {
51        messages
52            .iter()
53            .map(|msg| {
54                let role = match msg.role {
55                    Role::User => ConversationRole::User,
56                    Role::Assistant => ConversationRole::Assistant,
57                };
58
59                let content_blocks: Vec<BedrockContentBlock> = msg
60                    .content
61                    .iter()
62                    .filter_map(|block| self.format_content_block(block))
63                    .collect();
64
65                BedrockMessage::builder()
66                    .role(role)
67                    .set_content(Some(content_blocks))
68                    .build()
69                    .expect("valid message")
70            })
71            .collect()
72    }
73
74    fn json_to_document(value: &serde_json::Value) -> Document {
75        match value {
76            serde_json::Value::Null => Document::Null,
77            serde_json::Value::Bool(b) => Document::Bool(*b),
78            serde_json::Value::Number(n) => {
79                if let Some(i) = n.as_i64() {
80                    Document::Number(aws_smithy_types::Number::NegInt(i))
81                } else if let Some(f) = n.as_f64() {
82                    Document::Number(aws_smithy_types::Number::Float(f))
83                } else {
84                    Document::Null
85                }
86            }
87            serde_json::Value::String(s) => Document::String(s.clone()),
88            serde_json::Value::Array(arr) => {
89                Document::Array(arr.iter().map(Self::json_to_document).collect())
90            }
91            serde_json::Value::Object(obj) => {
92                Document::Object(obj.iter().map(|(k, v)| (k.clone(), Self::json_to_document(v))).collect())
93            }
94        }
95    }
96
97    fn format_content_block(&self, block: &ContentBlock) -> Option<BedrockContentBlock> {
98        if let Some(ref text) = block.text {
99            return Some(BedrockContentBlock::Text(text.clone()));
100        }
101
102        if let Some(ref tool_use) = block.tool_use {
103            let input_doc = Self::json_to_document(&tool_use.input);
104
105            return Some(BedrockContentBlock::ToolUse(
106                aws_sdk_bedrockruntime::types::ToolUseBlock::builder()
107                    .tool_use_id(&tool_use.tool_use_id)
108                    .name(&tool_use.name)
109                    .input(input_doc)
110                    .build()
111                    .expect("valid tool use"),
112            ));
113        }
114
115        if let Some(ref tool_result) = block.tool_result {
116            let content: Vec<aws_sdk_bedrockruntime::types::ToolResultContentBlock> = tool_result
117                .content
118                .iter()
119                .filter_map(|c| {
120                    if let Some(ref text) = c.text {
121                        Some(aws_sdk_bedrockruntime::types::ToolResultContentBlock::Text(text.clone()))
122                    } else if let Some(ref json_val) = c.json {
123                        Some(aws_sdk_bedrockruntime::types::ToolResultContentBlock::Text(
124                            serde_json::to_string(json_val).unwrap_or_default(),
125                        ))
126                    } else {
127                        None
128                    }
129                })
130                .collect();
131
132            let status = match tool_result.status {
133                crate::types::tools::ToolResultStatus::Success => {
134                    aws_sdk_bedrockruntime::types::ToolResultStatus::Success
135                }
136                crate::types::tools::ToolResultStatus::Error => {
137                    aws_sdk_bedrockruntime::types::ToolResultStatus::Error
138                }
139            };
140
141            return Some(BedrockContentBlock::ToolResult(
142                aws_sdk_bedrockruntime::types::ToolResultBlock::builder()
143                    .tool_use_id(&tool_result.tool_use_id)
144                    .set_content(Some(content))
145                    .status(status)
146                    .build()
147                    .expect("valid tool result"),
148            ));
149        }
150
151        None
152    }
153
154    fn format_tool_specs(&self, tool_specs: &[ToolSpec]) -> Vec<Tool> {
155        tool_specs
156            .iter()
157            .map(|spec| {
158                let input_schema_doc = Self::json_to_document(&spec.input_schema.json);
159
160                Tool::ToolSpec(
161                    ToolSpecification::builder()
162                        .name(&spec.name)
163                        .description(&spec.description)
164                        .input_schema(ToolInputSchema::Json(input_schema_doc))
165                        .build()
166                        .expect("valid tool spec"),
167                )
168            })
169            .collect()
170    }
171
172    fn format_system_prompt(&self, system_prompt: Option<&str>) -> Option<Vec<BedrockSystemContentBlock>> {
173        system_prompt.map(|s| vec![BedrockSystemContentBlock::Text(s.to_string())])
174    }
175
176    fn map_stop_reason(reason: &aws_sdk_bedrockruntime::types::StopReason) -> StopReason {
177        match reason {
178            aws_sdk_bedrockruntime::types::StopReason::EndTurn => StopReason::EndTurn,
179            aws_sdk_bedrockruntime::types::StopReason::ToolUse => StopReason::ToolUse,
180            aws_sdk_bedrockruntime::types::StopReason::MaxTokens => StopReason::MaxTokens,
181            aws_sdk_bedrockruntime::types::StopReason::StopSequence => StopReason::StopSequence,
182            aws_sdk_bedrockruntime::types::StopReason::ContentFiltered => StopReason::ContentFiltered,
183            aws_sdk_bedrockruntime::types::StopReason::GuardrailIntervened => StopReason::GuardrailIntervention,
184            _ => StopReason::EndTurn,
185        }
186    }
187}
188
189impl Default for BedrockModel {
190    fn default() -> Self {
191        Self::new(DEFAULT_MODEL_ID)
192    }
193}
194
195#[async_trait]
196impl Model for BedrockModel {
197    fn config(&self) -> &ModelConfig {
198        &self.config
199    }
200
201    fn update_config(&mut self, config: ModelConfig) {
202        self.config = config;
203    }
204
205    fn stream<'a>(
206        &'a self,
207        messages: &'a [Message],
208        tool_specs: Option<&'a [ToolSpec]>,
209        system_prompt: Option<&'a str>,
210        _tool_choice: Option<ToolChoice>,
211        _system_prompt_content: Option<&'a [SystemContentBlock]>,
212    ) -> StreamEventStream<'a> {
213        let model_id = self.config.model_id.clone();
214        let max_tokens = self.config.max_tokens.unwrap_or(4096);
215        let temperature = self.config.temperature;
216        let top_p = self.config.top_p;
217        let stop_sequences = self.config.stop_sequences.clone();
218
219        let formatted_messages = self.format_messages(messages);
220        let formatted_tools = tool_specs.map(|specs| self.format_tool_specs(specs));
221        let formatted_system = self.format_system_prompt(system_prompt);
222        let region = self.region.clone();
223
224        Box::pin(async_stream::stream! {
225            let mut config_loader = aws_config::defaults(BehaviorVersion::latest());
226            if let Some(ref r) = region {
227                config_loader = config_loader.region(aws_config::Region::new(r.clone()));
228            }
229            let sdk_config = config_loader.load().await;
230            let client = Client::new(&sdk_config);
231
232            let mut inference_config = InferenceConfiguration::builder().max_tokens(max_tokens as i32);
233
234            if let Some(temp) = temperature {
235                inference_config = inference_config.temperature(temp);
236            }
237            if let Some(p) = top_p {
238                inference_config = inference_config.top_p(p);
239            }
240            if let Some(ref seqs) = stop_sequences {
241                inference_config = inference_config.set_stop_sequences(Some(seqs.clone()));
242            }
243
244            let mut request = client
245                .converse_stream()
246                .model_id(&model_id)
247                .set_messages(Some(formatted_messages))
248                .inference_config(inference_config.build());
249
250            if let Some(system) = formatted_system {
251                request = request.set_system(Some(system));
252            }
253
254            if let Some(tools) = formatted_tools {
255                request = request.tool_config(
256                    ToolConfiguration::builder()
257                        .set_tools(Some(tools))
258                        .build()
259                        .expect("valid tool config"),
260                );
261            }
262
263            let response = match request.send().await {
264                Ok(resp) => resp,
265                Err(e) => {
266                    let err_msg = e.to_string();
267                    if err_msg.contains("ThrottlingException") || err_msg.contains("throttlingException") {
268                        yield Err(StrandsError::ModelThrottled { message: err_msg });
269                    } else if err_msg.contains("Input is too long") || err_msg.contains("context limit") {
270                        yield Err(StrandsError::ContextWindowOverflow { message: err_msg });
271                    } else {
272                        yield Err(StrandsError::model_error(err_msg));
273                    }
274                    return;
275                }
276            };
277
278            let mut stream = response.stream;
279            let mut has_tool_use = false;
280
281            while let Ok(Some(event)) = stream.recv().await {
282                match event {
283                    aws_sdk_bedrockruntime::types::ConverseStreamOutput::MessageStart(msg) => {
284                        let _role = msg.role;
285                        yield Ok(StreamEvent {
286                            message_start: Some(MessageStartEvent { role: Role::Assistant }),
287                            ..Default::default()
288                        });
289                    }
290
291                    aws_sdk_bedrockruntime::types::ConverseStreamOutput::ContentBlockStart(start) => {
292                        let content_block_index = start.content_block_index as u32;
293                        let block_start = if let Some(ref s) = start.start {
294                            match s {
295                                aws_sdk_bedrockruntime::types::ContentBlockStart::ToolUse(tu) => {
296                                    has_tool_use = true;
297                                    Some(ContentBlockStart {
298                                        tool_use: Some(ContentBlockStartToolUse {
299                                            name: tu.name.clone(),
300                                            tool_use_id: tu.tool_use_id.clone(),
301                                        }),
302                                    })
303                                }
304                                _ => None,
305                            }
306                        } else {
307                            None
308                        };
309
310                        yield Ok(StreamEvent {
311                            content_block_start: Some(ContentBlockStartEvent {
312                                content_block_index: Some(content_block_index),
313                                start: block_start,
314                            }),
315                            ..Default::default()
316                        });
317                    }
318
319                    aws_sdk_bedrockruntime::types::ConverseStreamOutput::ContentBlockDelta(delta) => {
320                        if let Some(ref d) = delta.delta {
321                            let block_delta = match d {
322                                aws_sdk_bedrockruntime::types::ContentBlockDelta::Text(text) => {
323                                    ContentBlockDelta {
324                                        text: Some(text.clone()),
325                                        ..Default::default()
326                                    }
327                                }
328                                aws_sdk_bedrockruntime::types::ContentBlockDelta::ToolUse(tu) => {
329                                    ContentBlockDelta {
330                                        tool_use: Some(ContentBlockDeltaToolUse {
331                                            input: tu.input.clone(),
332                                        }),
333                                        ..Default::default()
334                                    }
335                                }
336                                _ => ContentBlockDelta::default(),
337                            };
338
339                            yield Ok(StreamEvent {
340                                content_block_delta: Some(ContentBlockDeltaEvent {
341                                    content_block_index: Some(delta.content_block_index as u32),
342                                    delta: Some(block_delta),
343                                }),
344                                ..Default::default()
345                            });
346                        }
347                    }
348
349                    aws_sdk_bedrockruntime::types::ConverseStreamOutput::ContentBlockStop(stop) => {
350                        yield Ok(StreamEvent {
351                            content_block_stop: Some(ContentBlockStopEvent {
352                                content_block_index: Some(stop.content_block_index as u32),
353                            }),
354                            ..Default::default()
355                        });
356                    }
357
358                    aws_sdk_bedrockruntime::types::ConverseStreamOutput::MessageStop(stop) => {
359                        let mut stop_reason = Self::map_stop_reason(&stop.stop_reason);
360
361                        if has_tool_use && stop_reason == StopReason::EndTurn {
362                            stop_reason = StopReason::ToolUse;
363                        }
364
365                        yield Ok(StreamEvent {
366                            message_stop: Some(MessageStopEvent {
367                                stop_reason: Some(stop_reason),
368                                additional_model_response_fields: None,
369                            }),
370                            ..Default::default()
371                        });
372                    }
373
374                    aws_sdk_bedrockruntime::types::ConverseStreamOutput::Metadata(meta) => {
375                        let usage = meta.usage.map(|u| Usage {
376                            input_tokens: u.input_tokens as u32,
377                            output_tokens: u.output_tokens as u32,
378                            total_tokens: (u.input_tokens + u.output_tokens) as u32,
379                            cache_read_input_tokens: 0,
380                            cache_write_input_tokens: 0,
381                        });
382
383                        let metrics = meta.metrics.map(|m| Metrics {
384                            latency_ms: m.latency_ms as u64,
385                            time_to_first_byte_ms: 0,
386                        });
387
388                        yield Ok(StreamEvent {
389                            metadata: Some(MetadataEvent {
390                                usage,
391                                metrics,
392                                trace: None,
393                            }),
394                            ..Default::default()
395                        });
396                    }
397
398                    _ => {}
399                }
400            }
401        })
402    }
403}
404
405#[cfg(test)]
406mod tests {
407    use super::*;
408
409    #[test]
410    fn test_bedrock_model_creation() {
411        let model = BedrockModel::new("anthropic.claude-3-sonnet-20240229-v1:0");
412        assert_eq!(model.config().model_id, "anthropic.claude-3-sonnet-20240229-v1:0");
413    }
414
415    #[test]
416    fn test_bedrock_model_default() {
417        let model = BedrockModel::default();
418        assert!(model.config().model_id.contains("claude"));
419    }
420
421    #[test]
422    fn test_bedrock_with_region() {
423        let model = BedrockModel::default().with_region("us-east-1");
424        assert_eq!(model.region, Some("us-east-1".to_string()));
425    }
426
427    #[test]
428    fn test_json_to_document() {
429        let json = serde_json::json!({"key": "value", "num": 42});
430        let doc = BedrockModel::json_to_document(&json);
431        match doc {
432            Document::Object(map) => {
433                assert!(map.contains_key("key"));
434                assert!(map.contains_key("num"));
435            }
436            _ => panic!("expected object"),
437        }
438    }
439}