Skip to main content

pi/providers/
bedrock.rs

1//! Amazon Bedrock Converse provider implementation.
2//!
3//! This provider targets the Bedrock Converse API and maps its non-streaming
4//! JSON response into Pi stream events.
5
6use crate::auth::{AuthStorage, AwsResolvedCredentials, resolve_aws_credentials};
7use crate::config::Config;
8use crate::error::{Error, Result};
9use crate::http::client::Client;
10use crate::model::{
11    AssistantMessage, ContentBlock, Message, StopReason, StreamEvent, TextContent, ToolCall,
12    ToolResultMessage, Usage, UserContent,
13};
14use crate::models::CompatConfig;
15use crate::provider::{Context, Provider, StreamOptions, ToolDef};
16use async_trait::async_trait;
17use chrono::{DateTime, Utc};
18use futures::Stream;
19use futures::stream;
20use hmac::{Hmac, Mac};
21use serde::{Deserialize, Serialize};
22use serde_json::Value;
23use sha2::{Digest, Sha256};
24use std::fmt::Write as _;
25#[cfg(test)]
26use std::path::Path;
27use std::path::PathBuf;
28use std::pin::Pin;
29use url::Url;
30
31const DEFAULT_REGION: &str = "us-east-1";
32const BEDROCK_SERVICE: &str = "bedrock";
33
34type HmacSha256 = Hmac<Sha256>;
35
36#[derive(Debug, Clone)]
37enum BedrockAuth {
38    Sigv4 {
39        access_key_id: String,
40        secret_access_key: String,
41        session_token: Option<String>,
42    },
43    Bearer {
44        token: String,
45    },
46}
47
48#[derive(Debug, Clone)]
49struct BedrockAuthContext {
50    auth: BedrockAuth,
51    region: String,
52}
53
54#[derive(Debug, Clone)]
55struct Sigv4Headers {
56    authorization: String,
57    amz_date: String,
58    payload_hash: String,
59    security_token: Option<String>,
60}
61
62/// Amazon Bedrock provider.
63pub struct BedrockProvider {
64    client: Client,
65    model: String,
66    provider_name: String,
67    base_url_override: Option<String>,
68    compat: Option<CompatConfig>,
69    auth_path_override: Option<PathBuf>,
70}
71
72impl BedrockProvider {
73    /// Create a Bedrock provider for the given model ID.
74    pub fn new(model: impl Into<String>) -> Self {
75        let raw_model = model.into();
76        let normalized_model = normalize_model_id(&raw_model)
77            .ok()
78            .unwrap_or_else(|| raw_model.trim().to_string());
79        Self {
80            client: Client::new(),
81            model: normalized_model,
82            provider_name: "amazon-bedrock".to_string(),
83            base_url_override: None,
84            compat: None,
85            auth_path_override: None,
86        }
87    }
88
89    /// Set provider name for event attribution.
90    #[must_use]
91    pub fn with_provider_name(mut self, provider_name: impl Into<String>) -> Self {
92        self.provider_name = provider_name.into();
93        self
94    }
95
96    /// Override Bedrock base URL (useful for tests/proxies).
97    #[must_use]
98    pub fn with_base_url(mut self, base_url: impl AsRef<str>) -> Self {
99        let trimmed = base_url.as_ref().trim();
100        if !trimmed.is_empty() {
101            self.base_url_override = Some(trimmed.to_string());
102        }
103        self
104    }
105
106    /// Attach provider compatibility overrides.
107    #[must_use]
108    pub fn with_compat(mut self, compat: Option<CompatConfig>) -> Self {
109        self.compat = compat;
110        self
111    }
112
113    /// Inject a custom HTTP client.
114    #[must_use]
115    pub fn with_client(mut self, client: Client) -> Self {
116        self.client = client;
117        self
118    }
119
120    #[cfg(test)]
121    #[must_use]
122    fn with_auth_path(mut self, path: impl AsRef<Path>) -> Self {
123        self.auth_path_override = Some(path.as_ref().to_path_buf());
124        self
125    }
126
127    fn auth_path(&self) -> PathBuf {
128        self.auth_path_override
129            .clone()
130            .unwrap_or_else(Config::auth_path)
131    }
132
133    fn load_auth_storage(&self) -> Result<AuthStorage> {
134        AuthStorage::load(self.auth_path())
135            .map_err(|err| Error::auth(format!("Failed to load Bedrock credentials: {err}")))
136    }
137
138    fn resolve_auth_context(&self, options: &StreamOptions) -> Result<BedrockAuthContext> {
139        let auth_storage = self.load_auth_storage()?;
140        if let Some(resolved) = resolve_aws_credentials(&auth_storage) {
141            return Ok(match resolved {
142                AwsResolvedCredentials::Sigv4 {
143                    access_key_id,
144                    secret_access_key,
145                    session_token,
146                    region,
147                } => BedrockAuthContext {
148                    auth: BedrockAuth::Sigv4 {
149                        access_key_id,
150                        secret_access_key,
151                        session_token,
152                    },
153                    region,
154                },
155                AwsResolvedCredentials::Bearer { token, region } => BedrockAuthContext {
156                    auth: BedrockAuth::Bearer { token },
157                    region,
158                },
159            });
160        }
161
162        if let Some(token) = options
163            .api_key
164            .as_deref()
165            .map(str::trim)
166            .filter(|token| !token.is_empty())
167        {
168            return Ok(BedrockAuthContext {
169                auth: BedrockAuth::Bearer {
170                    token: token.to_string(),
171                },
172                region: std::env::var("AWS_REGION")
173                    .ok()
174                    .or_else(|| std::env::var("AWS_DEFAULT_REGION").ok())
175                    .unwrap_or_else(|| DEFAULT_REGION.to_string()),
176            });
177        }
178
179        Err(Error::auth(
180            "Amazon Bedrock requires AWS credentials. Set AWS_ACCESS_KEY_ID/AWS_SECRET_ACCESS_KEY, AWS_BEARER_TOKEN_BEDROCK, or store amazon-bedrock credentials in auth.json.",
181        ))
182    }
183
184    fn converse_url(&self, region: &str) -> Result<Url> {
185        let base = self
186            .base_url_override
187            .clone()
188            .unwrap_or_else(|| format!("https://bedrock-runtime.{region}.amazonaws.com"));
189        let mut url = Url::parse(&base)
190            .map_err(|err| Error::provider("amazon-bedrock", format!("Invalid base URL: {err}")))?;
191
192        if self.model.trim().is_empty() {
193            return Err(Error::provider(
194                "amazon-bedrock",
195                "Bedrock model id cannot be empty",
196            ));
197        }
198
199        if url.path().ends_with("/converse") || url.path().ends_with("/converse-stream") {
200            return Ok(url);
201        }
202
203        {
204            let mut segments = url.path_segments_mut().map_err(|()| {
205                Error::provider(
206                    "amazon-bedrock",
207                    "Bedrock base URL does not support path segments",
208                )
209            })?;
210            segments.push("model");
211            segments.push(&self.model);
212            segments.push("converse");
213        }
214        Ok(url)
215    }
216
217    pub fn build_request(context: &Context<'_>, options: &StreamOptions) -> BedrockConverseRequest {
218        let mut system = Vec::new();
219        if let Some(system_prompt) = context
220            .system_prompt
221            .as_deref()
222            .map(str::trim)
223            .filter(|prompt| !prompt.is_empty())
224        {
225            system.push(BedrockSystemContent {
226                text: system_prompt.to_string(),
227            });
228        }
229
230        let mut messages = Vec::new();
231        for message in context.messages.iter() {
232            if let Some(converted) = convert_message(message) {
233                messages.push(converted);
234            }
235        }
236
237        if messages.is_empty() {
238            messages.push(BedrockMessage {
239                role: "user",
240                content: vec![BedrockContent::Text {
241                    text: "Hello".to_string(),
242                }],
243            });
244        }
245
246        let inference_config = if options.max_tokens.is_some() || options.temperature.is_some() {
247            Some(BedrockInferenceConfig {
248                max_tokens: options.max_tokens,
249                temperature: options.temperature,
250            })
251        } else {
252            None
253        };
254
255        let tool_config = if context.tools.is_empty() {
256            None
257        } else {
258            Some(BedrockToolConfig {
259                tools: context.tools.iter().map(convert_tool).collect(),
260            })
261        };
262
263        BedrockConverseRequest {
264            system,
265            messages,
266            inference_config,
267            tool_config,
268        }
269    }
270
271    fn response_to_message(&self, response: BedrockConverseResponse) -> AssistantMessage {
272        let usage = response
273            .usage
274            .as_ref()
275            .map_or_else(Usage::default, convert_usage);
276
277        let stop_reason = map_stop_reason(response.stop_reason.as_deref());
278        let mut content = Vec::new();
279
280        if let Some(output) = response.output {
281            for block in output.message.content {
282                match block {
283                    BedrockResponseContent::Text { text } => {
284                        if !text.is_empty() {
285                            content.push(ContentBlock::Text(TextContent {
286                                text,
287                                text_signature: None,
288                            }));
289                        }
290                    }
291                    BedrockResponseContent::ToolUse { tool_use } => {
292                        content.push(ContentBlock::ToolCall(ToolCall {
293                            id: tool_use.tool_use_id,
294                            name: tool_use.name,
295                            arguments: tool_use.input,
296                            thought_signature: None,
297                        }));
298                    }
299                }
300            }
301        }
302
303        AssistantMessage {
304            content,
305            api: "bedrock-converse-stream".to_string(),
306            provider: self.provider_name.clone(),
307            model: self.model.clone(),
308            usage,
309            stop_reason,
310            error_message: None,
311            timestamp: Utc::now().timestamp_millis(),
312        }
313    }
314
315    fn message_events(message: &AssistantMessage) -> Vec<Result<StreamEvent>> {
316        let mut events = Vec::new();
317        events.push(Ok(StreamEvent::Start {
318            partial: message.clone(),
319        }));
320        for (content_index, block) in message.content.iter().enumerate() {
321            match block {
322                ContentBlock::Text(text) => {
323                    events.push(Ok(StreamEvent::TextStart { content_index }));
324                    events.push(Ok(StreamEvent::TextDelta {
325                        content_index,
326                        delta: text.text.clone(),
327                    }));
328                    events.push(Ok(StreamEvent::TextEnd {
329                        content_index,
330                        content: text.text.clone(),
331                    }));
332                }
333                ContentBlock::ToolCall(tool_call) => {
334                    let delta = serde_json::to_string(&tool_call.arguments)
335                        .unwrap_or_else(|_| "{}".to_string());
336                    events.push(Ok(StreamEvent::ToolCallStart { content_index }));
337                    events.push(Ok(StreamEvent::ToolCallDelta {
338                        content_index,
339                        delta,
340                    }));
341                    events.push(Ok(StreamEvent::ToolCallEnd {
342                        content_index,
343                        tool_call: tool_call.clone(),
344                    }));
345                }
346                _ => {}
347            }
348        }
349
350        events.push(Ok(StreamEvent::Done {
351            reason: message.stop_reason,
352            message: message.clone(),
353        }));
354        events
355    }
356}
357
358#[async_trait]
359impl Provider for BedrockProvider {
360    fn name(&self) -> &str {
361        &self.provider_name
362    }
363
364    fn api(&self) -> &'static str {
365        "bedrock-converse-stream"
366    }
367
368    fn model_id(&self) -> &str {
369        &self.model
370    }
371
372    async fn stream(
373        &self,
374        context: &Context<'_>,
375        options: &StreamOptions,
376    ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamEvent>> + Send>>> {
377        let request_body = Self::build_request(context, options);
378        let body = serde_json::to_vec(&request_body).map_err(|err| {
379            Error::provider(
380                "amazon-bedrock",
381                format!("Failed to serialize request body: {err}"),
382            )
383        })?;
384
385        let auth_context = self.resolve_auth_context(options)?;
386        let url = self.converse_url(&auth_context.region)?;
387
388        let mut request = self
389            .client
390            .post(url.as_str())
391            .header("Content-Type", "application/json")
392            .header("Accept", "application/json");
393
394        match auth_context.auth {
395            BedrockAuth::Bearer { token } => {
396                request = request.header("Authorization", format!("Bearer {token}"));
397            }
398            BedrockAuth::Sigv4 {
399                access_key_id,
400                secret_access_key,
401                session_token,
402            } => {
403                let signing_headers = build_sigv4_headers(
404                    &url,
405                    &body,
406                    &access_key_id,
407                    &secret_access_key,
408                    session_token.as_deref(),
409                    &auth_context.region,
410                    Utc::now(),
411                )?;
412                request = request
413                    .header("Authorization", signing_headers.authorization)
414                    .header("x-amz-date", signing_headers.amz_date)
415                    .header("x-amz-content-sha256", signing_headers.payload_hash);
416                if let Some(token) = signing_headers.security_token {
417                    request = request.header("x-amz-security-token", token);
418                }
419            }
420        }
421
422        if let Some(compat) = &self.compat
423            && let Some(custom_headers) = &compat.custom_headers
424        {
425            for (name, value) in custom_headers {
426                request = request.header(name, value);
427            }
428        }
429
430        for (name, value) in &options.headers {
431            request = request.header(name, value);
432        }
433
434        let response = request.body(body).send().await?;
435        let status = response.status();
436        let response_text = response
437            .text()
438            .await
439            .unwrap_or_else(|err| format!("<failed to read body: {err}>"));
440
441        if !(200..300).contains(&status) {
442            return Err(Error::provider(
443                "amazon-bedrock",
444                format!("Bedrock Converse API error (HTTP {status}): {response_text}"),
445            ));
446        }
447
448        let parsed: BedrockConverseResponse =
449            serde_json::from_str(&response_text).map_err(|err| {
450                Error::provider(
451                    "amazon-bedrock",
452                    format!("Failed to parse Bedrock response: {err}\nData: {response_text}"),
453                )
454            })?;
455
456        let message = self.response_to_message(parsed);
457        Ok(Box::pin(stream::iter(Self::message_events(&message))))
458    }
459}
460
461#[derive(Debug, Serialize)]
462#[serde(rename_all = "camelCase")]
463pub struct BedrockConverseRequest {
464    #[serde(skip_serializing_if = "Vec::is_empty")]
465    system: Vec<BedrockSystemContent>,
466    messages: Vec<BedrockMessage>,
467    #[serde(rename = "inferenceConfig", skip_serializing_if = "Option::is_none")]
468    inference_config: Option<BedrockInferenceConfig>,
469    #[serde(rename = "toolConfig", skip_serializing_if = "Option::is_none")]
470    tool_config: Option<BedrockToolConfig>,
471}
472
473#[derive(Debug, Serialize)]
474struct BedrockSystemContent {
475    text: String,
476}
477
478#[derive(Debug, Serialize)]
479struct BedrockMessage {
480    role: &'static str,
481    content: Vec<BedrockContent>,
482}
483
484#[derive(Debug, Serialize)]
485#[serde(untagged)]
486enum BedrockContent {
487    Text {
488        text: String,
489    },
490    Image {
491        image: BedrockImageBlock,
492    },
493    ToolUse {
494        #[serde(rename = "toolUse")]
495        tool_use: BedrockToolUse,
496    },
497    ToolResult {
498        #[serde(rename = "toolResult")]
499        tool_result: BedrockToolResult,
500    },
501}
502
503#[derive(Debug, Serialize)]
504struct BedrockImageBlock {
505    format: String,
506    source: BedrockImageSource,
507}
508
509#[derive(Debug, Serialize)]
510struct BedrockImageSource {
511    bytes: String,
512}
513
514#[derive(Debug, Serialize)]
515#[serde(rename_all = "camelCase")]
516struct BedrockToolUse {
517    tool_use_id: String,
518    name: String,
519    input: Value,
520}
521
522#[derive(Debug, Serialize)]
523#[serde(rename_all = "camelCase")]
524struct BedrockToolResult {
525    tool_use_id: String,
526    content: Vec<BedrockToolResultContent>,
527    status: String,
528}
529
530#[derive(Debug, Serialize)]
531#[serde(untagged)]
532enum BedrockToolResultContent {
533    Text { text: String },
534    Image { image: BedrockImageBlock },
535}
536
537#[derive(Debug, Serialize)]
538#[serde(rename_all = "camelCase")]
539struct BedrockInferenceConfig {
540    #[serde(skip_serializing_if = "Option::is_none")]
541    max_tokens: Option<u32>,
542    #[serde(skip_serializing_if = "Option::is_none")]
543    temperature: Option<f32>,
544}
545
546#[derive(Debug, Serialize)]
547struct BedrockToolConfig {
548    tools: Vec<BedrockToolDef>,
549}
550
551#[derive(Debug, Serialize)]
552#[serde(rename_all = "camelCase")]
553struct BedrockToolDef {
554    tool_spec: BedrockToolSpec,
555}
556
557#[derive(Debug, Serialize)]
558#[serde(rename_all = "camelCase")]
559struct BedrockToolSpec {
560    name: String,
561    description: String,
562    input_schema: BedrockInputSchema,
563}
564
565#[derive(Debug, Serialize)]
566struct BedrockInputSchema {
567    json: Value,
568}
569
570fn convert_message(message: &Message) -> Option<BedrockMessage> {
571    match message {
572        Message::User(user_message) => convert_user_message(user_message),
573        Message::Assistant(assistant_message) => convert_assistant_message(assistant_message),
574        Message::ToolResult(tool_result_message) => {
575            Some(convert_tool_result_message(tool_result_message))
576        }
577        Message::Custom(_) => None,
578    }
579}
580
581fn convert_user_message(message: &crate::model::UserMessage) -> Option<BedrockMessage> {
582    let mut content = Vec::new();
583    match &message.content {
584        UserContent::Text(text) => {
585            if !text.trim().is_empty() {
586                content.push(BedrockContent::Text { text: text.clone() });
587            }
588        }
589        UserContent::Blocks(blocks) => {
590            for block in blocks {
591                match block {
592                    ContentBlock::Text(text) if !text.text.trim().is_empty() => {
593                        content.push(BedrockContent::Text {
594                            text: text.text.clone(),
595                        });
596                    }
597                    ContentBlock::Image(img) => {
598                        let format = img
599                            .mime_type
600                            .rsplit('/')
601                            .next()
602                            .unwrap_or("png")
603                            .to_string();
604                        content.push(BedrockContent::Image {
605                            image: BedrockImageBlock {
606                                format,
607                                source: BedrockImageSource {
608                                    bytes: img.data.clone(),
609                                },
610                            },
611                        });
612                    }
613                    _ => {}
614                }
615            }
616        }
617    }
618
619    if content.is_empty() {
620        None
621    } else {
622        Some(BedrockMessage {
623            role: "user",
624            content,
625        })
626    }
627}
628
629fn convert_assistant_message(message: &AssistantMessage) -> Option<BedrockMessage> {
630    let mut content = Vec::new();
631    for block in &message.content {
632        match block {
633            ContentBlock::Text(text) if !text.text.trim().is_empty() => {
634                content.push(BedrockContent::Text {
635                    text: text.text.clone(),
636                });
637            }
638            ContentBlock::ToolCall(tool_call) => {
639                content.push(BedrockContent::ToolUse {
640                    tool_use: BedrockToolUse {
641                        tool_use_id: tool_call.id.clone(),
642                        name: tool_call.name.clone(),
643                        input: tool_call.arguments.clone(),
644                    },
645                });
646            }
647            _ => {}
648        }
649    }
650
651    if content.is_empty() {
652        None
653    } else {
654        Some(BedrockMessage {
655            role: "assistant",
656            content,
657        })
658    }
659}
660
661fn convert_tool_result_message(message: &ToolResultMessage) -> BedrockMessage {
662    let mut contents = Vec::new();
663
664    for block in &message.content {
665        match block {
666            ContentBlock::Text(text) if !text.text.is_empty() => {
667                contents.push(BedrockToolResultContent::Text {
668                    text: text.text.clone(),
669                });
670            }
671            ContentBlock::Image(img) => {
672                let format = img
673                    .mime_type
674                    .rsplit('/')
675                    .next()
676                    .unwrap_or("png")
677                    .to_string();
678                contents.push(BedrockToolResultContent::Image {
679                    image: BedrockImageBlock {
680                        format,
681                        source: BedrockImageSource {
682                            bytes: img.data.clone(),
683                        },
684                    },
685                });
686            }
687            _ => {}
688        }
689    }
690
691    if contents.is_empty() {
692        contents.push(BedrockToolResultContent::Text {
693            text: "{}".to_string(),
694        });
695    }
696
697    BedrockMessage {
698        role: "user",
699        content: vec![BedrockContent::ToolResult {
700            tool_result: BedrockToolResult {
701                tool_use_id: message.tool_call_id.clone(),
702                content: contents,
703                status: if message.is_error {
704                    "error".to_string()
705                } else {
706                    "success".to_string()
707                },
708            },
709        }],
710    }
711}
712
713fn convert_tool(tool: &ToolDef) -> BedrockToolDef {
714    BedrockToolDef {
715        tool_spec: BedrockToolSpec {
716            name: tool.name.clone(),
717            description: tool.description.clone(),
718            input_schema: BedrockInputSchema {
719                json: tool.parameters.clone(),
720            },
721        },
722    }
723}
724
725#[derive(Debug, Deserialize)]
726#[serde(rename_all = "camelCase")]
727struct BedrockConverseResponse {
728    #[serde(default)]
729    output: Option<BedrockResponseOutput>,
730    #[serde(default)]
731    stop_reason: Option<String>,
732    #[serde(default)]
733    usage: Option<BedrockUsage>,
734}
735
736#[derive(Debug, Deserialize)]
737struct BedrockResponseOutput {
738    message: BedrockResponseMessage,
739}
740
741#[derive(Debug, Deserialize)]
742struct BedrockResponseMessage {
743    #[allow(dead_code)]
744    role: Option<String>,
745    #[serde(default)]
746    content: Vec<BedrockResponseContent>,
747}
748
749#[derive(Debug, Deserialize)]
750#[serde(untagged)]
751enum BedrockResponseContent {
752    Text {
753        text: String,
754    },
755    ToolUse {
756        #[serde(rename = "toolUse")]
757        tool_use: BedrockResponseToolUse,
758    },
759}
760
761#[derive(Debug, Deserialize)]
762#[serde(rename_all = "camelCase")]
763struct BedrockResponseToolUse {
764    tool_use_id: String,
765    name: String,
766    #[serde(default)]
767    input: Value,
768}
769
770#[derive(Debug, Deserialize)]
771#[serde(rename_all = "camelCase")]
772#[allow(clippy::struct_field_names)]
773struct BedrockUsage {
774    #[serde(default)]
775    input_tokens: u64,
776    #[serde(default)]
777    output_tokens: u64,
778    #[serde(default)]
779    total_tokens: u64,
780}
781
782fn convert_usage(usage: &BedrockUsage) -> Usage {
783    let total = if usage.total_tokens > 0 {
784        usage.total_tokens
785    } else {
786        usage.input_tokens + usage.output_tokens
787    };
788
789    Usage {
790        input: usage.input_tokens,
791        output: usage.output_tokens,
792        total_tokens: total,
793        ..Usage::default()
794    }
795}
796
797fn map_stop_reason(stop_reason: Option<&str>) -> StopReason {
798    match stop_reason.unwrap_or("end_turn") {
799        "tool_use" => StopReason::ToolUse,
800        "max_tokens" => StopReason::Length,
801        "guardrail_intervened" | "content_filtered" => StopReason::Error,
802        _ => StopReason::Stop,
803    }
804}
805
806fn normalize_model_id(model_id: &str) -> Result<String> {
807    let mut normalized = model_id.trim().trim_matches('/');
808    if normalized.is_empty() {
809        return Err(Error::provider(
810            "amazon-bedrock",
811            "Bedrock model id cannot be empty",
812        ));
813    }
814
815    for prefix in ["amazon-bedrock/", "bedrock/", "model/"] {
816        if let Some(stripped) = normalized.strip_prefix(prefix) {
817            normalized = stripped;
818            break;
819        }
820    }
821
822    if let Some((_, stripped)) = normalized.split_once("/model/") {
823        normalized = stripped;
824    }
825
826    for suffix in ["/converse-stream", "/converse"] {
827        if let Some(stripped) = normalized.strip_suffix(suffix) {
828            normalized = stripped;
829            break;
830        }
831    }
832
833    let final_id = normalized.trim_matches('/');
834    if final_id.is_empty() {
835        return Err(Error::provider(
836            "amazon-bedrock",
837            "Bedrock model id cannot be empty",
838        ));
839    }
840
841    Ok(final_id.to_string())
842}
843
844fn build_sigv4_headers(
845    url: &Url,
846    payload: &[u8],
847    access_key_id: &str,
848    secret_access_key: &str,
849    session_token: Option<&str>,
850    region: &str,
851    now: DateTime<Utc>,
852) -> Result<Sigv4Headers> {
853    let amz_date = now.format("%Y%m%dT%H%M%SZ").to_string();
854    let date_stamp = now.format("%Y%m%d").to_string();
855    let payload_hash = sha256_hex(payload);
856    let host = canonical_host(url)?;
857    let canonical_uri = canonical_uri(url);
858    let canonical_query = canonical_query(url);
859
860    let mut canonical_headers = vec![
861        ("content-type".to_string(), "application/json".to_string()),
862        ("host".to_string(), host),
863        ("x-amz-content-sha256".to_string(), payload_hash.clone()),
864        ("x-amz-date".to_string(), amz_date.clone()),
865    ];
866    if let Some(token) = session_token {
867        canonical_headers.push(("x-amz-security-token".to_string(), token.to_string()));
868    }
869    canonical_headers.sort_by(|left, right| left.0.cmp(&right.0));
870
871    let signed_headers = canonical_headers
872        .iter()
873        .map(|(name, _)| name.as_str())
874        .collect::<Vec<_>>()
875        .join(";");
876
877    let mut canonical_headers_block = String::new();
878    for (name, value) in &canonical_headers {
879        let trimmed = value.trim();
880        writeln!(&mut canonical_headers_block, "{name}:{trimmed}")
881            .map_err(|err| Error::api(format!("Failed to build canonical headers: {err}")))?;
882    }
883
884    let canonical_request = format!(
885        "POST\n{canonical_uri}\n{canonical_query}\n{canonical_headers_block}\n{signed_headers}\n{payload_hash}"
886    );
887    let canonical_request_hash = sha256_hex(canonical_request.as_bytes());
888    let credential_scope = format!("{date_stamp}/{region}/{BEDROCK_SERVICE}/aws4_request");
889    let string_to_sign =
890        format!("AWS4-HMAC-SHA256\n{amz_date}\n{credential_scope}\n{canonical_request_hash}");
891    let signature = hex_encode(&signing_key(
892        secret_access_key,
893        &date_stamp,
894        region,
895        &string_to_sign,
896    )?);
897
898    let authorization = format!(
899        "AWS4-HMAC-SHA256 Credential={access_key_id}/{credential_scope}, SignedHeaders={signed_headers}, Signature={signature}"
900    );
901
902    Ok(Sigv4Headers {
903        authorization,
904        amz_date,
905        payload_hash,
906        security_token: session_token.map(ToString::to_string),
907    })
908}
909
910fn canonical_host(url: &Url) -> Result<String> {
911    let host = url.host_str().ok_or_else(|| {
912        Error::provider("amazon-bedrock", "Bedrock endpoint URL is missing a host")
913    })?;
914    Ok(url
915        .port()
916        .map_or_else(|| host.to_string(), |port| format!("{host}:{port}")))
917}
918
919fn canonical_uri(url: &Url) -> String {
920    let segments = url
921        .path_segments()
922        .map(|parts| parts.map(aws_percent_encode).collect::<Vec<_>>())
923        .unwrap_or_default();
924
925    if segments.is_empty() {
926        "/".to_string()
927    } else {
928        format!("/{}", segments.join("/"))
929    }
930}
931
932fn canonical_query(url: &Url) -> String {
933    let mut pairs = url
934        .query_pairs()
935        .map(|(key, value)| (aws_percent_encode(&key), aws_percent_encode(&value)))
936        .collect::<Vec<_>>();
937    pairs.sort();
938    pairs
939        .into_iter()
940        .map(|(key, value)| format!("{key}={value}"))
941        .collect::<Vec<_>>()
942        .join("&")
943}
944
945fn aws_percent_encode(value: &str) -> String {
946    let mut encoded = String::with_capacity(value.len());
947    for byte in value.bytes() {
948        if byte.is_ascii_alphanumeric() || matches!(byte, b'-' | b'_' | b'.' | b'~') {
949            encoded.push(char::from(byte));
950        } else {
951            encoded.push('%');
952            encoded.push(nibble_to_hex(byte >> 4));
953            encoded.push(nibble_to_hex(byte & 0x0f));
954        }
955    }
956    encoded
957}
958
959fn nibble_to_hex(nibble: u8) -> char {
960    match nibble {
961        0..=9 => char::from(b'0' + nibble),
962        10..=15 => char::from(b'A' + nibble - 10),
963        _ => '0',
964    }
965}
966
967fn signing_key(
968    secret_access_key: &str,
969    date_stamp: &str,
970    region: &str,
971    string_to_sign: &str,
972) -> Result<Vec<u8>> {
973    let key_date = hmac_sha256(
974        format!("AWS4{secret_access_key}").as_bytes(),
975        date_stamp.as_bytes(),
976    )?;
977    let key_region = hmac_sha256(&key_date, region.as_bytes())?;
978    let key_service = hmac_sha256(&key_region, BEDROCK_SERVICE.as_bytes())?;
979    let key_signing = hmac_sha256(&key_service, b"aws4_request")?;
980    hmac_sha256(&key_signing, string_to_sign.as_bytes())
981}
982
983fn hmac_sha256(key: &[u8], data: &[u8]) -> Result<Vec<u8>> {
984    let mut mac = HmacSha256::new_from_slice(key)
985        .map_err(|err| Error::api(format!("Failed to initialize HMAC: {err}")))?;
986    mac.update(data);
987    Ok(mac.finalize().into_bytes().to_vec())
988}
989
990fn sha256_hex(bytes: &[u8]) -> String {
991    let digest = Sha256::digest(bytes);
992    hex_encode(&digest)
993}
994
995fn hex_encode(bytes: &[u8]) -> String {
996    let mut out = String::with_capacity(bytes.len() * 2);
997    for byte in bytes {
998        let _ = write!(&mut out, "{byte:02x}");
999    }
1000    out
1001}
1002
1003#[cfg(test)]
1004mod tests {
1005    use super::*;
1006    use chrono::TimeZone as _;
1007    use serde_json::json;
1008
1009    fn test_context_with_tools() -> Context<'static> {
1010        Context {
1011            system_prompt: Some("You are concise.".to_string().into()),
1012            messages: vec![
1013                Message::User(crate::model::UserMessage {
1014                    content: UserContent::Text("Ping".to_string()),
1015                    timestamp: 0,
1016                }),
1017                Message::assistant(AssistantMessage {
1018                    content: vec![ContentBlock::ToolCall(ToolCall {
1019                        id: "tool_1".to_string(),
1020                        name: "search".to_string(),
1021                        arguments: json!({ "q": "rust" }),
1022                        thought_signature: None,
1023                    })],
1024                    api: "bedrock-converse-stream".to_string(),
1025                    provider: "amazon-bedrock".to_string(),
1026                    model: "m".to_string(),
1027                    usage: Usage::default(),
1028                    stop_reason: StopReason::ToolUse,
1029                    error_message: None,
1030                    timestamp: 0,
1031                }),
1032                Message::tool_result(ToolResultMessage {
1033                    tool_call_id: "tool_1".to_string(),
1034                    tool_name: "search".to_string(),
1035                    content: vec![ContentBlock::Text(TextContent {
1036                        text: "result".to_string(),
1037                        text_signature: None,
1038                    })],
1039                    details: None,
1040                    is_error: false,
1041                    timestamp: 0,
1042                }),
1043            ]
1044            .into(),
1045            tools: vec![ToolDef {
1046                name: "search".to_string(),
1047                description: "Search docs".to_string(),
1048                parameters: json!({
1049                    "type": "object",
1050                    "properties": {"q": {"type": "string"}},
1051                    "required": ["q"]
1052                }),
1053            }]
1054            .into(),
1055        }
1056    }
1057
1058    #[test]
1059    fn build_request_includes_system_messages_and_tools() {
1060        let request = BedrockProvider::build_request(
1061            &test_context_with_tools(),
1062            &StreamOptions {
1063                max_tokens: Some(321),
1064                temperature: Some(0.2),
1065                ..StreamOptions::default()
1066            },
1067        );
1068
1069        let value = serde_json::to_value(&request).expect("serialize request");
1070        assert_eq!(value["system"][0]["text"], "You are concise.");
1071        assert_eq!(value["messages"][0]["role"], "user");
1072        assert_eq!(
1073            value["messages"][1]["content"][0]["toolUse"]["name"],
1074            "search"
1075        );
1076        assert_eq!(
1077            value["messages"][2]["content"][0]["toolResult"]["status"],
1078            "success"
1079        );
1080        assert_eq!(value["inferenceConfig"]["maxTokens"], 321);
1081        assert_eq!(
1082            value["toolConfig"]["tools"][0]["toolSpec"]["name"],
1083            "search"
1084        );
1085    }
1086
1087    #[test]
1088    fn converse_url_appends_model_path_and_encodes_model_id() {
1089        let provider = BedrockProvider::new("anthropic.claude-3-5-sonnet-20240620-v1:0")
1090            .with_base_url("https://bedrock-runtime.us-east-1.amazonaws.com");
1091        let url = provider
1092            .converse_url("us-east-1")
1093            .expect("build converse URL");
1094        assert_eq!(
1095            url.path(),
1096            "/model/anthropic.claude-3-5-sonnet-20240620-v1:0/converse"
1097        );
1098    }
1099
1100    #[test]
1101    fn normalize_model_id_accepts_prefixed_variants() {
1102        assert_eq!(
1103            normalize_model_id("bedrock/us.anthropic.claude-3-7-sonnet-20250219-v1:0")
1104                .expect("normalize regional prefix"),
1105            "us.anthropic.claude-3-7-sonnet-20250219-v1:0"
1106        );
1107        assert_eq!(
1108            normalize_model_id("model/anthropic.claude-3-5-sonnet-20240620-v1:0/converse")
1109                .expect("normalize model path"),
1110            "anthropic.claude-3-5-sonnet-20240620-v1:0"
1111        );
1112    }
1113
1114    #[test]
1115    fn sigv4_headers_include_expected_scope_and_token() {
1116        let url =
1117            Url::parse("https://bedrock-runtime.us-west-2.amazonaws.com/model/m.converse/converse")
1118                .expect("url");
1119        let now = Utc
1120            .with_ymd_and_hms(2026, 2, 10, 8, 0, 0)
1121            .single()
1122            .expect("datetime");
1123        let headers = build_sigv4_headers(
1124            &url,
1125            br#"{"messages":[{"role":"user","content":[{"text":"Ping"}]}]}"#,
1126            "AKIDEXAMPLE",
1127            "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY",
1128            Some("session-token"),
1129            "us-west-2",
1130            now,
1131        )
1132        .expect("sign headers");
1133
1134        assert!(
1135            headers
1136                .authorization
1137                .contains("Credential=AKIDEXAMPLE/20260210/us-west-2/bedrock/aws4_request")
1138        );
1139        assert!(headers.authorization.contains(
1140            "SignedHeaders=content-type;host;x-amz-content-sha256;x-amz-date;x-amz-security-token"
1141        ));
1142        assert_eq!(headers.security_token.as_deref(), Some("session-token"));
1143        assert_eq!(headers.amz_date, "20260210T080000Z");
1144        assert_eq!(headers.payload_hash.len(), 64);
1145    }
1146
1147    #[test]
1148    fn response_to_message_maps_tool_use_and_usage() {
1149        let provider = BedrockProvider::new("anthropic.claude-3-5-sonnet-20240620-v1:0");
1150        let response: BedrockConverseResponse = serde_json::from_value(json!({
1151            "output": {
1152                "message": {
1153                    "role": "assistant",
1154                    "content": [
1155                        {"text": "I can help."},
1156                        {"toolUse": {"toolUseId": "call_1", "name": "search", "input": {"q": "rust"}}}
1157                    ]
1158                }
1159            },
1160            "stopReason": "tool_use",
1161            "usage": {"inputTokens": 10, "outputTokens": 5, "totalTokens": 15}
1162        }))
1163        .expect("parse response");
1164
1165        let message = provider.response_to_message(response);
1166        assert_eq!(message.stop_reason, StopReason::ToolUse);
1167        assert_eq!(message.usage.input, 10);
1168        assert_eq!(message.usage.output, 5);
1169        assert_eq!(message.usage.total_tokens, 15);
1170        assert!(matches!(message.content[0], ContentBlock::Text(_)));
1171        assert!(matches!(message.content[1], ContentBlock::ToolCall(_)));
1172    }
1173
1174    #[test]
1175    fn resolve_auth_context_uses_stream_option_api_key_fallback() {
1176        let temp_dir = tempfile::tempdir().expect("tempdir");
1177        let provider =
1178            BedrockProvider::new("model").with_auth_path(temp_dir.path().join("auth.json"));
1179        let auth = provider
1180            .resolve_auth_context(&StreamOptions {
1181                api_key: Some("bedrock-bearer".to_string()),
1182                ..StreamOptions::default()
1183            })
1184            .expect("resolve auth context");
1185        assert!(matches!(auth.auth, BedrockAuth::Bearer { .. }));
1186    }
1187
1188    fn make_bedrock_tool_result(content: Vec<ContentBlock>, is_error: bool) -> ToolResultMessage {
1189        ToolResultMessage {
1190            tool_call_id: "tool_42".to_string(),
1191            tool_name: "test_tool".to_string(),
1192            content,
1193            details: None,
1194            is_error,
1195            timestamp: 0,
1196        }
1197    }
1198
1199    #[test]
1200    fn tool_result_text_only_serializes_correctly() {
1201        let msg = make_bedrock_tool_result(
1202            vec![ContentBlock::Text(TextContent {
1203                text: "found it".to_string(),
1204                text_signature: None,
1205            })],
1206            false,
1207        );
1208        let bedrock_msg = convert_tool_result_message(&msg);
1209        let value = serde_json::to_value(&bedrock_msg).expect("serialize");
1210        assert_eq!(value["role"], "user");
1211        let tool_result = &value["content"][0]["toolResult"];
1212        assert_eq!(tool_result["toolUseId"], "tool_42");
1213        assert_eq!(tool_result["status"], "success");
1214        assert_eq!(tool_result["content"][0]["text"], "found it");
1215        assert_eq!(tool_result["content"].as_array().unwrap().len(), 1);
1216    }
1217
1218    #[test]
1219    fn tool_result_image_only_uses_native_image_format() {
1220        let msg = make_bedrock_tool_result(
1221            vec![ContentBlock::Image(crate::model::ImageContent {
1222                data: "aW1hZ2U=".to_string(),
1223                mime_type: "image/png".to_string(),
1224            })],
1225            false,
1226        );
1227        let bedrock_msg = convert_tool_result_message(&msg);
1228        let value = serde_json::to_value(&bedrock_msg).expect("serialize");
1229        let tool_result = &value["content"][0]["toolResult"];
1230        let content = &tool_result["content"][0];
1231        assert_eq!(content["image"]["format"], "png");
1232        assert_eq!(content["image"]["source"]["bytes"], "aW1hZ2U=");
1233    }
1234
1235    #[test]
1236    fn tool_result_mixed_text_and_image_preserves_both() {
1237        let msg = make_bedrock_tool_result(
1238            vec![
1239                ContentBlock::Text(TextContent {
1240                    text: "description".to_string(),
1241                    text_signature: None,
1242                }),
1243                ContentBlock::Image(crate::model::ImageContent {
1244                    data: "anBlZw==".to_string(),
1245                    mime_type: "image/jpeg".to_string(),
1246                }),
1247            ],
1248            false,
1249        );
1250        let bedrock_msg = convert_tool_result_message(&msg);
1251        let value = serde_json::to_value(&bedrock_msg).expect("serialize");
1252        let contents = value["content"][0]["toolResult"]["content"]
1253            .as_array()
1254            .expect("content array");
1255        assert_eq!(contents.len(), 2);
1256        assert_eq!(contents[0]["text"], "description");
1257        assert_eq!(contents[1]["image"]["format"], "jpeg");
1258    }
1259
1260    #[test]
1261    fn tool_result_empty_content_falls_back_to_empty_json() {
1262        let msg = make_bedrock_tool_result(vec![], false);
1263        let bedrock_msg = convert_tool_result_message(&msg);
1264        let value = serde_json::to_value(&bedrock_msg).expect("serialize");
1265        let contents = value["content"][0]["toolResult"]["content"]
1266            .as_array()
1267            .expect("content array");
1268        assert_eq!(contents.len(), 1);
1269        assert_eq!(contents[0]["text"], "{}");
1270    }
1271
1272    #[test]
1273    fn tool_result_error_sets_error_status() {
1274        let msg = make_bedrock_tool_result(
1275            vec![ContentBlock::Text(TextContent {
1276                text: "not found".to_string(),
1277                text_signature: None,
1278            })],
1279            true,
1280        );
1281        let bedrock_msg = convert_tool_result_message(&msg);
1282        let value = serde_json::to_value(&bedrock_msg).expect("serialize");
1283        assert_eq!(value["content"][0]["toolResult"]["status"], "error");
1284    }
1285
1286    #[test]
1287    fn tool_result_image_mime_extracts_format() {
1288        let msg = make_bedrock_tool_result(
1289            vec![ContentBlock::Image(crate::model::ImageContent {
1290                data: "data".to_string(),
1291                mime_type: "image/webp".to_string(),
1292            })],
1293            false,
1294        );
1295        let bedrock_msg = convert_tool_result_message(&msg);
1296        let value = serde_json::to_value(&bedrock_msg).expect("serialize");
1297        assert_eq!(
1298            value["content"][0]["toolResult"]["content"][0]["image"]["format"],
1299            "webp"
1300        );
1301    }
1302}