1use std::time::Duration;
18
19use aws_sdk_bedrockruntime::Client as BedrockClient;
20use aws_sdk_bedrockruntime::types::{
21 CachePointBlock, CachePointType, ContentBlock, ConversationRole, GuardrailConfiguration,
22 GuardrailStreamConfiguration, InferenceConfiguration, Message as BedrockMessage,
23 SystemContentBlock, ToolConfiguration, ToolInputSchema, ToolResultBlock,
24 ToolResultContentBlock, ToolResultStatus, ToolSpecification, ToolUseBlock,
25};
26
27mod embedding;
28
29#[cfg(feature = "sqs")]
30pub mod sqs;
31
32pub use embedding::BedrockEmbedding;
33
34#[cfg(feature = "sqs")]
35pub use sqs::SqsBroker;
36
37use daimon_core::{
38 ChatRequest, ChatResponse, DaimonError, Message, Model, ResponseStream, Result, Role,
39 StopReason, StreamEvent, ToolCall, Usage,
40};
41
42#[derive(Debug)]
47pub struct Bedrock {
48 model_id: String,
49 client: Option<BedrockClient>,
50 region: Option<String>,
51 max_retries: u32,
52 guardrail_id: Option<String>,
53 guardrail_version: Option<String>,
54 use_prompt_caching: bool,
55}
56
57impl Bedrock {
58 pub fn new(model_id: impl Into<String>) -> Self {
60 Self {
61 model_id: model_id.into(),
62 client: None,
63 region: None,
64 max_retries: 3,
65 guardrail_id: None,
66 guardrail_version: None,
67 use_prompt_caching: false,
68 }
69 }
70
71 pub fn with_client(mut self, client: BedrockClient) -> Self {
73 self.client = Some(client);
74 self
75 }
76
77 pub fn with_region(mut self, region: impl Into<String>) -> Self {
79 self.region = Some(region.into());
80 self
81 }
82
83 pub fn with_max_retries(mut self, retries: u32) -> Self {
85 self.max_retries = retries;
86 self
87 }
88
89 pub fn with_guardrail(mut self, id: impl Into<String>, version: impl Into<String>) -> Self {
91 self.guardrail_id = Some(id.into());
92 self.guardrail_version = Some(version.into());
93 self
94 }
95
96 pub fn with_prompt_caching(mut self) -> Self {
101 self.use_prompt_caching = true;
102 self
103 }
104
105 async fn get_client(&self) -> Result<BedrockClient> {
106 if let Some(ref client) = self.client {
107 return Ok(client.clone());
108 }
109
110 let mut config_loader = aws_config::from_env();
111 if let Some(ref region) = self.region {
112 config_loader = config_loader.region(aws_config::Region::new(region.clone()));
113 }
114 let config = config_loader.load().await;
115 Ok(BedrockClient::new(&config))
116 }
117
118 fn build_messages(
119 request: &ChatRequest,
120 use_prompt_caching: bool,
121 ) -> (Vec<SystemContentBlock>, Vec<BedrockMessage>) {
122 let mut system_blocks = Vec::new();
123 let mut messages = Vec::new();
124
125 for msg in &request.messages {
126 match msg.role {
127 Role::System => {
128 if let Some(ref text) = msg.content {
129 system_blocks.push(SystemContentBlock::Text(text.clone()));
130 }
131 }
132 Role::User => {
133 if let Some(ref text) = msg.content {
134 messages.push(
135 BedrockMessage::builder()
136 .role(ConversationRole::User)
137 .content(ContentBlock::Text(text.clone()))
138 .build()
139 .expect("valid bedrock message"),
140 );
141 }
142 }
143 Role::Assistant => {
144 let mut content_blocks = Vec::new();
145 if let Some(ref text) = msg.content {
146 content_blocks.push(ContentBlock::Text(text.clone()));
147 }
148 for tc in &msg.tool_calls {
149 let input_doc = json_to_document(&tc.arguments);
150 content_blocks.push(ContentBlock::ToolUse(
151 ToolUseBlock::builder()
152 .tool_use_id(&tc.id)
153 .name(&tc.name)
154 .input(input_doc)
155 .build()
156 .expect("valid tool use block"),
157 ));
158 }
159 if !content_blocks.is_empty() {
160 let mut builder =
161 BedrockMessage::builder().role(ConversationRole::Assistant);
162 for block in content_blocks {
163 builder = builder.content(block);
164 }
165 messages.push(builder.build().expect("valid bedrock message"));
166 }
167 }
168 Role::Tool => {
169 let tool_call_id = msg.tool_call_id.clone().unwrap_or_default();
170 let content = msg.content.clone().unwrap_or_default();
171 let tool_result = ContentBlock::ToolResult(
172 ToolResultBlock::builder()
173 .tool_use_id(tool_call_id)
174 .status(ToolResultStatus::Success)
175 .content(ToolResultContentBlock::Text(content))
176 .build()
177 .expect("valid tool result block"),
178 );
179 messages.push(
180 BedrockMessage::builder()
181 .role(ConversationRole::User)
182 .content(tool_result)
183 .build()
184 .expect("valid bedrock message"),
185 );
186 }
187 }
188 }
189
190 if use_prompt_caching && !system_blocks.is_empty() {
191 system_blocks.push(SystemContentBlock::CachePoint(
192 CachePointBlock::builder()
193 .r#type(CachePointType::Default)
194 .build()
195 .expect("valid cache point block"),
196 ));
197 }
198
199 (system_blocks, messages)
200 }
201
202 fn build_tool_config(
203 request: &ChatRequest,
204 use_prompt_caching: bool,
205 ) -> Option<ToolConfiguration> {
206 if request.tools.is_empty() {
207 return None;
208 }
209
210 let tools: Vec<aws_sdk_bedrockruntime::types::Tool> = request
211 .tools
212 .iter()
213 .map(|spec| {
214 let schema_doc = json_to_document(&spec.parameters);
215 aws_sdk_bedrockruntime::types::Tool::ToolSpec(
216 ToolSpecification::builder()
217 .name(&spec.name)
218 .description(&spec.description)
219 .input_schema(ToolInputSchema::Json(schema_doc))
220 .build()
221 .expect("valid tool spec"),
222 )
223 })
224 .collect();
225
226 let mut builder = ToolConfiguration::builder();
227 for tool in tools {
228 builder = builder.tools(tool);
229 }
230 if use_prompt_caching {
231 builder = builder.tools(aws_sdk_bedrockruntime::types::Tool::CachePoint(
232 CachePointBlock::builder()
233 .r#type(CachePointType::Default)
234 .build()
235 .expect("valid cache point block"),
236 ));
237 }
238 Some(builder.build().expect("valid tool config"))
239 }
240
241 fn parse_converse_output(
242 &self,
243 output: aws_sdk_bedrockruntime::operation::converse::ConverseOutput,
244 ) -> Result<ChatResponse> {
245 let stop_reason = match *output.stop_reason() {
246 aws_sdk_bedrockruntime::types::StopReason::ToolUse => StopReason::ToolUse,
247 aws_sdk_bedrockruntime::types::StopReason::MaxTokens => StopReason::MaxTokens,
248 _ => StopReason::EndTurn,
249 };
250
251 let mut text_content = String::new();
252 let mut tool_calls = Vec::new();
253
254 if let Some(aws_sdk_bedrockruntime::types::ConverseOutput::Message(msg)) = output.output()
255 {
256 for block in msg.content() {
257 match block {
258 ContentBlock::Text(t) => text_content.push_str(t),
259 ContentBlock::ToolUse(tu) => {
260 let args = document_to_json(tu.input());
261 tool_calls.push(ToolCall {
262 id: tu.tool_use_id().to_string(),
263 name: tu.name().to_string(),
264 arguments: args,
265 });
266 }
267 _ => {}
268 }
269 }
270 }
271
272 let message = if tool_calls.is_empty() {
273 Message::assistant(text_content)
274 } else {
275 Message {
276 role: Role::Assistant,
277 content: if text_content.is_empty() {
278 None
279 } else {
280 Some(text_content)
281 },
282 tool_calls,
283 tool_call_id: None,
284 }
285 };
286
287 let usage = output.usage().map(|u| Usage {
288 input_tokens: u.input_tokens() as u32,
289 output_tokens: u.output_tokens() as u32,
290 cached_tokens: u.cache_read_input_tokens().unwrap_or(0) as u32,
291 });
292
293 Ok(ChatResponse {
294 message,
295 stop_reason,
296 usage,
297 })
298 }
299}
300
301fn is_retryable_error(err: impl std::fmt::Display) -> bool {
302 let s = err.to_string();
303 let s_lower = s.to_lowercase();
304 s_lower.contains("throttl")
305 || s_lower.contains("service unavailable")
306 || s_lower.contains("internal server")
307 || s.contains("503")
308 || s.contains("429")
309}
310
311impl Model for Bedrock {
312 #[tracing::instrument(skip_all, fields(model = %self.model_id))]
313 async fn generate(&self, request: &ChatRequest) -> Result<ChatResponse> {
314 let client = self.get_client().await?;
315 tracing::debug!("obtained Bedrock client");
316
317 let (system_blocks, messages) = Self::build_messages(request, self.use_prompt_caching);
318 let tool_config = Self::build_tool_config(request, self.use_prompt_caching);
319 tracing::debug!(
320 system_blocks = system_blocks.len(),
321 message_count = messages.len(),
322 has_tools = tool_config.is_some(),
323 prompt_caching = self.use_prompt_caching,
324 "built request messages"
325 );
326
327 let mut last_error = None;
328 for attempt in 0..=self.max_retries {
329 let mut req_builder = client.converse().model_id(&self.model_id);
330
331 for block in system_blocks.clone() {
332 req_builder = req_builder.system(block);
333 }
334 for msg in messages.clone() {
335 req_builder = req_builder.messages(msg);
336 }
337 if let Some(ref tc) = tool_config {
338 req_builder = req_builder.tool_config(tc.clone());
339 }
340
341 let mut inference_config = InferenceConfiguration::builder();
342 if let Some(temp) = request.temperature {
343 inference_config = inference_config.temperature(temp);
344 }
345 if let Some(max_tok) = request.max_tokens {
346 inference_config = inference_config.max_tokens(max_tok as i32);
347 }
348 req_builder = req_builder.inference_config(inference_config.build());
349
350 if let (Some(id), Some(version)) = (&self.guardrail_id, &self.guardrail_version) {
351 let guardrail_config = GuardrailConfiguration::builder()
352 .guardrail_identifier(id)
353 .guardrail_version(version)
354 .build()
355 .expect("valid guardrail config");
356 req_builder = req_builder.guardrail_config(guardrail_config);
357 tracing::debug!(guardrail_id = %id, "applied guardrail config");
358 }
359
360 match req_builder.send().await {
361 Ok(output) => {
362 tracing::debug!("received successful Converse response");
363 return self.parse_converse_output(output);
364 }
365 Err(e) => {
366 last_error = Some(e.to_string());
367 if is_retryable_error(e.to_string()) && attempt < self.max_retries {
368 let delay_ms = 100 * 2u64.pow(attempt);
369 tracing::debug!(
370 attempt = attempt + 1,
371 max_retries = self.max_retries,
372 delay_ms,
373 "retryable error, backing off"
374 );
375 tokio::time::sleep(Duration::from_millis(delay_ms)).await;
376 } else {
377 return Err(DaimonError::Model(format!(
378 "Bedrock Converse error: {}",
379 last_error.unwrap_or_default()
380 )));
381 }
382 }
383 }
384 }
385
386 Err(DaimonError::Model(format!(
387 "Bedrock Converse error: {}",
388 last_error.unwrap_or_else(|| "unknown".into())
389 )))
390 }
391
392 #[tracing::instrument(skip_all, fields(model = %self.model_id))]
393 async fn generate_stream(&self, request: &ChatRequest) -> Result<ResponseStream> {
394 let client = self.get_client().await?;
395 tracing::debug!("obtained Bedrock client for streaming");
396
397 let (system_blocks, messages) = Self::build_messages(request, self.use_prompt_caching);
398 let tool_config = Self::build_tool_config(request, self.use_prompt_caching);
399 tracing::debug!(
400 system_blocks = system_blocks.len(),
401 message_count = messages.len(),
402 has_tools = tool_config.is_some(),
403 prompt_caching = self.use_prompt_caching,
404 "built request messages for stream"
405 );
406
407 let mut req_builder = client.converse_stream().model_id(&self.model_id);
408
409 for block in system_blocks {
410 req_builder = req_builder.system(block);
411 }
412 for msg in messages {
413 req_builder = req_builder.messages(msg);
414 }
415 if let Some(tc) = tool_config {
416 req_builder = req_builder.tool_config(tc);
417 }
418
419 let mut inference_config = InferenceConfiguration::builder();
420 if let Some(temp) = request.temperature {
421 inference_config = inference_config.temperature(temp);
422 }
423 if let Some(max_tok) = request.max_tokens {
424 inference_config = inference_config.max_tokens(max_tok as i32);
425 }
426 req_builder = req_builder.inference_config(inference_config.build());
427
428 if let (Some(id), Some(version)) = (&self.guardrail_id, &self.guardrail_version) {
429 let guardrail_config = GuardrailStreamConfiguration::builder()
430 .guardrail_identifier(id)
431 .guardrail_version(version)
432 .build()
433 .expect("valid guardrail stream config");
434 req_builder = req_builder.guardrail_config(guardrail_config);
435 tracing::debug!(guardrail_id = %id, "applied guardrail config for stream");
436 }
437
438 let mut event_stream = req_builder
439 .send()
440 .await
441 .map_err(|e| DaimonError::Model(format!("Bedrock ConverseStream error: {e}")))?;
442
443 tracing::debug!("stream established, processing events");
444
445 let stream = async_stream::try_stream! {
446 let stream_output = &mut event_stream.stream;
447 while let Some(event) = stream_output.recv().await.map_err(|e| {
448 DaimonError::Model(format!("Bedrock stream error: {e}"))
449 })? {
450 use aws_sdk_bedrockruntime::types::ConverseStreamOutput;
451 match event {
452 ConverseStreamOutput::ContentBlockDelta(delta) => {
453 if let Some(d) = delta.delta() {
454 use aws_sdk_bedrockruntime::types::ContentBlockDelta as CBD;
455 match d {
456 CBD::Text(t) => {
457 yield StreamEvent::TextDelta(t.to_string());
458 }
459 CBD::ToolUse(tu) => {
460 yield StreamEvent::ToolCallDelta {
461 id: String::new(),
462 arguments_delta: tu.input().to_string(),
463 };
464 }
465 _ => {}
466 }
467 }
468 }
469 ConverseStreamOutput::ContentBlockStart(start) => {
470 if let Some(s) = start.start() {
471 use aws_sdk_bedrockruntime::types::ContentBlockStart as CBS;
472 if let CBS::ToolUse(tu) = s {
473 yield StreamEvent::ToolCallStart {
474 id: tu.tool_use_id().to_string(),
475 name: tu.name().to_string(),
476 };
477 }
478 }
479 }
480 ConverseStreamOutput::MessageStop(_) => {
481 yield StreamEvent::Done;
482 }
483 _ => {}
484 }
485 }
486 };
487
488 Ok(Box::pin(stream))
489 }
490}
491
492fn json_to_document(value: &serde_json::Value) -> aws_smithy_types::Document {
493 match value {
494 serde_json::Value::Null => aws_smithy_types::Document::Null,
495 serde_json::Value::Bool(b) => aws_smithy_types::Document::Bool(*b),
496 serde_json::Value::Number(n) => {
497 if let Some(i) = n.as_i64() {
498 aws_smithy_types::Document::Number(aws_smithy_types::Number::PosInt(i as u64))
499 } else if let Some(f) = n.as_f64() {
500 aws_smithy_types::Document::Number(aws_smithy_types::Number::Float(f))
501 } else {
502 aws_smithy_types::Document::Null
503 }
504 }
505 serde_json::Value::String(s) => aws_smithy_types::Document::String(s.clone()),
506 serde_json::Value::Array(arr) => {
507 aws_smithy_types::Document::Array(arr.iter().map(json_to_document).collect())
508 }
509 serde_json::Value::Object(obj) => {
510 let map: std::collections::HashMap<String, aws_smithy_types::Document> = obj
511 .iter()
512 .map(|(k, v)| (k.clone(), json_to_document(v)))
513 .collect();
514 aws_smithy_types::Document::Object(map)
515 }
516 }
517}
518
519fn document_to_json(doc: &aws_smithy_types::Document) -> serde_json::Value {
520 match doc {
521 aws_smithy_types::Document::Object(map) => {
522 let obj: serde_json::Map<String, serde_json::Value> = map
523 .iter()
524 .map(|(k, v)| (k.clone(), document_to_json(v)))
525 .collect();
526 serde_json::Value::Object(obj)
527 }
528 aws_smithy_types::Document::Array(arr) => {
529 serde_json::Value::Array(arr.iter().map(document_to_json).collect())
530 }
531 aws_smithy_types::Document::Number(n) => match n {
532 aws_smithy_types::Number::PosInt(i) => serde_json::Value::Number((*i).into()),
533 aws_smithy_types::Number::NegInt(i) => serde_json::Value::Number((*i).into()),
534 aws_smithy_types::Number::Float(f) => serde_json::Value::Number(
535 serde_json::Number::from_f64(*f).unwrap_or(serde_json::Number::from(0)),
536 ),
537 },
538 aws_smithy_types::Document::String(s) => serde_json::Value::String(s.clone()),
539 aws_smithy_types::Document::Bool(b) => serde_json::Value::Bool(*b),
540 aws_smithy_types::Document::Null => serde_json::Value::Null,
541 }
542}
543
544#[cfg(test)]
545mod tests {
546 use super::*;
547 use daimon_core::ToolSpec;
548
549 #[test]
550 fn test_bedrock_new() {
551 let model = Bedrock::new("us.anthropic.claude-sonnet-4-20250514");
552 assert_eq!(model.model_id, "us.anthropic.claude-sonnet-4-20250514");
553 assert!(model.client.is_none());
554 }
555
556 #[test]
557 fn test_bedrock_with_region() {
558 let model = Bedrock::new("test").with_region("us-east-1");
559 assert_eq!(model.region.as_deref(), Some("us-east-1"));
560 }
561
562 #[test]
563 fn test_bedrock_with_max_retries() {
564 let model = Bedrock::new("test").with_max_retries(5);
565 assert_eq!(model.max_retries, 5);
566 }
567
568 #[test]
569 fn test_bedrock_with_max_retries_default() {
570 let model = Bedrock::new("test");
571 assert_eq!(model.max_retries, 3);
572 }
573
574 #[test]
575 fn test_bedrock_with_guardrail() {
576 let model = Bedrock::new("test").with_guardrail("guardrail-123", "DRAFT");
577 assert_eq!(model.guardrail_id.as_deref(), Some("guardrail-123"));
578 assert_eq!(model.guardrail_version.as_deref(), Some("DRAFT"));
579 }
580
581 #[test]
582 fn test_bedrock_with_guardrail_default_none() {
583 let model = Bedrock::new("test");
584 assert!(model.guardrail_id.is_none());
585 assert!(model.guardrail_version.is_none());
586 }
587
588 #[test]
589 fn test_build_messages_basic() {
590 let request = ChatRequest {
591 messages: vec![Message::system("Be helpful"), Message::user("hello")],
592 tools: vec![],
593 temperature: None,
594 max_tokens: None,
595 };
596 let (system, messages) = Bedrock::build_messages(&request, false);
597 assert_eq!(system.len(), 1);
598 assert_eq!(messages.len(), 1);
599 }
600
601 #[test]
602 fn test_build_messages_with_tool_results() {
603 let request = ChatRequest {
604 messages: vec![
605 Message::user("calc"),
606 Message::assistant_with_tool_calls(vec![ToolCall {
607 id: "tc_1".into(),
608 name: "calc".into(),
609 arguments: serde_json::json!({}),
610 }]),
611 Message::tool_result("tc_1", "42"),
612 ],
613 tools: vec![],
614 temperature: None,
615 max_tokens: None,
616 };
617 let (_, messages) = Bedrock::build_messages(&request, false);
618 assert_eq!(messages.len(), 3);
619 }
620
621 #[test]
622 fn test_build_messages_with_caching() {
623 let request = ChatRequest {
624 messages: vec![Message::system("Be helpful"), Message::user("hello")],
625 tools: vec![],
626 temperature: None,
627 max_tokens: None,
628 };
629 let (system, _) = Bedrock::build_messages(&request, true);
630 assert_eq!(system.len(), 2, "should have text + cache point");
631 }
632
633 #[test]
634 fn test_build_messages_caching_no_system() {
635 let request = ChatRequest {
636 messages: vec![Message::user("hello")],
637 tools: vec![],
638 temperature: None,
639 max_tokens: None,
640 };
641 let (system, _) = Bedrock::build_messages(&request, true);
642 assert!(system.is_empty(), "no cache point when no system prompt");
643 }
644
645 #[test]
646 fn test_json_to_document_string() {
647 let json = serde_json::json!("hello");
648 let doc = json_to_document(&json);
649 assert!(matches!(doc, aws_smithy_types::Document::String(s) if s == "hello"));
650 }
651
652 #[test]
653 fn test_json_to_document_object() {
654 let json = serde_json::json!({"key": "value"});
655 let doc = json_to_document(&json);
656 if let aws_smithy_types::Document::Object(map) = doc {
657 assert!(map.contains_key("key"));
658 } else {
659 panic!("expected Document::Object");
660 }
661 }
662
663 #[test]
664 fn test_json_to_document_null() {
665 let json = serde_json::Value::Null;
666 let doc = json_to_document(&json);
667 assert!(matches!(doc, aws_smithy_types::Document::Null));
668 }
669
670 #[test]
671 fn test_document_to_json_object() {
672 let mut map = std::collections::HashMap::new();
673 map.insert(
674 "key".to_string(),
675 aws_smithy_types::Document::String("value".into()),
676 );
677 let doc = aws_smithy_types::Document::Object(map);
678 let json = document_to_json(&doc);
679 assert_eq!(json["key"], "value");
680 }
681
682 #[test]
683 fn test_document_to_json_null() {
684 let json = document_to_json(&aws_smithy_types::Document::Null);
685 assert!(json.is_null());
686 }
687
688 #[test]
689 fn test_document_to_json_bool() {
690 let json = document_to_json(&aws_smithy_types::Document::Bool(true));
691 assert_eq!(json, serde_json::Value::Bool(true));
692 }
693
694 #[test]
695 fn test_document_to_json_array() {
696 let doc = aws_smithy_types::Document::Array(vec![
697 aws_smithy_types::Document::String("a".into()),
698 aws_smithy_types::Document::String("b".into()),
699 ]);
700 let json = document_to_json(&doc);
701 assert!(json.is_array());
702 assert_eq!(json.as_array().unwrap().len(), 2);
703 }
704
705 #[test]
706 fn test_roundtrip_json_document() {
707 let original = serde_json::json!({
708 "type": "object",
709 "properties": {
710 "name": {"type": "string"},
711 "count": 42,
712 "active": true
713 }
714 });
715 let doc = json_to_document(&original);
716 let back = document_to_json(&doc);
717 assert_eq!(original, back);
718 }
719
720 #[test]
721 fn test_build_tool_config_empty() {
722 let request = ChatRequest {
723 messages: vec![],
724 tools: vec![],
725 temperature: None,
726 max_tokens: None,
727 };
728 assert!(Bedrock::build_tool_config(&request, false).is_none());
729 }
730
731 #[test]
732 fn test_build_tool_config_with_tools() {
733 let request = ChatRequest {
734 messages: vec![],
735 tools: vec![ToolSpec {
736 name: "calc".into(),
737 description: "Calculator".into(),
738 parameters: serde_json::json!({"type": "object"}),
739 }],
740 temperature: None,
741 max_tokens: None,
742 };
743 assert!(Bedrock::build_tool_config(&request, false).is_some());
744 }
745
746 #[test]
747 fn test_with_prompt_caching() {
748 let model = Bedrock::new("test").with_prompt_caching();
749 assert!(model.use_prompt_caching);
750 }
751}