Skip to main content

pi/providers/
bedrock.rs

1//! Amazon Bedrock Converse provider implementation.
2//!
3//! This provider targets the Bedrock Converse API and maps its non-streaming
4//! JSON response into Pi stream events.
5
6use crate::auth::{AuthStorage, AwsResolvedCredentials, resolve_aws_credentials};
7use crate::config::Config;
8use crate::error::{Error, Result};
9use crate::http::client::Client;
10use crate::model::{
11    AssistantMessage, ContentBlock, Message, StopReason, StreamEvent, TextContent, ToolCall,
12    ToolResultMessage, Usage, UserContent,
13};
14use crate::models::CompatConfig;
15use crate::provider::{Context, Provider, StreamOptions, ToolDef};
16use async_trait::async_trait;
17use chrono::{DateTime, Utc};
18use futures::Stream;
19use futures::stream;
20use hmac::{Hmac, Mac};
21use serde::{Deserialize, Serialize};
22use serde_json::Value;
23use sha2::{Digest, Sha256};
24use std::fmt::Write as _;
25#[cfg(test)]
26use std::path::Path;
27use std::path::PathBuf;
28use std::pin::Pin;
29use url::Url;
30
31const DEFAULT_REGION: &str = "us-east-1";
32const BEDROCK_SERVICE: &str = "bedrock";
33
34type HmacSha256 = Hmac<Sha256>;
35
36#[derive(Debug, Clone)]
37enum BedrockAuth {
38    Sigv4 {
39        access_key_id: String,
40        secret_access_key: String,
41        session_token: Option<String>,
42    },
43    Bearer {
44        token: String,
45    },
46}
47
48#[derive(Debug, Clone)]
49struct BedrockAuthContext {
50    auth: BedrockAuth,
51    region: String,
52}
53
54#[derive(Debug, Clone)]
55struct Sigv4Headers {
56    authorization: String,
57    amz_date: String,
58    payload_hash: String,
59    security_token: Option<String>,
60}
61
62/// Amazon Bedrock provider.
63pub struct BedrockProvider {
64    client: Client,
65    model: String,
66    provider_name: String,
67    base_url_override: Option<String>,
68    compat: Option<CompatConfig>,
69    auth_path_override: Option<PathBuf>,
70}
71
72impl BedrockProvider {
73    /// Create a Bedrock provider for the given model ID.
74    pub fn new(model: impl Into<String>) -> Self {
75        let raw_model = model.into();
76        let normalized_model = normalize_model_id(&raw_model)
77            .ok()
78            .unwrap_or_else(|| raw_model.trim().to_string());
79        Self {
80            client: Client::new(),
81            model: normalized_model,
82            provider_name: "amazon-bedrock".to_string(),
83            base_url_override: None,
84            compat: None,
85            auth_path_override: None,
86        }
87    }
88
89    /// Set provider name for event attribution.
90    #[must_use]
91    pub fn with_provider_name(mut self, provider_name: impl Into<String>) -> Self {
92        self.provider_name = provider_name.into();
93        self
94    }
95
96    /// Override Bedrock base URL (useful for tests/proxies).
97    #[must_use]
98    pub fn with_base_url(mut self, base_url: impl AsRef<str>) -> Self {
99        let trimmed = base_url.as_ref().trim();
100        if !trimmed.is_empty() {
101            self.base_url_override = Some(trimmed.to_string());
102        }
103        self
104    }
105
106    /// Attach provider compatibility overrides.
107    #[must_use]
108    pub fn with_compat(mut self, compat: Option<CompatConfig>) -> Self {
109        self.compat = compat;
110        self
111    }
112
113    /// Inject a custom HTTP client.
114    #[must_use]
115    pub fn with_client(mut self, client: Client) -> Self {
116        self.client = client;
117        self
118    }
119
120    #[cfg(test)]
121    #[must_use]
122    fn with_auth_path(mut self, path: impl AsRef<Path>) -> Self {
123        self.auth_path_override = Some(path.as_ref().to_path_buf());
124        self
125    }
126
127    fn auth_path(&self) -> PathBuf {
128        self.auth_path_override
129            .clone()
130            .unwrap_or_else(Config::auth_path)
131    }
132
133    fn load_auth_storage(&self) -> Result<AuthStorage> {
134        AuthStorage::load(self.auth_path())
135            .map_err(|err| Error::auth(format!("Failed to load Bedrock credentials: {err}")))
136    }
137
138    fn resolve_auth_context(&self, options: &StreamOptions) -> Result<BedrockAuthContext> {
139        let auth_storage = self.load_auth_storage()?;
140        if let Some(resolved) = resolve_aws_credentials(&auth_storage) {
141            return Ok(match resolved {
142                AwsResolvedCredentials::Sigv4 {
143                    access_key_id,
144                    secret_access_key,
145                    session_token,
146                    region,
147                } => BedrockAuthContext {
148                    auth: BedrockAuth::Sigv4 {
149                        access_key_id,
150                        secret_access_key,
151                        session_token,
152                    },
153                    region,
154                },
155                AwsResolvedCredentials::Bearer { token, region } => BedrockAuthContext {
156                    auth: BedrockAuth::Bearer { token },
157                    region,
158                },
159            });
160        }
161
162        if let Some(token) = options
163            .api_key
164            .as_deref()
165            .map(str::trim)
166            .filter(|token| !token.is_empty())
167        {
168            return Ok(BedrockAuthContext {
169                auth: BedrockAuth::Bearer {
170                    token: token.to_string(),
171                },
172                region: std::env::var("AWS_REGION")
173                    .ok()
174                    .or_else(|| std::env::var("AWS_DEFAULT_REGION").ok())
175                    .unwrap_or_else(|| DEFAULT_REGION.to_string()),
176            });
177        }
178
179        Err(Error::auth(
180            "Amazon Bedrock requires AWS credentials. Set AWS_ACCESS_KEY_ID/AWS_SECRET_ACCESS_KEY, AWS_BEARER_TOKEN_BEDROCK, or store amazon-bedrock credentials in auth.json.",
181        ))
182    }
183
184    fn converse_url(&self, region: &str) -> Result<Url> {
185        let base = self
186            .base_url_override
187            .clone()
188            .unwrap_or_else(|| format!("https://bedrock-runtime.{region}.amazonaws.com"));
189        let mut url = Url::parse(&base)
190            .map_err(|err| Error::provider("amazon-bedrock", format!("Invalid base URL: {err}")))?;
191
192        if self.model.trim().is_empty() {
193            return Err(Error::provider(
194                "amazon-bedrock",
195                "Bedrock model id cannot be empty",
196            ));
197        }
198
199        if url.path().ends_with("/converse") || url.path().ends_with("/converse-stream") {
200            return Ok(url);
201        }
202
203        {
204            let mut segments = url.path_segments_mut().map_err(|()| {
205                Error::provider(
206                    "amazon-bedrock",
207                    "Bedrock base URL does not support path segments",
208                )
209            })?;
210            segments.push("model");
211            segments.push(&self.model);
212            segments.push("converse");
213        }
214        Ok(url)
215    }
216
217    pub fn build_request(context: &Context<'_>, options: &StreamOptions) -> BedrockConverseRequest {
218        let mut system = Vec::new();
219        if let Some(system_prompt) = context
220            .system_prompt
221            .as_deref()
222            .map(str::trim)
223            .filter(|prompt| !prompt.is_empty())
224        {
225            system.push(BedrockSystemContent {
226                text: system_prompt.to_string(),
227            });
228        }
229
230        let mut messages = Vec::new();
231        for message in context.messages.iter() {
232            if let Some(converted) = convert_message(message) {
233                messages.push(converted);
234            }
235        }
236
237        if messages.is_empty() {
238            messages.push(BedrockMessage {
239                role: "user",
240                content: vec![BedrockContent::Text {
241                    text: "Hello".to_string(),
242                }],
243            });
244        }
245
246        let inference_config = if options.max_tokens.is_some() || options.temperature.is_some() {
247            Some(BedrockInferenceConfig {
248                max_tokens: options.max_tokens,
249                temperature: options.temperature,
250            })
251        } else {
252            None
253        };
254
255        let tool_config = if context.tools.is_empty() {
256            None
257        } else {
258            Some(BedrockToolConfig {
259                tools: context.tools.iter().map(convert_tool).collect(),
260            })
261        };
262
263        BedrockConverseRequest {
264            system,
265            messages,
266            inference_config,
267            tool_config,
268        }
269    }
270
271    fn response_to_message(&self, response: BedrockConverseResponse) -> AssistantMessage {
272        let usage = response
273            .usage
274            .as_ref()
275            .map_or_else(Usage::default, convert_usage);
276
277        let stop_reason = map_stop_reason(response.stop_reason.as_deref());
278        let mut content = Vec::new();
279
280        if let Some(output) = response.output {
281            for block in output.message.content {
282                match block {
283                    BedrockResponseContent::Text { text } => {
284                        if !text.is_empty() {
285                            content.push(ContentBlock::Text(TextContent {
286                                text,
287                                text_signature: None,
288                            }));
289                        }
290                    }
291                    BedrockResponseContent::ToolUse { tool_use } => {
292                        content.push(ContentBlock::ToolCall(ToolCall {
293                            id: tool_use.tool_use_id,
294                            name: tool_use.name,
295                            arguments: tool_use.input,
296                            thought_signature: None,
297                        }));
298                    }
299                }
300            }
301        }
302
303        AssistantMessage {
304            content,
305            api: "bedrock-converse-stream".to_string(),
306            provider: self.provider_name.clone(),
307            model: self.model.clone(),
308            usage,
309            stop_reason,
310            error_message: None,
311            timestamp: Utc::now().timestamp_millis(),
312        }
313    }
314
315    fn message_events(message: &AssistantMessage) -> Vec<Result<StreamEvent>> {
316        let mut events = Vec::new();
317        events.push(Ok(StreamEvent::Start {
318            partial: message.clone(),
319        }));
320        for (content_index, block) in message.content.iter().enumerate() {
321            match block {
322                ContentBlock::Text(text) => {
323                    events.push(Ok(StreamEvent::TextStart { content_index }));
324                    events.push(Ok(StreamEvent::TextDelta {
325                        content_index,
326                        delta: text.text.clone(),
327                    }));
328                    events.push(Ok(StreamEvent::TextEnd {
329                        content_index,
330                        content: text.text.clone(),
331                    }));
332                }
333                ContentBlock::ToolCall(tool_call) => {
334                    let delta = serde_json::to_string(&tool_call.arguments)
335                        .unwrap_or_else(|_| "{}".to_string());
336                    events.push(Ok(StreamEvent::ToolCallStart { content_index }));
337                    events.push(Ok(StreamEvent::ToolCallDelta {
338                        content_index,
339                        delta,
340                    }));
341                    events.push(Ok(StreamEvent::ToolCallEnd {
342                        content_index,
343                        tool_call: tool_call.clone(),
344                    }));
345                }
346                _ => {}
347            }
348        }
349
350        events.push(Ok(StreamEvent::Done {
351            reason: message.stop_reason,
352            message: message.clone(),
353        }));
354        events
355    }
356}
357
358#[async_trait]
359impl Provider for BedrockProvider {
360    fn name(&self) -> &str {
361        &self.provider_name
362    }
363
364    fn api(&self) -> &'static str {
365        "bedrock-converse-stream"
366    }
367
368    fn model_id(&self) -> &str {
369        &self.model
370    }
371
372    async fn stream(
373        &self,
374        context: &Context<'_>,
375        options: &StreamOptions,
376    ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamEvent>> + Send>>> {
377        let request_body = Self::build_request(context, options);
378        let body = serde_json::to_vec(&request_body).map_err(|err| {
379            Error::provider(
380                "amazon-bedrock",
381                format!("Failed to serialize request body: {err}"),
382            )
383        })?;
384
385        let auth_context = self.resolve_auth_context(options)?;
386        let url = self.converse_url(&auth_context.region)?;
387
388        let mut request = self
389            .client
390            .post(url.as_str())
391            .header("Content-Type", "application/json")
392            .header("Accept", "application/json");
393
394        match auth_context.auth {
395            BedrockAuth::Bearer { token } => {
396                request = request.header("Authorization", format!("Bearer {token}"));
397            }
398            BedrockAuth::Sigv4 {
399                access_key_id,
400                secret_access_key,
401                session_token,
402            } => {
403                let signing_headers = build_sigv4_headers(
404                    &url,
405                    &body,
406                    &access_key_id,
407                    &secret_access_key,
408                    session_token.as_deref(),
409                    &auth_context.region,
410                    Utc::now(),
411                )?;
412                request = request
413                    .header("Authorization", signing_headers.authorization)
414                    .header("x-amz-date", signing_headers.amz_date)
415                    .header("x-amz-content-sha256", signing_headers.payload_hash);
416                if let Some(token) = signing_headers.security_token {
417                    request = request.header("x-amz-security-token", token);
418                }
419            }
420        }
421
422        if let Some(compat) = &self.compat
423            && let Some(custom_headers) = &compat.custom_headers
424        {
425            for (name, value) in custom_headers {
426                request = request.header(name, value);
427            }
428        }
429
430        for (name, value) in &options.headers {
431            request = request.header(name, value);
432        }
433
434        let response = request.body(body).send().await?;
435        let status = response.status();
436        let response_text = response
437            .text()
438            .await
439            .unwrap_or_else(|err| format!("<failed to read body: {err}>"));
440
441        if !(200..300).contains(&status) {
442            return Err(Error::provider(
443                "amazon-bedrock",
444                format!("Bedrock Converse API error (HTTP {status}): {response_text}"),
445            ));
446        }
447
448        let parsed: BedrockConverseResponse =
449            serde_json::from_str(&response_text).map_err(|err| {
450                Error::provider(
451                    "amazon-bedrock",
452                    format!("Failed to parse Bedrock response: {err}"),
453                )
454            })?;
455
456        let message = self.response_to_message(parsed);
457        Ok(Box::pin(stream::iter(Self::message_events(&message))))
458    }
459}
460
461#[derive(Debug, Serialize)]
462#[serde(rename_all = "camelCase")]
463pub struct BedrockConverseRequest {
464    #[serde(skip_serializing_if = "Vec::is_empty")]
465    system: Vec<BedrockSystemContent>,
466    messages: Vec<BedrockMessage>,
467    #[serde(rename = "inferenceConfig", skip_serializing_if = "Option::is_none")]
468    inference_config: Option<BedrockInferenceConfig>,
469    #[serde(rename = "toolConfig", skip_serializing_if = "Option::is_none")]
470    tool_config: Option<BedrockToolConfig>,
471}
472
473#[derive(Debug, Serialize)]
474struct BedrockSystemContent {
475    text: String,
476}
477
478#[derive(Debug, Serialize)]
479struct BedrockMessage {
480    role: &'static str,
481    content: Vec<BedrockContent>,
482}
483
484#[derive(Debug, Serialize)]
485#[serde(untagged)]
486enum BedrockContent {
487    Text {
488        text: String,
489    },
490    Image {
491        image: BedrockImageBlock,
492    },
493    ToolUse {
494        #[serde(rename = "toolUse")]
495        tool_use: BedrockToolUse,
496    },
497    ToolResult {
498        #[serde(rename = "toolResult")]
499        tool_result: BedrockToolResult,
500    },
501}
502
503#[derive(Debug, Serialize)]
504struct BedrockImageBlock {
505    format: String,
506    source: BedrockImageSource,
507}
508
509#[derive(Debug, Serialize)]
510struct BedrockImageSource {
511    bytes: String,
512}
513
514#[derive(Debug, Serialize)]
515#[serde(rename_all = "camelCase")]
516struct BedrockToolUse {
517    tool_use_id: String,
518    name: String,
519    input: Value,
520}
521
522#[derive(Debug, Serialize)]
523#[serde(rename_all = "camelCase")]
524struct BedrockToolResult {
525    tool_use_id: String,
526    content: Vec<BedrockToolResultContent>,
527    status: String,
528}
529
530#[derive(Debug, Serialize)]
531#[serde(untagged)]
532enum BedrockToolResultContent {
533    Text { text: String },
534}
535
536#[derive(Debug, Serialize)]
537#[serde(rename_all = "camelCase")]
538struct BedrockInferenceConfig {
539    #[serde(skip_serializing_if = "Option::is_none")]
540    max_tokens: Option<u32>,
541    #[serde(skip_serializing_if = "Option::is_none")]
542    temperature: Option<f32>,
543}
544
545#[derive(Debug, Serialize)]
546struct BedrockToolConfig {
547    tools: Vec<BedrockToolDef>,
548}
549
550#[derive(Debug, Serialize)]
551#[serde(rename_all = "camelCase")]
552struct BedrockToolDef {
553    tool_spec: BedrockToolSpec,
554}
555
556#[derive(Debug, Serialize)]
557#[serde(rename_all = "camelCase")]
558struct BedrockToolSpec {
559    name: String,
560    description: String,
561    input_schema: BedrockInputSchema,
562}
563
564#[derive(Debug, Serialize)]
565struct BedrockInputSchema {
566    json: Value,
567}
568
569fn convert_message(message: &Message) -> Option<BedrockMessage> {
570    match message {
571        Message::User(user_message) => convert_user_message(user_message),
572        Message::Assistant(assistant_message) => convert_assistant_message(assistant_message),
573        Message::ToolResult(tool_result_message) => {
574            Some(convert_tool_result_message(tool_result_message))
575        }
576        Message::Custom(_) => None,
577    }
578}
579
580fn convert_user_message(message: &crate::model::UserMessage) -> Option<BedrockMessage> {
581    let mut content = Vec::new();
582    match &message.content {
583        UserContent::Text(text) => {
584            if !text.trim().is_empty() {
585                content.push(BedrockContent::Text { text: text.clone() });
586            }
587        }
588        UserContent::Blocks(blocks) => {
589            for block in blocks {
590                match block {
591                    ContentBlock::Text(text) => {
592                        if !text.text.trim().is_empty() {
593                            content.push(BedrockContent::Text {
594                                text: text.text.clone(),
595                            });
596                        }
597                    }
598                    ContentBlock::Image(img) => {
599                        let format = img
600                            .mime_type
601                            .rsplit('/')
602                            .next()
603                            .unwrap_or("png")
604                            .to_string();
605                        content.push(BedrockContent::Image {
606                            image: BedrockImageBlock {
607                                format,
608                                source: BedrockImageSource {
609                                    bytes: img.data.clone(),
610                                },
611                            },
612                        });
613                    }
614                    _ => {}
615                }
616            }
617        }
618    }
619
620    if content.is_empty() {
621        None
622    } else {
623        Some(BedrockMessage {
624            role: "user",
625            content,
626        })
627    }
628}
629
630fn convert_assistant_message(message: &AssistantMessage) -> Option<BedrockMessage> {
631    let mut content = Vec::new();
632    for block in &message.content {
633        match block {
634            ContentBlock::Text(text) => {
635                if !text.text.trim().is_empty() {
636                    content.push(BedrockContent::Text {
637                        text: text.text.clone(),
638                    });
639                }
640            }
641            ContentBlock::ToolCall(tool_call) => {
642                content.push(BedrockContent::ToolUse {
643                    tool_use: BedrockToolUse {
644                        tool_use_id: tool_call.id.clone(),
645                        name: tool_call.name.clone(),
646                        input: tool_call.arguments.clone(),
647                    },
648                });
649            }
650            _ => {}
651        }
652    }
653
654    if content.is_empty() {
655        None
656    } else {
657        Some(BedrockMessage {
658            role: "assistant",
659            content,
660        })
661    }
662}
663
664fn convert_tool_result_message(message: &ToolResultMessage) -> BedrockMessage {
665    let text = message
666        .content
667        .iter()
668        .filter_map(|block| match block {
669            ContentBlock::Text(text) => Some(text.text.as_str()),
670            _ => None,
671        })
672        .collect::<Vec<_>>()
673        .join("\n");
674
675    let result_text = if text.trim().is_empty() {
676        "{}".to_string()
677    } else {
678        text
679    };
680
681    BedrockMessage {
682        role: "user",
683        content: vec![BedrockContent::ToolResult {
684            tool_result: BedrockToolResult {
685                tool_use_id: message.tool_call_id.clone(),
686                content: vec![BedrockToolResultContent::Text { text: result_text }],
687                status: if message.is_error {
688                    "error".to_string()
689                } else {
690                    "success".to_string()
691                },
692            },
693        }],
694    }
695}
696
697fn convert_tool(tool: &ToolDef) -> BedrockToolDef {
698    BedrockToolDef {
699        tool_spec: BedrockToolSpec {
700            name: tool.name.clone(),
701            description: tool.description.clone(),
702            input_schema: BedrockInputSchema {
703                json: tool.parameters.clone(),
704            },
705        },
706    }
707}
708
709#[derive(Debug, Deserialize)]
710#[serde(rename_all = "camelCase")]
711struct BedrockConverseResponse {
712    #[serde(default)]
713    output: Option<BedrockResponseOutput>,
714    #[serde(default)]
715    stop_reason: Option<String>,
716    #[serde(default)]
717    usage: Option<BedrockUsage>,
718}
719
720#[derive(Debug, Deserialize)]
721struct BedrockResponseOutput {
722    message: BedrockResponseMessage,
723}
724
725#[derive(Debug, Deserialize)]
726struct BedrockResponseMessage {
727    #[allow(dead_code)]
728    role: Option<String>,
729    #[serde(default)]
730    content: Vec<BedrockResponseContent>,
731}
732
733#[derive(Debug, Deserialize)]
734#[serde(untagged)]
735enum BedrockResponseContent {
736    Text {
737        text: String,
738    },
739    ToolUse {
740        #[serde(rename = "toolUse")]
741        tool_use: BedrockResponseToolUse,
742    },
743}
744
745#[derive(Debug, Deserialize)]
746#[serde(rename_all = "camelCase")]
747struct BedrockResponseToolUse {
748    tool_use_id: String,
749    name: String,
750    #[serde(default)]
751    input: Value,
752}
753
754#[derive(Debug, Deserialize)]
755#[serde(rename_all = "camelCase")]
756#[allow(clippy::struct_field_names)]
757struct BedrockUsage {
758    #[serde(default)]
759    input_tokens: u64,
760    #[serde(default)]
761    output_tokens: u64,
762    #[serde(default)]
763    total_tokens: u64,
764}
765
766fn convert_usage(usage: &BedrockUsage) -> Usage {
767    let total = if usage.total_tokens > 0 {
768        usage.total_tokens
769    } else {
770        usage.input_tokens + usage.output_tokens
771    };
772
773    Usage {
774        input: usage.input_tokens,
775        output: usage.output_tokens,
776        total_tokens: total,
777        ..Usage::default()
778    }
779}
780
781fn map_stop_reason(stop_reason: Option<&str>) -> StopReason {
782    match stop_reason.unwrap_or("end_turn") {
783        "tool_use" => StopReason::ToolUse,
784        "max_tokens" => StopReason::Length,
785        "guardrail_intervened" | "content_filtered" => StopReason::Error,
786        _ => StopReason::Stop,
787    }
788}
789
790fn normalize_model_id(model_id: &str) -> Result<String> {
791    let mut normalized = model_id.trim().trim_matches('/');
792    if normalized.is_empty() {
793        return Err(Error::provider(
794            "amazon-bedrock",
795            "Bedrock model id cannot be empty",
796        ));
797    }
798
799    for prefix in ["amazon-bedrock/", "bedrock/", "model/"] {
800        if let Some(stripped) = normalized.strip_prefix(prefix) {
801            normalized = stripped;
802            break;
803        }
804    }
805
806    if let Some((_, stripped)) = normalized.split_once("/model/") {
807        normalized = stripped;
808    }
809
810    for suffix in ["/converse-stream", "/converse"] {
811        if let Some(stripped) = normalized.strip_suffix(suffix) {
812            normalized = stripped;
813            break;
814        }
815    }
816
817    let final_id = normalized.trim_matches('/');
818    if final_id.is_empty() {
819        return Err(Error::provider(
820            "amazon-bedrock",
821            "Bedrock model id cannot be empty",
822        ));
823    }
824
825    Ok(final_id.to_string())
826}
827
828fn build_sigv4_headers(
829    url: &Url,
830    payload: &[u8],
831    access_key_id: &str,
832    secret_access_key: &str,
833    session_token: Option<&str>,
834    region: &str,
835    now: DateTime<Utc>,
836) -> Result<Sigv4Headers> {
837    let amz_date = now.format("%Y%m%dT%H%M%SZ").to_string();
838    let date_stamp = now.format("%Y%m%d").to_string();
839    let payload_hash = sha256_hex(payload);
840    let host = canonical_host(url)?;
841    let canonical_uri = canonical_uri(url);
842    let canonical_query = canonical_query(url);
843
844    let mut canonical_headers = vec![
845        ("content-type".to_string(), "application/json".to_string()),
846        ("host".to_string(), host),
847        ("x-amz-content-sha256".to_string(), payload_hash.clone()),
848        ("x-amz-date".to_string(), amz_date.clone()),
849    ];
850    if let Some(token) = session_token {
851        canonical_headers.push(("x-amz-security-token".to_string(), token.to_string()));
852    }
853    canonical_headers.sort_by(|left, right| left.0.cmp(&right.0));
854
855    let signed_headers = canonical_headers
856        .iter()
857        .map(|(name, _)| name.as_str())
858        .collect::<Vec<_>>()
859        .join(";");
860
861    let mut canonical_headers_block = String::new();
862    for (name, value) in &canonical_headers {
863        let trimmed = value.trim();
864        writeln!(&mut canonical_headers_block, "{name}:{trimmed}")
865            .map_err(|err| Error::api(format!("Failed to build canonical headers: {err}")))?;
866    }
867
868    let canonical_request = format!(
869        "POST\n{canonical_uri}\n{canonical_query}\n{canonical_headers_block}\n{signed_headers}\n{payload_hash}"
870    );
871    let canonical_request_hash = sha256_hex(canonical_request.as_bytes());
872    let credential_scope = format!("{date_stamp}/{region}/{BEDROCK_SERVICE}/aws4_request");
873    let string_to_sign =
874        format!("AWS4-HMAC-SHA256\n{amz_date}\n{credential_scope}\n{canonical_request_hash}");
875    let signature = hex_encode(&signing_key(
876        secret_access_key,
877        &date_stamp,
878        region,
879        &string_to_sign,
880    )?);
881
882    let authorization = format!(
883        "AWS4-HMAC-SHA256 Credential={access_key_id}/{credential_scope}, SignedHeaders={signed_headers}, Signature={signature}"
884    );
885
886    Ok(Sigv4Headers {
887        authorization,
888        amz_date,
889        payload_hash,
890        security_token: session_token.map(ToString::to_string),
891    })
892}
893
894fn canonical_host(url: &Url) -> Result<String> {
895    let host = url.host_str().ok_or_else(|| {
896        Error::provider("amazon-bedrock", "Bedrock endpoint URL is missing a host")
897    })?;
898    Ok(url
899        .port()
900        .map_or_else(|| host.to_string(), |port| format!("{host}:{port}")))
901}
902
903fn canonical_uri(url: &Url) -> String {
904    let segments = url
905        .path_segments()
906        .map(|parts| parts.map(aws_percent_encode).collect::<Vec<_>>())
907        .unwrap_or_default();
908
909    if segments.is_empty() {
910        "/".to_string()
911    } else {
912        format!("/{}", segments.join("/"))
913    }
914}
915
916fn canonical_query(url: &Url) -> String {
917    let mut pairs = url
918        .query_pairs()
919        .map(|(key, value)| (aws_percent_encode(&key), aws_percent_encode(&value)))
920        .collect::<Vec<_>>();
921    pairs.sort();
922    pairs
923        .into_iter()
924        .map(|(key, value)| format!("{key}={value}"))
925        .collect::<Vec<_>>()
926        .join("&")
927}
928
929fn aws_percent_encode(value: &str) -> String {
930    let mut encoded = String::with_capacity(value.len());
931    for byte in value.bytes() {
932        if byte.is_ascii_alphanumeric() || matches!(byte, b'-' | b'_' | b'.' | b'~') {
933            encoded.push(char::from(byte));
934        } else {
935            encoded.push('%');
936            encoded.push(nibble_to_hex(byte >> 4));
937            encoded.push(nibble_to_hex(byte & 0x0f));
938        }
939    }
940    encoded
941}
942
943fn nibble_to_hex(nibble: u8) -> char {
944    match nibble {
945        0..=9 => char::from(b'0' + nibble),
946        10..=15 => char::from(b'A' + nibble - 10),
947        _ => '0',
948    }
949}
950
951fn signing_key(
952    secret_access_key: &str,
953    date_stamp: &str,
954    region: &str,
955    string_to_sign: &str,
956) -> Result<Vec<u8>> {
957    let key_date = hmac_sha256(
958        format!("AWS4{secret_access_key}").as_bytes(),
959        date_stamp.as_bytes(),
960    )?;
961    let key_region = hmac_sha256(&key_date, region.as_bytes())?;
962    let key_service = hmac_sha256(&key_region, BEDROCK_SERVICE.as_bytes())?;
963    let key_signing = hmac_sha256(&key_service, b"aws4_request")?;
964    hmac_sha256(&key_signing, string_to_sign.as_bytes())
965}
966
967fn hmac_sha256(key: &[u8], data: &[u8]) -> Result<Vec<u8>> {
968    let mut mac = HmacSha256::new_from_slice(key)
969        .map_err(|err| Error::api(format!("Failed to initialize HMAC: {err}")))?;
970    mac.update(data);
971    Ok(mac.finalize().into_bytes().to_vec())
972}
973
974fn sha256_hex(bytes: &[u8]) -> String {
975    let digest = Sha256::digest(bytes);
976    hex_encode(&digest)
977}
978
979fn hex_encode(bytes: &[u8]) -> String {
980    let mut out = String::with_capacity(bytes.len() * 2);
981    for byte in bytes {
982        let _ = write!(&mut out, "{byte:02x}");
983    }
984    out
985}
986
987#[cfg(test)]
988mod tests {
989    use super::*;
990    use chrono::TimeZone as _;
991    use serde_json::json;
992
993    fn test_context_with_tools() -> Context<'static> {
994        Context {
995            system_prompt: Some("You are concise.".to_string().into()),
996            messages: vec![
997                Message::User(crate::model::UserMessage {
998                    content: UserContent::Text("Ping".to_string()),
999                    timestamp: 0,
1000                }),
1001                Message::assistant(AssistantMessage {
1002                    content: vec![ContentBlock::ToolCall(ToolCall {
1003                        id: "tool_1".to_string(),
1004                        name: "search".to_string(),
1005                        arguments: json!({ "q": "rust" }),
1006                        thought_signature: None,
1007                    })],
1008                    api: "bedrock-converse-stream".to_string(),
1009                    provider: "amazon-bedrock".to_string(),
1010                    model: "m".to_string(),
1011                    usage: Usage::default(),
1012                    stop_reason: StopReason::ToolUse,
1013                    error_message: None,
1014                    timestamp: 0,
1015                }),
1016                Message::tool_result(ToolResultMessage {
1017                    tool_call_id: "tool_1".to_string(),
1018                    tool_name: "search".to_string(),
1019                    content: vec![ContentBlock::Text(TextContent {
1020                        text: "result".to_string(),
1021                        text_signature: None,
1022                    })],
1023                    details: None,
1024                    is_error: false,
1025                    timestamp: 0,
1026                }),
1027            ]
1028            .into(),
1029            tools: vec![ToolDef {
1030                name: "search".to_string(),
1031                description: "Search docs".to_string(),
1032                parameters: json!({
1033                    "type": "object",
1034                    "properties": {"q": {"type": "string"}},
1035                    "required": ["q"]
1036                }),
1037            }]
1038            .into(),
1039        }
1040    }
1041
1042    #[test]
1043    fn build_request_includes_system_messages_and_tools() {
1044        let request = BedrockProvider::build_request(
1045            &test_context_with_tools(),
1046            &StreamOptions {
1047                max_tokens: Some(321),
1048                temperature: Some(0.2),
1049                ..StreamOptions::default()
1050            },
1051        );
1052
1053        let value = serde_json::to_value(&request).expect("serialize request");
1054        assert_eq!(value["system"][0]["text"], "You are concise.");
1055        assert_eq!(value["messages"][0]["role"], "user");
1056        assert_eq!(
1057            value["messages"][1]["content"][0]["toolUse"]["name"],
1058            "search"
1059        );
1060        assert_eq!(
1061            value["messages"][2]["content"][0]["toolResult"]["status"],
1062            "success"
1063        );
1064        assert_eq!(value["inferenceConfig"]["maxTokens"], 321);
1065        assert_eq!(
1066            value["toolConfig"]["tools"][0]["toolSpec"]["name"],
1067            "search"
1068        );
1069    }
1070
1071    #[test]
1072    fn converse_url_appends_model_path_and_encodes_model_id() {
1073        let provider = BedrockProvider::new("anthropic.claude-3-5-sonnet-20240620-v1:0")
1074            .with_base_url("https://bedrock-runtime.us-east-1.amazonaws.com");
1075        let url = provider
1076            .converse_url("us-east-1")
1077            .expect("build converse URL");
1078        assert_eq!(
1079            url.path(),
1080            "/model/anthropic.claude-3-5-sonnet-20240620-v1:0/converse"
1081        );
1082    }
1083
1084    #[test]
1085    fn normalize_model_id_accepts_prefixed_variants() {
1086        assert_eq!(
1087            normalize_model_id("bedrock/us.anthropic.claude-3-7-sonnet-20250219-v1:0")
1088                .expect("normalize regional prefix"),
1089            "us.anthropic.claude-3-7-sonnet-20250219-v1:0"
1090        );
1091        assert_eq!(
1092            normalize_model_id("model/anthropic.claude-3-5-sonnet-20240620-v1:0/converse")
1093                .expect("normalize model path"),
1094            "anthropic.claude-3-5-sonnet-20240620-v1:0"
1095        );
1096    }
1097
1098    #[test]
1099    fn sigv4_headers_include_expected_scope_and_token() {
1100        let url =
1101            Url::parse("https://bedrock-runtime.us-west-2.amazonaws.com/model/m.converse/converse")
1102                .expect("url");
1103        let now = Utc
1104            .with_ymd_and_hms(2026, 2, 10, 8, 0, 0)
1105            .single()
1106            .expect("datetime");
1107        let headers = build_sigv4_headers(
1108            &url,
1109            br#"{"messages":[{"role":"user","content":[{"text":"Ping"}]}]}"#,
1110            "AKIDEXAMPLE",
1111            "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY",
1112            Some("session-token"),
1113            "us-west-2",
1114            now,
1115        )
1116        .expect("sign headers");
1117
1118        assert!(
1119            headers
1120                .authorization
1121                .contains("Credential=AKIDEXAMPLE/20260210/us-west-2/bedrock/aws4_request")
1122        );
1123        assert!(headers.authorization.contains(
1124            "SignedHeaders=content-type;host;x-amz-content-sha256;x-amz-date;x-amz-security-token"
1125        ));
1126        assert_eq!(headers.security_token.as_deref(), Some("session-token"));
1127        assert_eq!(headers.amz_date, "20260210T080000Z");
1128        assert_eq!(headers.payload_hash.len(), 64);
1129    }
1130
1131    #[test]
1132    fn response_to_message_maps_tool_use_and_usage() {
1133        let provider = BedrockProvider::new("anthropic.claude-3-5-sonnet-20240620-v1:0");
1134        let response: BedrockConverseResponse = serde_json::from_value(json!({
1135            "output": {
1136                "message": {
1137                    "role": "assistant",
1138                    "content": [
1139                        {"text": "I can help."},
1140                        {"toolUse": {"toolUseId": "call_1", "name": "search", "input": {"q": "rust"}}}
1141                    ]
1142                }
1143            },
1144            "stopReason": "tool_use",
1145            "usage": {"inputTokens": 10, "outputTokens": 5, "totalTokens": 15}
1146        }))
1147        .expect("parse response");
1148
1149        let message = provider.response_to_message(response);
1150        assert_eq!(message.stop_reason, StopReason::ToolUse);
1151        assert_eq!(message.usage.input, 10);
1152        assert_eq!(message.usage.output, 5);
1153        assert_eq!(message.usage.total_tokens, 15);
1154        assert!(matches!(message.content[0], ContentBlock::Text(_)));
1155        assert!(matches!(message.content[1], ContentBlock::ToolCall(_)));
1156    }
1157
1158    #[test]
1159    fn resolve_auth_context_uses_stream_option_api_key_fallback() {
1160        let temp_dir = tempfile::tempdir().expect("tempdir");
1161        let provider =
1162            BedrockProvider::new("model").with_auth_path(temp_dir.path().join("auth.json"));
1163        let auth = provider
1164            .resolve_auth_context(&StreamOptions {
1165                api_key: Some("bedrock-bearer".to_string()),
1166                ..StreamOptions::default()
1167            })
1168            .expect("resolve auth context");
1169        assert!(matches!(auth.auth, BedrockAuth::Bearer { .. }));
1170    }
1171}