Skip to main content

daimon_provider_bedrock/
lib.rs

1//! Amazon Bedrock model provider for the [Daimon](https://docs.rs/daimon) agent framework.
2//!
3//! Supports the Bedrock Converse API for non-streaming and streaming inference,
4//! with optional guardrails, prompt caching, and configurable retries.
5//!
6//! # Example
7//!
8//! ```ignore
9//! use daimon_provider_bedrock::Bedrock;
10//! use daimon_core::Model;
11//!
12//! let model = Bedrock::new("us.anthropic.claude-sonnet-4-20250514")
13//!     .with_region("us-east-1")
14//!     .with_prompt_caching();
15//! ```
16
17use std::time::Duration;
18
19use aws_sdk_bedrockruntime::Client as BedrockClient;
20use aws_sdk_bedrockruntime::types::{
21    CachePointBlock, CachePointType, ContentBlock, ConversationRole, GuardrailConfiguration,
22    GuardrailStreamConfiguration, InferenceConfiguration, Message as BedrockMessage,
23    SystemContentBlock, ToolConfiguration, ToolInputSchema, ToolResultBlock,
24    ToolResultContentBlock, ToolResultStatus, ToolSpecification, ToolUseBlock,
25};
26
27mod embedding;
28
29#[cfg(feature = "sqs")]
30pub mod sqs;
31
32pub use embedding::BedrockEmbedding;
33
34#[cfg(feature = "sqs")]
35pub use sqs::SqsBroker;
36
37use daimon_core::{
38    ChatRequest, ChatResponse, DaimonError, Message, Model, ResponseStream, Result, Role,
39    StopReason, StreamEvent, ToolCall, Usage,
40};
41
42/// Amazon Bedrock model provider using the Converse API.
43///
44/// Supports both non-streaming and streaming inference, with optional
45/// guardrails for content filtering and configurable retry behavior.
46#[derive(Debug)]
47pub struct Bedrock {
48    model_id: String,
49    client: Option<BedrockClient>,
50    region: Option<String>,
51    max_retries: u32,
52    guardrail_id: Option<String>,
53    guardrail_version: Option<String>,
54    use_prompt_caching: bool,
55}
56
57impl Bedrock {
58    /// Creates a new Bedrock provider for the given model ID.
59    pub fn new(model_id: impl Into<String>) -> Self {
60        Self {
61            model_id: model_id.into(),
62            client: None,
63            region: None,
64            max_retries: 3,
65            guardrail_id: None,
66            guardrail_version: None,
67            use_prompt_caching: false,
68        }
69    }
70
71    /// Sets the Bedrock client to use (otherwise created from env config).
72    pub fn with_client(mut self, client: BedrockClient) -> Self {
73        self.client = Some(client);
74        self
75    }
76
77    /// Sets the AWS region for the Bedrock client.
78    pub fn with_region(mut self, region: impl Into<String>) -> Self {
79        self.region = Some(region.into());
80        self
81    }
82
83    /// Sets the maximum number of retries for throttling/server errors (default: 3).
84    pub fn with_max_retries(mut self, retries: u32) -> Self {
85        self.max_retries = retries;
86        self
87    }
88
89    /// Configures a guardrail for content filtering.
90    pub fn with_guardrail(mut self, id: impl Into<String>, version: impl Into<String>) -> Self {
91        self.guardrail_id = Some(id.into());
92        self.guardrail_version = Some(version.into());
93        self
94    }
95
96    /// Enables prompt caching for system messages and tool definitions.
97    ///
98    /// When enabled, a `CachePoint` content block is appended after the
99    /// system prompt and after the tool configuration in each request.
100    pub fn with_prompt_caching(mut self) -> Self {
101        self.use_prompt_caching = true;
102        self
103    }
104
105    async fn get_client(&self) -> Result<BedrockClient> {
106        if let Some(ref client) = self.client {
107            return Ok(client.clone());
108        }
109
110        let mut config_loader = aws_config::from_env();
111        if let Some(ref region) = self.region {
112            config_loader = config_loader.region(aws_config::Region::new(region.clone()));
113        }
114        let config = config_loader.load().await;
115        Ok(BedrockClient::new(&config))
116    }
117
118    fn build_messages(
119        request: &ChatRequest,
120        use_prompt_caching: bool,
121    ) -> (Vec<SystemContentBlock>, Vec<BedrockMessage>) {
122        let mut system_blocks = Vec::new();
123        let mut messages = Vec::new();
124
125        for msg in &request.messages {
126            match msg.role {
127                Role::System => {
128                    if let Some(ref text) = msg.content {
129                        system_blocks.push(SystemContentBlock::Text(text.clone()));
130                    }
131                }
132                Role::User => {
133                    if let Some(ref text) = msg.content {
134                        messages.push(
135                            BedrockMessage::builder()
136                                .role(ConversationRole::User)
137                                .content(ContentBlock::Text(text.clone()))
138                                .build()
139                                .expect("valid bedrock message"),
140                        );
141                    }
142                }
143                Role::Assistant => {
144                    let mut content_blocks = Vec::new();
145                    if let Some(ref text) = msg.content {
146                        content_blocks.push(ContentBlock::Text(text.clone()));
147                    }
148                    for tc in &msg.tool_calls {
149                        let input_doc = json_to_document(&tc.arguments);
150                        content_blocks.push(ContentBlock::ToolUse(
151                            ToolUseBlock::builder()
152                                .tool_use_id(&tc.id)
153                                .name(&tc.name)
154                                .input(input_doc)
155                                .build()
156                                .expect("valid tool use block"),
157                        ));
158                    }
159                    if !content_blocks.is_empty() {
160                        let mut builder =
161                            BedrockMessage::builder().role(ConversationRole::Assistant);
162                        for block in content_blocks {
163                            builder = builder.content(block);
164                        }
165                        messages.push(builder.build().expect("valid bedrock message"));
166                    }
167                }
168                Role::Tool => {
169                    let tool_call_id = msg.tool_call_id.clone().unwrap_or_default();
170                    let content = msg.content.clone().unwrap_or_default();
171                    let tool_result = ContentBlock::ToolResult(
172                        ToolResultBlock::builder()
173                            .tool_use_id(tool_call_id)
174                            .status(ToolResultStatus::Success)
175                            .content(ToolResultContentBlock::Text(content))
176                            .build()
177                            .expect("valid tool result block"),
178                    );
179                    messages.push(
180                        BedrockMessage::builder()
181                            .role(ConversationRole::User)
182                            .content(tool_result)
183                            .build()
184                            .expect("valid bedrock message"),
185                    );
186                }
187            }
188        }
189
190        if use_prompt_caching && !system_blocks.is_empty() {
191            system_blocks.push(SystemContentBlock::CachePoint(
192                CachePointBlock::builder()
193                    .r#type(CachePointType::Default)
194                    .build()
195                    .expect("valid cache point block"),
196            ));
197        }
198
199        (system_blocks, messages)
200    }
201
202    fn build_tool_config(
203        request: &ChatRequest,
204        use_prompt_caching: bool,
205    ) -> Option<ToolConfiguration> {
206        if request.tools.is_empty() {
207            return None;
208        }
209
210        let tools: Vec<aws_sdk_bedrockruntime::types::Tool> = request
211            .tools
212            .iter()
213            .map(|spec| {
214                let schema_doc = json_to_document(&spec.parameters);
215                aws_sdk_bedrockruntime::types::Tool::ToolSpec(
216                    ToolSpecification::builder()
217                        .name(&spec.name)
218                        .description(&spec.description)
219                        .input_schema(ToolInputSchema::Json(schema_doc))
220                        .build()
221                        .expect("valid tool spec"),
222                )
223            })
224            .collect();
225
226        let mut builder = ToolConfiguration::builder();
227        for tool in tools {
228            builder = builder.tools(tool);
229        }
230        if use_prompt_caching {
231            builder = builder.tools(aws_sdk_bedrockruntime::types::Tool::CachePoint(
232                CachePointBlock::builder()
233                    .r#type(CachePointType::Default)
234                    .build()
235                    .expect("valid cache point block"),
236            ));
237        }
238        Some(builder.build().expect("valid tool config"))
239    }
240
241    fn parse_converse_output(
242        &self,
243        output: aws_sdk_bedrockruntime::operation::converse::ConverseOutput,
244    ) -> Result<ChatResponse> {
245        let stop_reason = match *output.stop_reason() {
246            aws_sdk_bedrockruntime::types::StopReason::ToolUse => StopReason::ToolUse,
247            aws_sdk_bedrockruntime::types::StopReason::MaxTokens => StopReason::MaxTokens,
248            _ => StopReason::EndTurn,
249        };
250
251        let mut text_content = String::new();
252        let mut tool_calls = Vec::new();
253
254        if let Some(aws_sdk_bedrockruntime::types::ConverseOutput::Message(msg)) = output.output()
255        {
256            for block in msg.content() {
257                match block {
258                    ContentBlock::Text(t) => text_content.push_str(t),
259                    ContentBlock::ToolUse(tu) => {
260                        let args = document_to_json(tu.input());
261                        tool_calls.push(ToolCall {
262                            id: tu.tool_use_id().to_string(),
263                            name: tu.name().to_string(),
264                            arguments: args,
265                        });
266                    }
267                    _ => {}
268                }
269            }
270        }
271
272        let message = if tool_calls.is_empty() {
273            Message::assistant(text_content)
274        } else {
275            Message {
276                role: Role::Assistant,
277                content: if text_content.is_empty() {
278                    None
279                } else {
280                    Some(text_content)
281                },
282                tool_calls,
283                tool_call_id: None,
284            }
285        };
286
287        let usage = output.usage().map(|u| Usage {
288            input_tokens: u.input_tokens() as u32,
289            output_tokens: u.output_tokens() as u32,
290            cached_tokens: u.cache_read_input_tokens().unwrap_or(0) as u32,
291        });
292
293        Ok(ChatResponse {
294            message,
295            stop_reason,
296            usage,
297        })
298    }
299}
300
301fn is_retryable_error(err: impl std::fmt::Display) -> bool {
302    let s = err.to_string();
303    let s_lower = s.to_lowercase();
304    s_lower.contains("throttl")
305        || s_lower.contains("service unavailable")
306        || s_lower.contains("internal server")
307        || s.contains("503")
308        || s.contains("429")
309}
310
311impl Model for Bedrock {
312    #[tracing::instrument(skip_all, fields(model = %self.model_id))]
313    async fn generate(&self, request: &ChatRequest) -> Result<ChatResponse> {
314        let client = self.get_client().await?;
315        tracing::debug!("obtained Bedrock client");
316
317        let (system_blocks, messages) = Self::build_messages(request, self.use_prompt_caching);
318        let tool_config = Self::build_tool_config(request, self.use_prompt_caching);
319        tracing::debug!(
320            system_blocks = system_blocks.len(),
321            message_count = messages.len(),
322            has_tools = tool_config.is_some(),
323            prompt_caching = self.use_prompt_caching,
324            "built request messages"
325        );
326
327        let mut last_error = None;
328        for attempt in 0..=self.max_retries {
329            let mut req_builder = client.converse().model_id(&self.model_id);
330
331            for block in system_blocks.clone() {
332                req_builder = req_builder.system(block);
333            }
334            for msg in messages.clone() {
335                req_builder = req_builder.messages(msg);
336            }
337            if let Some(ref tc) = tool_config {
338                req_builder = req_builder.tool_config(tc.clone());
339            }
340
341            let mut inference_config = InferenceConfiguration::builder();
342            if let Some(temp) = request.temperature {
343                inference_config = inference_config.temperature(temp);
344            }
345            if let Some(max_tok) = request.max_tokens {
346                inference_config = inference_config.max_tokens(max_tok as i32);
347            }
348            req_builder = req_builder.inference_config(inference_config.build());
349
350            if let (Some(id), Some(version)) = (&self.guardrail_id, &self.guardrail_version) {
351                let guardrail_config = GuardrailConfiguration::builder()
352                    .guardrail_identifier(id)
353                    .guardrail_version(version)
354                    .build()
355                    .expect("valid guardrail config");
356                req_builder = req_builder.guardrail_config(guardrail_config);
357                tracing::debug!(guardrail_id = %id, "applied guardrail config");
358            }
359
360            match req_builder.send().await {
361                Ok(output) => {
362                    tracing::debug!("received successful Converse response");
363                    return self.parse_converse_output(output);
364                }
365                Err(e) => {
366                    last_error = Some(e.to_string());
367                    if is_retryable_error(e.to_string()) && attempt < self.max_retries {
368                        let delay_ms = 100 * 2u64.pow(attempt);
369                        tracing::debug!(
370                            attempt = attempt + 1,
371                            max_retries = self.max_retries,
372                            delay_ms,
373                            "retryable error, backing off"
374                        );
375                        tokio::time::sleep(Duration::from_millis(delay_ms)).await;
376                    } else {
377                        return Err(DaimonError::Model(format!(
378                            "Bedrock Converse error: {}",
379                            last_error.unwrap_or_default()
380                        )));
381                    }
382                }
383            }
384        }
385
386        Err(DaimonError::Model(format!(
387            "Bedrock Converse error: {}",
388            last_error.unwrap_or_else(|| "unknown".into())
389        )))
390    }
391
392    #[tracing::instrument(skip_all, fields(model = %self.model_id))]
393    async fn generate_stream(&self, request: &ChatRequest) -> Result<ResponseStream> {
394        let client = self.get_client().await?;
395        tracing::debug!("obtained Bedrock client for streaming");
396
397        let (system_blocks, messages) = Self::build_messages(request, self.use_prompt_caching);
398        let tool_config = Self::build_tool_config(request, self.use_prompt_caching);
399        tracing::debug!(
400            system_blocks = system_blocks.len(),
401            message_count = messages.len(),
402            has_tools = tool_config.is_some(),
403            prompt_caching = self.use_prompt_caching,
404            "built request messages for stream"
405        );
406
407        let mut req_builder = client.converse_stream().model_id(&self.model_id);
408
409        for block in system_blocks {
410            req_builder = req_builder.system(block);
411        }
412        for msg in messages {
413            req_builder = req_builder.messages(msg);
414        }
415        if let Some(tc) = tool_config {
416            req_builder = req_builder.tool_config(tc);
417        }
418
419        let mut inference_config = InferenceConfiguration::builder();
420        if let Some(temp) = request.temperature {
421            inference_config = inference_config.temperature(temp);
422        }
423        if let Some(max_tok) = request.max_tokens {
424            inference_config = inference_config.max_tokens(max_tok as i32);
425        }
426        req_builder = req_builder.inference_config(inference_config.build());
427
428        if let (Some(id), Some(version)) = (&self.guardrail_id, &self.guardrail_version) {
429            let guardrail_config = GuardrailStreamConfiguration::builder()
430                .guardrail_identifier(id)
431                .guardrail_version(version)
432                .build()
433                .expect("valid guardrail stream config");
434            req_builder = req_builder.guardrail_config(guardrail_config);
435            tracing::debug!(guardrail_id = %id, "applied guardrail config for stream");
436        }
437
438        let mut event_stream = req_builder
439            .send()
440            .await
441            .map_err(|e| DaimonError::Model(format!("Bedrock ConverseStream error: {e}")))?;
442
443        tracing::debug!("stream established, processing events");
444
445        let stream = async_stream::try_stream! {
446            let stream_output = &mut event_stream.stream;
447            while let Some(event) = stream_output.recv().await.map_err(|e| {
448                DaimonError::Model(format!("Bedrock stream error: {e}"))
449            })? {
450                use aws_sdk_bedrockruntime::types::ConverseStreamOutput;
451                match event {
452                    ConverseStreamOutput::ContentBlockDelta(delta) => {
453                        if let Some(d) = delta.delta() {
454                            use aws_sdk_bedrockruntime::types::ContentBlockDelta as CBD;
455                            match d {
456                                CBD::Text(t) => {
457                                    yield StreamEvent::TextDelta(t.to_string());
458                                }
459                                CBD::ToolUse(tu) => {
460                                    yield StreamEvent::ToolCallDelta {
461                                        id: String::new(),
462                                        arguments_delta: tu.input().to_string(),
463                                    };
464                                }
465                                _ => {}
466                            }
467                        }
468                    }
469                    ConverseStreamOutput::ContentBlockStart(start) => {
470                        if let Some(s) = start.start() {
471                            use aws_sdk_bedrockruntime::types::ContentBlockStart as CBS;
472                            if let CBS::ToolUse(tu) = s {
473                                yield StreamEvent::ToolCallStart {
474                                    id: tu.tool_use_id().to_string(),
475                                    name: tu.name().to_string(),
476                                };
477                            }
478                        }
479                    }
480                    ConverseStreamOutput::MessageStop(_) => {
481                        yield StreamEvent::Done;
482                    }
483                    _ => {}
484                }
485            }
486        };
487
488        Ok(Box::pin(stream))
489    }
490}
491
492fn json_to_document(value: &serde_json::Value) -> aws_smithy_types::Document {
493    match value {
494        serde_json::Value::Null => aws_smithy_types::Document::Null,
495        serde_json::Value::Bool(b) => aws_smithy_types::Document::Bool(*b),
496        serde_json::Value::Number(n) => {
497            if let Some(i) = n.as_i64() {
498                aws_smithy_types::Document::Number(aws_smithy_types::Number::PosInt(i as u64))
499            } else if let Some(f) = n.as_f64() {
500                aws_smithy_types::Document::Number(aws_smithy_types::Number::Float(f))
501            } else {
502                aws_smithy_types::Document::Null
503            }
504        }
505        serde_json::Value::String(s) => aws_smithy_types::Document::String(s.clone()),
506        serde_json::Value::Array(arr) => {
507            aws_smithy_types::Document::Array(arr.iter().map(json_to_document).collect())
508        }
509        serde_json::Value::Object(obj) => {
510            let map: std::collections::HashMap<String, aws_smithy_types::Document> = obj
511                .iter()
512                .map(|(k, v)| (k.clone(), json_to_document(v)))
513                .collect();
514            aws_smithy_types::Document::Object(map)
515        }
516    }
517}
518
519fn document_to_json(doc: &aws_smithy_types::Document) -> serde_json::Value {
520    match doc {
521        aws_smithy_types::Document::Object(map) => {
522            let obj: serde_json::Map<String, serde_json::Value> = map
523                .iter()
524                .map(|(k, v)| (k.clone(), document_to_json(v)))
525                .collect();
526            serde_json::Value::Object(obj)
527        }
528        aws_smithy_types::Document::Array(arr) => {
529            serde_json::Value::Array(arr.iter().map(document_to_json).collect())
530        }
531        aws_smithy_types::Document::Number(n) => match n {
532            aws_smithy_types::Number::PosInt(i) => serde_json::Value::Number((*i).into()),
533            aws_smithy_types::Number::NegInt(i) => serde_json::Value::Number((*i).into()),
534            aws_smithy_types::Number::Float(f) => serde_json::Value::Number(
535                serde_json::Number::from_f64(*f).unwrap_or(serde_json::Number::from(0)),
536            ),
537        },
538        aws_smithy_types::Document::String(s) => serde_json::Value::String(s.clone()),
539        aws_smithy_types::Document::Bool(b) => serde_json::Value::Bool(*b),
540        aws_smithy_types::Document::Null => serde_json::Value::Null,
541    }
542}
543
544#[cfg(test)]
545mod tests {
546    use super::*;
547    use daimon_core::ToolSpec;
548
549    #[test]
550    fn test_bedrock_new() {
551        let model = Bedrock::new("us.anthropic.claude-sonnet-4-20250514");
552        assert_eq!(model.model_id, "us.anthropic.claude-sonnet-4-20250514");
553        assert!(model.client.is_none());
554    }
555
556    #[test]
557    fn test_bedrock_with_region() {
558        let model = Bedrock::new("test").with_region("us-east-1");
559        assert_eq!(model.region.as_deref(), Some("us-east-1"));
560    }
561
562    #[test]
563    fn test_bedrock_with_max_retries() {
564        let model = Bedrock::new("test").with_max_retries(5);
565        assert_eq!(model.max_retries, 5);
566    }
567
568    #[test]
569    fn test_bedrock_with_max_retries_default() {
570        let model = Bedrock::new("test");
571        assert_eq!(model.max_retries, 3);
572    }
573
574    #[test]
575    fn test_bedrock_with_guardrail() {
576        let model = Bedrock::new("test").with_guardrail("guardrail-123", "DRAFT");
577        assert_eq!(model.guardrail_id.as_deref(), Some("guardrail-123"));
578        assert_eq!(model.guardrail_version.as_deref(), Some("DRAFT"));
579    }
580
581    #[test]
582    fn test_bedrock_with_guardrail_default_none() {
583        let model = Bedrock::new("test");
584        assert!(model.guardrail_id.is_none());
585        assert!(model.guardrail_version.is_none());
586    }
587
588    #[test]
589    fn test_build_messages_basic() {
590        let request = ChatRequest {
591            messages: vec![Message::system("Be helpful"), Message::user("hello")],
592            tools: vec![],
593            temperature: None,
594            max_tokens: None,
595        };
596        let (system, messages) = Bedrock::build_messages(&request, false);
597        assert_eq!(system.len(), 1);
598        assert_eq!(messages.len(), 1);
599    }
600
601    #[test]
602    fn test_build_messages_with_tool_results() {
603        let request = ChatRequest {
604            messages: vec![
605                Message::user("calc"),
606                Message::assistant_with_tool_calls(vec![ToolCall {
607                    id: "tc_1".into(),
608                    name: "calc".into(),
609                    arguments: serde_json::json!({}),
610                }]),
611                Message::tool_result("tc_1", "42"),
612            ],
613            tools: vec![],
614            temperature: None,
615            max_tokens: None,
616        };
617        let (_, messages) = Bedrock::build_messages(&request, false);
618        assert_eq!(messages.len(), 3);
619    }
620
621    #[test]
622    fn test_build_messages_with_caching() {
623        let request = ChatRequest {
624            messages: vec![Message::system("Be helpful"), Message::user("hello")],
625            tools: vec![],
626            temperature: None,
627            max_tokens: None,
628        };
629        let (system, _) = Bedrock::build_messages(&request, true);
630        assert_eq!(system.len(), 2, "should have text + cache point");
631    }
632
633    #[test]
634    fn test_build_messages_caching_no_system() {
635        let request = ChatRequest {
636            messages: vec![Message::user("hello")],
637            tools: vec![],
638            temperature: None,
639            max_tokens: None,
640        };
641        let (system, _) = Bedrock::build_messages(&request, true);
642        assert!(system.is_empty(), "no cache point when no system prompt");
643    }
644
645    #[test]
646    fn test_json_to_document_string() {
647        let json = serde_json::json!("hello");
648        let doc = json_to_document(&json);
649        assert!(matches!(doc, aws_smithy_types::Document::String(s) if s == "hello"));
650    }
651
652    #[test]
653    fn test_json_to_document_object() {
654        let json = serde_json::json!({"key": "value"});
655        let doc = json_to_document(&json);
656        if let aws_smithy_types::Document::Object(map) = doc {
657            assert!(map.contains_key("key"));
658        } else {
659            panic!("expected Document::Object");
660        }
661    }
662
663    #[test]
664    fn test_json_to_document_null() {
665        let json = serde_json::Value::Null;
666        let doc = json_to_document(&json);
667        assert!(matches!(doc, aws_smithy_types::Document::Null));
668    }
669
670    #[test]
671    fn test_document_to_json_object() {
672        let mut map = std::collections::HashMap::new();
673        map.insert(
674            "key".to_string(),
675            aws_smithy_types::Document::String("value".into()),
676        );
677        let doc = aws_smithy_types::Document::Object(map);
678        let json = document_to_json(&doc);
679        assert_eq!(json["key"], "value");
680    }
681
682    #[test]
683    fn test_document_to_json_null() {
684        let json = document_to_json(&aws_smithy_types::Document::Null);
685        assert!(json.is_null());
686    }
687
688    #[test]
689    fn test_document_to_json_bool() {
690        let json = document_to_json(&aws_smithy_types::Document::Bool(true));
691        assert_eq!(json, serde_json::Value::Bool(true));
692    }
693
694    #[test]
695    fn test_document_to_json_array() {
696        let doc = aws_smithy_types::Document::Array(vec![
697            aws_smithy_types::Document::String("a".into()),
698            aws_smithy_types::Document::String("b".into()),
699        ]);
700        let json = document_to_json(&doc);
701        assert!(json.is_array());
702        assert_eq!(json.as_array().unwrap().len(), 2);
703    }
704
705    #[test]
706    fn test_roundtrip_json_document() {
707        let original = serde_json::json!({
708            "type": "object",
709            "properties": {
710                "name": {"type": "string"},
711                "count": 42,
712                "active": true
713            }
714        });
715        let doc = json_to_document(&original);
716        let back = document_to_json(&doc);
717        assert_eq!(original, back);
718    }
719
720    #[test]
721    fn test_build_tool_config_empty() {
722        let request = ChatRequest {
723            messages: vec![],
724            tools: vec![],
725            temperature: None,
726            max_tokens: None,
727        };
728        assert!(Bedrock::build_tool_config(&request, false).is_none());
729    }
730
731    #[test]
732    fn test_build_tool_config_with_tools() {
733        let request = ChatRequest {
734            messages: vec![],
735            tools: vec![ToolSpec {
736                name: "calc".into(),
737                description: "Calculator".into(),
738                parameters: serde_json::json!({"type": "object"}),
739            }],
740            temperature: None,
741            max_tokens: None,
742        };
743        assert!(Bedrock::build_tool_config(&request, false).is_some());
744    }
745
746    #[test]
747    fn test_with_prompt_caching() {
748        let model = Bedrock::new("test").with_prompt_caching();
749        assert!(model.use_prompt_caching);
750    }
751}