1use crate::auth::{AuthStorage, AwsResolvedCredentials, resolve_aws_credentials};
7use crate::config::Config;
8use crate::error::{Error, Result};
9use crate::http::client::Client;
10use crate::model::{
11 AssistantMessage, ContentBlock, Message, StopReason, StreamEvent, TextContent, ToolCall,
12 ToolResultMessage, Usage, UserContent,
13};
14use crate::models::CompatConfig;
15use crate::provider::{Context, Provider, StreamOptions, ToolDef};
16use async_trait::async_trait;
17use chrono::{DateTime, Utc};
18use futures::Stream;
19use futures::stream;
20use hmac::{Hmac, Mac};
21use serde::{Deserialize, Serialize};
22use serde_json::Value;
23use sha2::{Digest, Sha256};
24use std::fmt::Write as _;
25#[cfg(test)]
26use std::path::Path;
27use std::path::PathBuf;
28use std::pin::Pin;
29use url::Url;
30
31const DEFAULT_REGION: &str = "us-east-1";
32const BEDROCK_SERVICE: &str = "bedrock";
33
34type HmacSha256 = Hmac<Sha256>;
35
36#[derive(Debug, Clone)]
37enum BedrockAuth {
38 Sigv4 {
39 access_key_id: String,
40 secret_access_key: String,
41 session_token: Option<String>,
42 },
43 Bearer {
44 token: String,
45 },
46}
47
48#[derive(Debug, Clone)]
49struct BedrockAuthContext {
50 auth: BedrockAuth,
51 region: String,
52}
53
54#[derive(Debug, Clone)]
55struct Sigv4Headers {
56 authorization: String,
57 amz_date: String,
58 payload_hash: String,
59 security_token: Option<String>,
60}
61
62pub struct BedrockProvider {
64 client: Client,
65 model: String,
66 provider_name: String,
67 base_url_override: Option<String>,
68 compat: Option<CompatConfig>,
69 auth_path_override: Option<PathBuf>,
70}
71
72impl BedrockProvider {
73 pub fn new(model: impl Into<String>) -> Self {
75 let raw_model = model.into();
76 let normalized_model = normalize_model_id(&raw_model)
77 .ok()
78 .unwrap_or_else(|| raw_model.trim().to_string());
79 Self {
80 client: Client::new(),
81 model: normalized_model,
82 provider_name: "amazon-bedrock".to_string(),
83 base_url_override: None,
84 compat: None,
85 auth_path_override: None,
86 }
87 }
88
89 #[must_use]
91 pub fn with_provider_name(mut self, provider_name: impl Into<String>) -> Self {
92 self.provider_name = provider_name.into();
93 self
94 }
95
96 #[must_use]
98 pub fn with_base_url(mut self, base_url: impl AsRef<str>) -> Self {
99 let trimmed = base_url.as_ref().trim();
100 if !trimmed.is_empty() {
101 self.base_url_override = Some(trimmed.to_string());
102 }
103 self
104 }
105
106 #[must_use]
108 pub fn with_compat(mut self, compat: Option<CompatConfig>) -> Self {
109 self.compat = compat;
110 self
111 }
112
113 #[must_use]
115 pub fn with_client(mut self, client: Client) -> Self {
116 self.client = client;
117 self
118 }
119
120 #[cfg(test)]
121 #[must_use]
122 fn with_auth_path(mut self, path: impl AsRef<Path>) -> Self {
123 self.auth_path_override = Some(path.as_ref().to_path_buf());
124 self
125 }
126
127 fn auth_path(&self) -> PathBuf {
128 self.auth_path_override
129 .clone()
130 .unwrap_or_else(Config::auth_path)
131 }
132
133 fn load_auth_storage(&self) -> Result<AuthStorage> {
134 AuthStorage::load(self.auth_path())
135 .map_err(|err| Error::auth(format!("Failed to load Bedrock credentials: {err}")))
136 }
137
138 fn resolve_auth_context(&self, options: &StreamOptions) -> Result<BedrockAuthContext> {
139 let auth_storage = self.load_auth_storage()?;
140 if let Some(resolved) = resolve_aws_credentials(&auth_storage) {
141 return Ok(match resolved {
142 AwsResolvedCredentials::Sigv4 {
143 access_key_id,
144 secret_access_key,
145 session_token,
146 region,
147 } => BedrockAuthContext {
148 auth: BedrockAuth::Sigv4 {
149 access_key_id,
150 secret_access_key,
151 session_token,
152 },
153 region,
154 },
155 AwsResolvedCredentials::Bearer { token, region } => BedrockAuthContext {
156 auth: BedrockAuth::Bearer { token },
157 region,
158 },
159 });
160 }
161
162 if let Some(token) = options
163 .api_key
164 .as_deref()
165 .map(str::trim)
166 .filter(|token| !token.is_empty())
167 {
168 return Ok(BedrockAuthContext {
169 auth: BedrockAuth::Bearer {
170 token: token.to_string(),
171 },
172 region: std::env::var("AWS_REGION")
173 .ok()
174 .or_else(|| std::env::var("AWS_DEFAULT_REGION").ok())
175 .unwrap_or_else(|| DEFAULT_REGION.to_string()),
176 });
177 }
178
179 Err(Error::auth(
180 "Amazon Bedrock requires AWS credentials. Set AWS_ACCESS_KEY_ID/AWS_SECRET_ACCESS_KEY, AWS_BEARER_TOKEN_BEDROCK, or store amazon-bedrock credentials in auth.json.",
181 ))
182 }
183
184 fn converse_url(&self, region: &str) -> Result<Url> {
185 let base = self
186 .base_url_override
187 .clone()
188 .unwrap_or_else(|| format!("https://bedrock-runtime.{region}.amazonaws.com"));
189 let mut url = Url::parse(&base)
190 .map_err(|err| Error::provider("amazon-bedrock", format!("Invalid base URL: {err}")))?;
191
192 if self.model.trim().is_empty() {
193 return Err(Error::provider(
194 "amazon-bedrock",
195 "Bedrock model id cannot be empty",
196 ));
197 }
198
199 if url.path().ends_with("/converse") || url.path().ends_with("/converse-stream") {
200 return Ok(url);
201 }
202
203 {
204 let mut segments = url.path_segments_mut().map_err(|()| {
205 Error::provider(
206 "amazon-bedrock",
207 "Bedrock base URL does not support path segments",
208 )
209 })?;
210 segments.push("model");
211 segments.push(&self.model);
212 segments.push("converse");
213 }
214 Ok(url)
215 }
216
217 pub fn build_request(context: &Context<'_>, options: &StreamOptions) -> BedrockConverseRequest {
218 let mut system = Vec::new();
219 if let Some(system_prompt) = context
220 .system_prompt
221 .as_deref()
222 .map(str::trim)
223 .filter(|prompt| !prompt.is_empty())
224 {
225 system.push(BedrockSystemContent {
226 text: system_prompt.to_string(),
227 });
228 }
229
230 let mut messages = Vec::new();
231 for message in context.messages.iter() {
232 if let Some(converted) = convert_message(message) {
233 messages.push(converted);
234 }
235 }
236
237 if messages.is_empty() {
238 messages.push(BedrockMessage {
239 role: "user",
240 content: vec![BedrockContent::Text {
241 text: "Hello".to_string(),
242 }],
243 });
244 }
245
246 let inference_config = if options.max_tokens.is_some() || options.temperature.is_some() {
247 Some(BedrockInferenceConfig {
248 max_tokens: options.max_tokens,
249 temperature: options.temperature,
250 })
251 } else {
252 None
253 };
254
255 let tool_config = if context.tools.is_empty() {
256 None
257 } else {
258 Some(BedrockToolConfig {
259 tools: context.tools.iter().map(convert_tool).collect(),
260 })
261 };
262
263 BedrockConverseRequest {
264 system,
265 messages,
266 inference_config,
267 tool_config,
268 }
269 }
270
271 fn response_to_message(&self, response: BedrockConverseResponse) -> AssistantMessage {
272 let usage = response
273 .usage
274 .as_ref()
275 .map_or_else(Usage::default, convert_usage);
276
277 let stop_reason = map_stop_reason(response.stop_reason.as_deref());
278 let mut content = Vec::new();
279
280 if let Some(output) = response.output {
281 for block in output.message.content {
282 match block {
283 BedrockResponseContent::Text { text } => {
284 if !text.is_empty() {
285 content.push(ContentBlock::Text(TextContent {
286 text,
287 text_signature: None,
288 }));
289 }
290 }
291 BedrockResponseContent::ToolUse { tool_use } => {
292 content.push(ContentBlock::ToolCall(ToolCall {
293 id: tool_use.tool_use_id,
294 name: tool_use.name,
295 arguments: tool_use.input,
296 thought_signature: None,
297 }));
298 }
299 }
300 }
301 }
302
303 AssistantMessage {
304 content,
305 api: "bedrock-converse-stream".to_string(),
306 provider: self.provider_name.clone(),
307 model: self.model.clone(),
308 usage,
309 stop_reason,
310 error_message: None,
311 timestamp: Utc::now().timestamp_millis(),
312 }
313 }
314
315 fn message_events(message: &AssistantMessage) -> Vec<Result<StreamEvent>> {
316 let mut events = Vec::new();
317 events.push(Ok(StreamEvent::Start {
318 partial: message.clone(),
319 }));
320 for (content_index, block) in message.content.iter().enumerate() {
321 match block {
322 ContentBlock::Text(text) => {
323 events.push(Ok(StreamEvent::TextStart { content_index }));
324 events.push(Ok(StreamEvent::TextDelta {
325 content_index,
326 delta: text.text.clone(),
327 }));
328 events.push(Ok(StreamEvent::TextEnd {
329 content_index,
330 content: text.text.clone(),
331 }));
332 }
333 ContentBlock::ToolCall(tool_call) => {
334 let delta = serde_json::to_string(&tool_call.arguments)
335 .unwrap_or_else(|_| "{}".to_string());
336 events.push(Ok(StreamEvent::ToolCallStart { content_index }));
337 events.push(Ok(StreamEvent::ToolCallDelta {
338 content_index,
339 delta,
340 }));
341 events.push(Ok(StreamEvent::ToolCallEnd {
342 content_index,
343 tool_call: tool_call.clone(),
344 }));
345 }
346 _ => {}
347 }
348 }
349
350 events.push(Ok(StreamEvent::Done {
351 reason: message.stop_reason,
352 message: message.clone(),
353 }));
354 events
355 }
356}
357
358#[async_trait]
359impl Provider for BedrockProvider {
360 fn name(&self) -> &str {
361 &self.provider_name
362 }
363
364 fn api(&self) -> &'static str {
365 "bedrock-converse-stream"
366 }
367
368 fn model_id(&self) -> &str {
369 &self.model
370 }
371
372 async fn stream(
373 &self,
374 context: &Context<'_>,
375 options: &StreamOptions,
376 ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamEvent>> + Send>>> {
377 let request_body = Self::build_request(context, options);
378 let body = serde_json::to_vec(&request_body).map_err(|err| {
379 Error::provider(
380 "amazon-bedrock",
381 format!("Failed to serialize request body: {err}"),
382 )
383 })?;
384
385 let auth_context = self.resolve_auth_context(options)?;
386 let url = self.converse_url(&auth_context.region)?;
387
388 let mut request = self
389 .client
390 .post(url.as_str())
391 .header("Content-Type", "application/json")
392 .header("Accept", "application/json");
393
394 match auth_context.auth {
395 BedrockAuth::Bearer { token } => {
396 request = request.header("Authorization", format!("Bearer {token}"));
397 }
398 BedrockAuth::Sigv4 {
399 access_key_id,
400 secret_access_key,
401 session_token,
402 } => {
403 let signing_headers = build_sigv4_headers(
404 &url,
405 &body,
406 &access_key_id,
407 &secret_access_key,
408 session_token.as_deref(),
409 &auth_context.region,
410 Utc::now(),
411 )?;
412 request = request
413 .header("Authorization", signing_headers.authorization)
414 .header("x-amz-date", signing_headers.amz_date)
415 .header("x-amz-content-sha256", signing_headers.payload_hash);
416 if let Some(token) = signing_headers.security_token {
417 request = request.header("x-amz-security-token", token);
418 }
419 }
420 }
421
422 if let Some(compat) = &self.compat
423 && let Some(custom_headers) = &compat.custom_headers
424 {
425 for (name, value) in custom_headers {
426 request = request.header(name, value);
427 }
428 }
429
430 for (name, value) in &options.headers {
431 request = request.header(name, value);
432 }
433
434 let response = request.body(body).send().await?;
435 let status = response.status();
436 let response_text = response
437 .text()
438 .await
439 .unwrap_or_else(|err| format!("<failed to read body: {err}>"));
440
441 if !(200..300).contains(&status) {
442 return Err(Error::provider(
443 "amazon-bedrock",
444 format!("Bedrock Converse API error (HTTP {status}): {response_text}"),
445 ));
446 }
447
448 let parsed: BedrockConverseResponse =
449 serde_json::from_str(&response_text).map_err(|err| {
450 Error::provider(
451 "amazon-bedrock",
452 format!("Failed to parse Bedrock response: {err}"),
453 )
454 })?;
455
456 let message = self.response_to_message(parsed);
457 Ok(Box::pin(stream::iter(Self::message_events(&message))))
458 }
459}
460
461#[derive(Debug, Serialize)]
462#[serde(rename_all = "camelCase")]
463pub struct BedrockConverseRequest {
464 #[serde(skip_serializing_if = "Vec::is_empty")]
465 system: Vec<BedrockSystemContent>,
466 messages: Vec<BedrockMessage>,
467 #[serde(rename = "inferenceConfig", skip_serializing_if = "Option::is_none")]
468 inference_config: Option<BedrockInferenceConfig>,
469 #[serde(rename = "toolConfig", skip_serializing_if = "Option::is_none")]
470 tool_config: Option<BedrockToolConfig>,
471}
472
473#[derive(Debug, Serialize)]
474struct BedrockSystemContent {
475 text: String,
476}
477
478#[derive(Debug, Serialize)]
479struct BedrockMessage {
480 role: &'static str,
481 content: Vec<BedrockContent>,
482}
483
484#[derive(Debug, Serialize)]
485#[serde(untagged)]
486enum BedrockContent {
487 Text {
488 text: String,
489 },
490 Image {
491 image: BedrockImageBlock,
492 },
493 ToolUse {
494 #[serde(rename = "toolUse")]
495 tool_use: BedrockToolUse,
496 },
497 ToolResult {
498 #[serde(rename = "toolResult")]
499 tool_result: BedrockToolResult,
500 },
501}
502
503#[derive(Debug, Serialize)]
504struct BedrockImageBlock {
505 format: String,
506 source: BedrockImageSource,
507}
508
509#[derive(Debug, Serialize)]
510struct BedrockImageSource {
511 bytes: String,
512}
513
514#[derive(Debug, Serialize)]
515#[serde(rename_all = "camelCase")]
516struct BedrockToolUse {
517 tool_use_id: String,
518 name: String,
519 input: Value,
520}
521
522#[derive(Debug, Serialize)]
523#[serde(rename_all = "camelCase")]
524struct BedrockToolResult {
525 tool_use_id: String,
526 content: Vec<BedrockToolResultContent>,
527 status: String,
528}
529
530#[derive(Debug, Serialize)]
531#[serde(untagged)]
532enum BedrockToolResultContent {
533 Text { text: String },
534}
535
536#[derive(Debug, Serialize)]
537#[serde(rename_all = "camelCase")]
538struct BedrockInferenceConfig {
539 #[serde(skip_serializing_if = "Option::is_none")]
540 max_tokens: Option<u32>,
541 #[serde(skip_serializing_if = "Option::is_none")]
542 temperature: Option<f32>,
543}
544
545#[derive(Debug, Serialize)]
546struct BedrockToolConfig {
547 tools: Vec<BedrockToolDef>,
548}
549
550#[derive(Debug, Serialize)]
551#[serde(rename_all = "camelCase")]
552struct BedrockToolDef {
553 tool_spec: BedrockToolSpec,
554}
555
556#[derive(Debug, Serialize)]
557#[serde(rename_all = "camelCase")]
558struct BedrockToolSpec {
559 name: String,
560 description: String,
561 input_schema: BedrockInputSchema,
562}
563
564#[derive(Debug, Serialize)]
565struct BedrockInputSchema {
566 json: Value,
567}
568
569fn convert_message(message: &Message) -> Option<BedrockMessage> {
570 match message {
571 Message::User(user_message) => convert_user_message(user_message),
572 Message::Assistant(assistant_message) => convert_assistant_message(assistant_message),
573 Message::ToolResult(tool_result_message) => {
574 Some(convert_tool_result_message(tool_result_message))
575 }
576 Message::Custom(_) => None,
577 }
578}
579
580fn convert_user_message(message: &crate::model::UserMessage) -> Option<BedrockMessage> {
581 let mut content = Vec::new();
582 match &message.content {
583 UserContent::Text(text) => {
584 if !text.trim().is_empty() {
585 content.push(BedrockContent::Text { text: text.clone() });
586 }
587 }
588 UserContent::Blocks(blocks) => {
589 for block in blocks {
590 match block {
591 ContentBlock::Text(text) => {
592 if !text.text.trim().is_empty() {
593 content.push(BedrockContent::Text {
594 text: text.text.clone(),
595 });
596 }
597 }
598 ContentBlock::Image(img) => {
599 let format = img
600 .mime_type
601 .rsplit('/')
602 .next()
603 .unwrap_or("png")
604 .to_string();
605 content.push(BedrockContent::Image {
606 image: BedrockImageBlock {
607 format,
608 source: BedrockImageSource {
609 bytes: img.data.clone(),
610 },
611 },
612 });
613 }
614 _ => {}
615 }
616 }
617 }
618 }
619
620 if content.is_empty() {
621 None
622 } else {
623 Some(BedrockMessage {
624 role: "user",
625 content,
626 })
627 }
628}
629
630fn convert_assistant_message(message: &AssistantMessage) -> Option<BedrockMessage> {
631 let mut content = Vec::new();
632 for block in &message.content {
633 match block {
634 ContentBlock::Text(text) => {
635 if !text.text.trim().is_empty() {
636 content.push(BedrockContent::Text {
637 text: text.text.clone(),
638 });
639 }
640 }
641 ContentBlock::ToolCall(tool_call) => {
642 content.push(BedrockContent::ToolUse {
643 tool_use: BedrockToolUse {
644 tool_use_id: tool_call.id.clone(),
645 name: tool_call.name.clone(),
646 input: tool_call.arguments.clone(),
647 },
648 });
649 }
650 _ => {}
651 }
652 }
653
654 if content.is_empty() {
655 None
656 } else {
657 Some(BedrockMessage {
658 role: "assistant",
659 content,
660 })
661 }
662}
663
664fn convert_tool_result_message(message: &ToolResultMessage) -> BedrockMessage {
665 let text = message
666 .content
667 .iter()
668 .filter_map(|block| match block {
669 ContentBlock::Text(text) => Some(text.text.as_str()),
670 _ => None,
671 })
672 .collect::<Vec<_>>()
673 .join("\n");
674
675 let result_text = if text.trim().is_empty() {
676 "{}".to_string()
677 } else {
678 text
679 };
680
681 BedrockMessage {
682 role: "user",
683 content: vec![BedrockContent::ToolResult {
684 tool_result: BedrockToolResult {
685 tool_use_id: message.tool_call_id.clone(),
686 content: vec![BedrockToolResultContent::Text { text: result_text }],
687 status: if message.is_error {
688 "error".to_string()
689 } else {
690 "success".to_string()
691 },
692 },
693 }],
694 }
695}
696
697fn convert_tool(tool: &ToolDef) -> BedrockToolDef {
698 BedrockToolDef {
699 tool_spec: BedrockToolSpec {
700 name: tool.name.clone(),
701 description: tool.description.clone(),
702 input_schema: BedrockInputSchema {
703 json: tool.parameters.clone(),
704 },
705 },
706 }
707}
708
709#[derive(Debug, Deserialize)]
710#[serde(rename_all = "camelCase")]
711struct BedrockConverseResponse {
712 #[serde(default)]
713 output: Option<BedrockResponseOutput>,
714 #[serde(default)]
715 stop_reason: Option<String>,
716 #[serde(default)]
717 usage: Option<BedrockUsage>,
718}
719
720#[derive(Debug, Deserialize)]
721struct BedrockResponseOutput {
722 message: BedrockResponseMessage,
723}
724
725#[derive(Debug, Deserialize)]
726struct BedrockResponseMessage {
727 #[allow(dead_code)]
728 role: Option<String>,
729 #[serde(default)]
730 content: Vec<BedrockResponseContent>,
731}
732
733#[derive(Debug, Deserialize)]
734#[serde(untagged)]
735enum BedrockResponseContent {
736 Text {
737 text: String,
738 },
739 ToolUse {
740 #[serde(rename = "toolUse")]
741 tool_use: BedrockResponseToolUse,
742 },
743}
744
745#[derive(Debug, Deserialize)]
746#[serde(rename_all = "camelCase")]
747struct BedrockResponseToolUse {
748 tool_use_id: String,
749 name: String,
750 #[serde(default)]
751 input: Value,
752}
753
754#[derive(Debug, Deserialize)]
755#[serde(rename_all = "camelCase")]
756#[allow(clippy::struct_field_names)]
757struct BedrockUsage {
758 #[serde(default)]
759 input_tokens: u64,
760 #[serde(default)]
761 output_tokens: u64,
762 #[serde(default)]
763 total_tokens: u64,
764}
765
766fn convert_usage(usage: &BedrockUsage) -> Usage {
767 let total = if usage.total_tokens > 0 {
768 usage.total_tokens
769 } else {
770 usage.input_tokens + usage.output_tokens
771 };
772
773 Usage {
774 input: usage.input_tokens,
775 output: usage.output_tokens,
776 total_tokens: total,
777 ..Usage::default()
778 }
779}
780
781fn map_stop_reason(stop_reason: Option<&str>) -> StopReason {
782 match stop_reason.unwrap_or("end_turn") {
783 "tool_use" => StopReason::ToolUse,
784 "max_tokens" => StopReason::Length,
785 "guardrail_intervened" | "content_filtered" => StopReason::Error,
786 _ => StopReason::Stop,
787 }
788}
789
790fn normalize_model_id(model_id: &str) -> Result<String> {
791 let mut normalized = model_id.trim().trim_matches('/');
792 if normalized.is_empty() {
793 return Err(Error::provider(
794 "amazon-bedrock",
795 "Bedrock model id cannot be empty",
796 ));
797 }
798
799 for prefix in ["amazon-bedrock/", "bedrock/", "model/"] {
800 if let Some(stripped) = normalized.strip_prefix(prefix) {
801 normalized = stripped;
802 break;
803 }
804 }
805
806 if let Some((_, stripped)) = normalized.split_once("/model/") {
807 normalized = stripped;
808 }
809
810 for suffix in ["/converse-stream", "/converse"] {
811 if let Some(stripped) = normalized.strip_suffix(suffix) {
812 normalized = stripped;
813 break;
814 }
815 }
816
817 let final_id = normalized.trim_matches('/');
818 if final_id.is_empty() {
819 return Err(Error::provider(
820 "amazon-bedrock",
821 "Bedrock model id cannot be empty",
822 ));
823 }
824
825 Ok(final_id.to_string())
826}
827
828fn build_sigv4_headers(
829 url: &Url,
830 payload: &[u8],
831 access_key_id: &str,
832 secret_access_key: &str,
833 session_token: Option<&str>,
834 region: &str,
835 now: DateTime<Utc>,
836) -> Result<Sigv4Headers> {
837 let amz_date = now.format("%Y%m%dT%H%M%SZ").to_string();
838 let date_stamp = now.format("%Y%m%d").to_string();
839 let payload_hash = sha256_hex(payload);
840 let host = canonical_host(url)?;
841 let canonical_uri = canonical_uri(url);
842 let canonical_query = canonical_query(url);
843
844 let mut canonical_headers = vec![
845 ("content-type".to_string(), "application/json".to_string()),
846 ("host".to_string(), host),
847 ("x-amz-content-sha256".to_string(), payload_hash.clone()),
848 ("x-amz-date".to_string(), amz_date.clone()),
849 ];
850 if let Some(token) = session_token {
851 canonical_headers.push(("x-amz-security-token".to_string(), token.to_string()));
852 }
853 canonical_headers.sort_by(|left, right| left.0.cmp(&right.0));
854
855 let signed_headers = canonical_headers
856 .iter()
857 .map(|(name, _)| name.as_str())
858 .collect::<Vec<_>>()
859 .join(";");
860
861 let mut canonical_headers_block = String::new();
862 for (name, value) in &canonical_headers {
863 let trimmed = value.trim();
864 writeln!(&mut canonical_headers_block, "{name}:{trimmed}")
865 .map_err(|err| Error::api(format!("Failed to build canonical headers: {err}")))?;
866 }
867
868 let canonical_request = format!(
869 "POST\n{canonical_uri}\n{canonical_query}\n{canonical_headers_block}\n{signed_headers}\n{payload_hash}"
870 );
871 let canonical_request_hash = sha256_hex(canonical_request.as_bytes());
872 let credential_scope = format!("{date_stamp}/{region}/{BEDROCK_SERVICE}/aws4_request");
873 let string_to_sign =
874 format!("AWS4-HMAC-SHA256\n{amz_date}\n{credential_scope}\n{canonical_request_hash}");
875 let signature = hex_encode(&signing_key(
876 secret_access_key,
877 &date_stamp,
878 region,
879 &string_to_sign,
880 )?);
881
882 let authorization = format!(
883 "AWS4-HMAC-SHA256 Credential={access_key_id}/{credential_scope}, SignedHeaders={signed_headers}, Signature={signature}"
884 );
885
886 Ok(Sigv4Headers {
887 authorization,
888 amz_date,
889 payload_hash,
890 security_token: session_token.map(ToString::to_string),
891 })
892}
893
894fn canonical_host(url: &Url) -> Result<String> {
895 let host = url.host_str().ok_or_else(|| {
896 Error::provider("amazon-bedrock", "Bedrock endpoint URL is missing a host")
897 })?;
898 Ok(url
899 .port()
900 .map_or_else(|| host.to_string(), |port| format!("{host}:{port}")))
901}
902
903fn canonical_uri(url: &Url) -> String {
904 let segments = url
905 .path_segments()
906 .map(|parts| parts.map(aws_percent_encode).collect::<Vec<_>>())
907 .unwrap_or_default();
908
909 if segments.is_empty() {
910 "/".to_string()
911 } else {
912 format!("/{}", segments.join("/"))
913 }
914}
915
916fn canonical_query(url: &Url) -> String {
917 let mut pairs = url
918 .query_pairs()
919 .map(|(key, value)| (aws_percent_encode(&key), aws_percent_encode(&value)))
920 .collect::<Vec<_>>();
921 pairs.sort();
922 pairs
923 .into_iter()
924 .map(|(key, value)| format!("{key}={value}"))
925 .collect::<Vec<_>>()
926 .join("&")
927}
928
929fn aws_percent_encode(value: &str) -> String {
930 let mut encoded = String::with_capacity(value.len());
931 for byte in value.bytes() {
932 if byte.is_ascii_alphanumeric() || matches!(byte, b'-' | b'_' | b'.' | b'~') {
933 encoded.push(char::from(byte));
934 } else {
935 encoded.push('%');
936 encoded.push(nibble_to_hex(byte >> 4));
937 encoded.push(nibble_to_hex(byte & 0x0f));
938 }
939 }
940 encoded
941}
942
943fn nibble_to_hex(nibble: u8) -> char {
944 match nibble {
945 0..=9 => char::from(b'0' + nibble),
946 10..=15 => char::from(b'A' + nibble - 10),
947 _ => '0',
948 }
949}
950
951fn signing_key(
952 secret_access_key: &str,
953 date_stamp: &str,
954 region: &str,
955 string_to_sign: &str,
956) -> Result<Vec<u8>> {
957 let key_date = hmac_sha256(
958 format!("AWS4{secret_access_key}").as_bytes(),
959 date_stamp.as_bytes(),
960 )?;
961 let key_region = hmac_sha256(&key_date, region.as_bytes())?;
962 let key_service = hmac_sha256(&key_region, BEDROCK_SERVICE.as_bytes())?;
963 let key_signing = hmac_sha256(&key_service, b"aws4_request")?;
964 hmac_sha256(&key_signing, string_to_sign.as_bytes())
965}
966
967fn hmac_sha256(key: &[u8], data: &[u8]) -> Result<Vec<u8>> {
968 let mut mac = HmacSha256::new_from_slice(key)
969 .map_err(|err| Error::api(format!("Failed to initialize HMAC: {err}")))?;
970 mac.update(data);
971 Ok(mac.finalize().into_bytes().to_vec())
972}
973
974fn sha256_hex(bytes: &[u8]) -> String {
975 let digest = Sha256::digest(bytes);
976 hex_encode(&digest)
977}
978
979fn hex_encode(bytes: &[u8]) -> String {
980 let mut out = String::with_capacity(bytes.len() * 2);
981 for byte in bytes {
982 let _ = write!(&mut out, "{byte:02x}");
983 }
984 out
985}
986
987#[cfg(test)]
988mod tests {
989 use super::*;
990 use chrono::TimeZone as _;
991 use serde_json::json;
992
993 fn test_context_with_tools() -> Context<'static> {
994 Context {
995 system_prompt: Some("You are concise.".to_string().into()),
996 messages: vec![
997 Message::User(crate::model::UserMessage {
998 content: UserContent::Text("Ping".to_string()),
999 timestamp: 0,
1000 }),
1001 Message::assistant(AssistantMessage {
1002 content: vec![ContentBlock::ToolCall(ToolCall {
1003 id: "tool_1".to_string(),
1004 name: "search".to_string(),
1005 arguments: json!({ "q": "rust" }),
1006 thought_signature: None,
1007 })],
1008 api: "bedrock-converse-stream".to_string(),
1009 provider: "amazon-bedrock".to_string(),
1010 model: "m".to_string(),
1011 usage: Usage::default(),
1012 stop_reason: StopReason::ToolUse,
1013 error_message: None,
1014 timestamp: 0,
1015 }),
1016 Message::tool_result(ToolResultMessage {
1017 tool_call_id: "tool_1".to_string(),
1018 tool_name: "search".to_string(),
1019 content: vec![ContentBlock::Text(TextContent {
1020 text: "result".to_string(),
1021 text_signature: None,
1022 })],
1023 details: None,
1024 is_error: false,
1025 timestamp: 0,
1026 }),
1027 ]
1028 .into(),
1029 tools: vec![ToolDef {
1030 name: "search".to_string(),
1031 description: "Search docs".to_string(),
1032 parameters: json!({
1033 "type": "object",
1034 "properties": {"q": {"type": "string"}},
1035 "required": ["q"]
1036 }),
1037 }]
1038 .into(),
1039 }
1040 }
1041
1042 #[test]
1043 fn build_request_includes_system_messages_and_tools() {
1044 let request = BedrockProvider::build_request(
1045 &test_context_with_tools(),
1046 &StreamOptions {
1047 max_tokens: Some(321),
1048 temperature: Some(0.2),
1049 ..StreamOptions::default()
1050 },
1051 );
1052
1053 let value = serde_json::to_value(&request).expect("serialize request");
1054 assert_eq!(value["system"][0]["text"], "You are concise.");
1055 assert_eq!(value["messages"][0]["role"], "user");
1056 assert_eq!(
1057 value["messages"][1]["content"][0]["toolUse"]["name"],
1058 "search"
1059 );
1060 assert_eq!(
1061 value["messages"][2]["content"][0]["toolResult"]["status"],
1062 "success"
1063 );
1064 assert_eq!(value["inferenceConfig"]["maxTokens"], 321);
1065 assert_eq!(
1066 value["toolConfig"]["tools"][0]["toolSpec"]["name"],
1067 "search"
1068 );
1069 }
1070
1071 #[test]
1072 fn converse_url_appends_model_path_and_encodes_model_id() {
1073 let provider = BedrockProvider::new("anthropic.claude-3-5-sonnet-20240620-v1:0")
1074 .with_base_url("https://bedrock-runtime.us-east-1.amazonaws.com");
1075 let url = provider
1076 .converse_url("us-east-1")
1077 .expect("build converse URL");
1078 assert_eq!(
1079 url.path(),
1080 "/model/anthropic.claude-3-5-sonnet-20240620-v1:0/converse"
1081 );
1082 }
1083
1084 #[test]
1085 fn normalize_model_id_accepts_prefixed_variants() {
1086 assert_eq!(
1087 normalize_model_id("bedrock/us.anthropic.claude-3-7-sonnet-20250219-v1:0")
1088 .expect("normalize regional prefix"),
1089 "us.anthropic.claude-3-7-sonnet-20250219-v1:0"
1090 );
1091 assert_eq!(
1092 normalize_model_id("model/anthropic.claude-3-5-sonnet-20240620-v1:0/converse")
1093 .expect("normalize model path"),
1094 "anthropic.claude-3-5-sonnet-20240620-v1:0"
1095 );
1096 }
1097
1098 #[test]
1099 fn sigv4_headers_include_expected_scope_and_token() {
1100 let url =
1101 Url::parse("https://bedrock-runtime.us-west-2.amazonaws.com/model/m.converse/converse")
1102 .expect("url");
1103 let now = Utc
1104 .with_ymd_and_hms(2026, 2, 10, 8, 0, 0)
1105 .single()
1106 .expect("datetime");
1107 let headers = build_sigv4_headers(
1108 &url,
1109 br#"{"messages":[{"role":"user","content":[{"text":"Ping"}]}]}"#,
1110 "AKIDEXAMPLE",
1111 "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY",
1112 Some("session-token"),
1113 "us-west-2",
1114 now,
1115 )
1116 .expect("sign headers");
1117
1118 assert!(
1119 headers
1120 .authorization
1121 .contains("Credential=AKIDEXAMPLE/20260210/us-west-2/bedrock/aws4_request")
1122 );
1123 assert!(headers.authorization.contains(
1124 "SignedHeaders=content-type;host;x-amz-content-sha256;x-amz-date;x-amz-security-token"
1125 ));
1126 assert_eq!(headers.security_token.as_deref(), Some("session-token"));
1127 assert_eq!(headers.amz_date, "20260210T080000Z");
1128 assert_eq!(headers.payload_hash.len(), 64);
1129 }
1130
1131 #[test]
1132 fn response_to_message_maps_tool_use_and_usage() {
1133 let provider = BedrockProvider::new("anthropic.claude-3-5-sonnet-20240620-v1:0");
1134 let response: BedrockConverseResponse = serde_json::from_value(json!({
1135 "output": {
1136 "message": {
1137 "role": "assistant",
1138 "content": [
1139 {"text": "I can help."},
1140 {"toolUse": {"toolUseId": "call_1", "name": "search", "input": {"q": "rust"}}}
1141 ]
1142 }
1143 },
1144 "stopReason": "tool_use",
1145 "usage": {"inputTokens": 10, "outputTokens": 5, "totalTokens": 15}
1146 }))
1147 .expect("parse response");
1148
1149 let message = provider.response_to_message(response);
1150 assert_eq!(message.stop_reason, StopReason::ToolUse);
1151 assert_eq!(message.usage.input, 10);
1152 assert_eq!(message.usage.output, 5);
1153 assert_eq!(message.usage.total_tokens, 15);
1154 assert!(matches!(message.content[0], ContentBlock::Text(_)));
1155 assert!(matches!(message.content[1], ContentBlock::ToolCall(_)));
1156 }
1157
1158 #[test]
1159 fn resolve_auth_context_uses_stream_option_api_key_fallback() {
1160 let temp_dir = tempfile::tempdir().expect("tempdir");
1161 let provider =
1162 BedrockProvider::new("model").with_auth_path(temp_dir.path().join("auth.json"));
1163 let auth = provider
1164 .resolve_auth_context(&StreamOptions {
1165 api_key: Some("bedrock-bearer".to_string()),
1166 ..StreamOptions::default()
1167 })
1168 .expect("resolve auth context");
1169 assert!(matches!(auth.auth, BedrockAuth::Bearer { .. }));
1170 }
1171}