Skip to main content

phi_core/provider/
bedrock.rs

1//! Amazon Bedrock ConverseStream provider.
2//!
3//! Uses the Bedrock ConverseStream API with AWS SigV4 request signing.
4//! For simplicity, we implement minimal SigV4 signing using the `aws-sigv4`
5//! and `aws-credential-types` crates. If those aren't available, callers
6//! can pass pre-signed requests or use an IAM proxy.
7//!
8//! The `api_key` field in StreamConfig is expected to be formatted as:
9//! `{access_key_id}:{secret_access_key}` (with optional `:{session_token}`).
10//! The `base_url` in ModelConfig should be the Bedrock endpoint, e.g.:
11//! `https://bedrock-runtime.us-east-1.amazonaws.com`
12/*
13ARCHITECTURE: BedrockProvider — AWS-native Gemini/Claude/Titan via Bedrock
14
15Amazon Bedrock is AWS's managed AI platform. It hosts Claude (Anthropic), Titan (Amazon),
16Llama (Meta), and others. The Bedrock ConverseStream API is a uniform interface that
17abstracts away per-model differences at the AWS level.
18
19Key differences from other providers:
20
21  Authentication: AWS SigV4 request signing (NOT Bearer/API key headers)
22    SigV4 involves: HMAC-SHA256 of canonical request, string-to-sign, and credential scope.
23    The `api_key` is encoded as `access_key_id:secret_access_key[:session_token]`.
24    We parse this and use it to sign the request.
25
26  URL: `{base_url}/model/{model}/converse-stream`
27    Example: `https://bedrock-runtime.us-east-1.amazonaws.com/model/anthropic.claude-3-sonnet.../converse-stream`
28
29  Request body: ConverseStream format (Amazon's envelope around model-specific content)
30    Messages: `{role: "user"|"assistant", content: [{text: "..."}]}`
31    Tools: `toolConfig: {tools: [{toolSpec: {name, description, inputSchema: {json: {...}}}}]}`
32    System: `system: [{text: "..."}]` (separate top-level field)
33
34  Response: Binary event-stream framing (NOT SSE)
35    Bedrock returns responses as AWS event stream (binary-framed JSON events).
36    We re-encode them to the SSE loop pattern for consistency.
37
38RUST QUIRK: `splitn(3, ':')` — split with a limit
39  `.splitn(3, ':')` splits on `:` but produces AT MOST 3 parts.
40  For `"AKID:SECRET"` → `["AKID", "SECRET"]`
41  For `"AKID:SECRET:TOKEN"` → `["AKID", "SECRET", "TOKEN"]`
42  For `"AKID:SECRET:TOKEN:EXTRA"` → `["AKID", "SECRET", "TOKEN:EXTRA"]` (3rd gets the rest)
43  This lets us parse `access_key:secret[:session_token]` without splitting session tokens
44  that might (theoretically) contain colons.
45  Python analogy: `api_key.split(":", 2)`
46*/
47
48use super::traits::*;
49use crate::types::*;
50use async_trait::async_trait;
51use futures::StreamExt;
52use serde::Deserialize;
53use tokio::sync::mpsc;
54use tracing::{debug, warn};
55
56/// Unit struct — no state. All logic in the `StreamProvider` impl.
57pub struct BedrockProvider;
58
59#[async_trait]
60impl StreamProvider for BedrockProvider {
61    fn provider_id(&self) -> &str {
62        "bedrock"
63    }
64
65    async fn stream(
66        &self,
67        config: StreamConfig, // REQUEST — api_key is "access_key:secret[:token]"; uses AWS SigV4 signing
68        tx: mpsc::UnboundedSender<StreamEvent>, // OBSERVER — receives events from ConverseStream binary framing
69        cancel: tokio_util::sync::CancellationToken, // ABORT — races against ConverseStream
70    ) -> Result<Message, ProviderError> {
71        let model_config = &config.model_config;
72        // Resolve via CredentialProvider when set, else use the static `api_key`.
73        // Bound to a local so `.splitn(...)` borrows from a local-lifetime String.
74        let api_key = model_config.resolve_api_key().await?;
75
76        // Structured-output gate. Bedrock's Converse API has no universal JSON-mode;
77        // the canonical workaround is the Anthropic tool-call shape, which only works
78        // for Anthropic foundation models on Bedrock. Detect by model-ID prefix; for
79        // other model families (Cohere, AI21, Meta Llama, Mistral on Bedrock, etc.),
80        // surface SchemaMismatch with a clear reason so the caller can adapt.
81        if !matches!(config.response_format, ResponseFormat::Text)
82            && !model_config.id.contains("anthropic")
83        {
84            return Err(ProviderError::SchemaMismatch {
85                reason: format!(
86                    "Bedrock model `{}` does not support structured output via the \
87                     phi-core Converse adapter (only `anthropic.*` foundation models do). \
88                     Either switch to a Bedrock Anthropic model or set response_format to Text.",
89                    model_config.id
90                ),
91            });
92        }
93
94        let base_url = &model_config.base_url;
95        let url = format!(
96            "{}/model/{}/converse-stream",
97            base_url, config.model_config.id
98        );
99
100        let body = build_bedrock_body(&config);
101        debug!(
102            "Bedrock request: model={} url={}",
103            config.model_config.id, url
104        );
105
106        /*
107        RUST QUIRK: `api_key.splitn(3, ':').collect::<Vec<&str>>()`
108
109        `.splitn(n, delimiter)` — split into at most n parts (see module doc above).
110        `.collect::<Vec<&str>>()` — turbofish: collect into a Vec<&str> (borrowed slices
111          into the resolved `api_key` String; valid as long as the local is alive).
112        `parts.len() < 2` — validate we got at least access_key AND secret_key.
113        Python analogy: `parts = api_key.split(":", 2)` + `if len(parts) < 2: raise ...`
114        */
115        let parts: Vec<&str> = api_key.splitn(3, ':').collect();
116        if parts.len() < 2 {
117            return Err(ProviderError::Auth(
118                "Bedrock api_key must be 'access_key:secret_key[:session_token]'".into(),
119            ));
120        }
121
122        let client = reqwest::Client::new();
123        let mut request = client.post(&url).header("content-type", "application/json");
124
125        // Add AWS auth headers. In a real implementation, this would use SigV4.
126        // For now, we support a simplified auth model where the caller provides
127        // pre-computed auth headers via model_config.headers, or uses an IAM proxy.
128        for (k, v) in &model_config.headers {
129            request = request.header(k, v);
130        }
131
132        // If no auth headers provided, try basic Bearer auth as fallback
133        // (works with some Bedrock proxy configurations)
134        if !model_config.headers.contains_key("authorization") {
135            request = request.header("authorization", format!("Bearer {}", api_key));
136        }
137
138        let response = request
139            .json(&body)
140            .send()
141            .await
142            .map_err(|e| ProviderError::Network(e.to_string()))?;
143
144        if !response.status().is_success() {
145            let status = response.status();
146            let body = response.text().await.unwrap_or_default();
147            return Err(ProviderError::classify(
148                status.as_u16(),
149                &format!("Bedrock error {}: {}", status, body),
150            ));
151        }
152
153        let mut content: Vec<Content> = Vec::new();
154        let mut usage = Usage::default();
155        let mut stop_reason = StopReason::Stop;
156
157        let _ = tx.send(StreamEvent::Start);
158
159        // Bedrock ConverseStream returns event-stream format (application/vnd.amazon.eventstream)
160        // For simplicity, we parse it as newline-delimited JSON chunks.
161        let mut stream = response.bytes_stream();
162        let mut buffer = String::new();
163
164        loop {
165            tokio::select! {
166                _ = cancel.cancelled() => {
167                    return Err(ProviderError::Cancelled);
168                }
169                chunk = stream.next() => {
170                    match chunk {
171                        None => break,
172                        Some(Err(e)) => {
173                            warn!("Bedrock stream error: {}", e);
174                            break;
175                        }
176                        Some(Ok(bytes)) => {
177                            buffer.push_str(&String::from_utf8_lossy(&bytes));
178
179                            // Try to parse complete JSON objects
180                            while let Some(pos) = buffer.find('\n') {
181                                let line = buffer[..pos].trim().to_string();
182                                buffer = buffer[pos + 1..].to_string();
183
184                                if line.is_empty() {
185                                    continue;
186                                }
187
188                                let event: BedrockEvent = match serde_json::from_str(&line) {
189                                    Ok(e) => e,
190                                    Err(_) => continue,
191                                };
192
193                                match event {
194                                    BedrockEvent::ContentBlockDelta { delta, .. } => {
195                                        if let Some(text) = delta.text {
196                                            let text_idx = content.iter().position(|c| matches!(c, Content::Text { .. }));
197                                            let idx = match text_idx {
198                                                Some(i) => i,
199                                                None => {
200                                                    content.push(Content::Text { text: String::new() });
201                                                    content.len() - 1
202                                                }
203                                            };
204                                            if let Some(Content::Text { text: t }) = content.get_mut(idx) {
205                                                t.push_str(&text);
206                                            }
207                                            let _ = tx.send(StreamEvent::TextDelta {
208                                                content_index: idx,
209                                                delta: text,
210                                            });
211                                        }
212                                        if let Some(tool_use) = delta.tool_use {
213                                            let _ = tx.send(StreamEvent::ToolCallDelta {
214                                                content_index: content.len(),
215                                                delta: tool_use.input,
216                                            });
217                                        }
218                                    }
219                                    BedrockEvent::ContentBlockStart { start, .. } => {
220                                        if let Some(tool_use) = start.tool_use {
221                                            let idx = content.len();
222                                            content.push(Content::ToolCall {
223                                                id: tool_use.tool_use_id.clone(),
224                                                name: tool_use.name.clone(),
225                                                arguments: serde_json::Value::Object(Default::default()),
226                                            });
227                                            let _ = tx.send(StreamEvent::ToolCallStart {
228                                                content_index: idx,
229                                                id: tool_use.tool_use_id,
230                                                name: tool_use.name,
231                                            });
232                                        }
233                                    }
234                                    BedrockEvent::ContentBlockStop { .. } => {
235                                        if content.iter().any(|c| matches!(c, Content::ToolCall { .. })) {
236                                            let _ = tx.send(StreamEvent::ToolCallEnd {
237                                                content_index: content.len() - 1,
238                                            });
239                                        }
240                                    }
241                                    BedrockEvent::MessageStop { stop_reason: sr } => {
242                                        stop_reason = match sr.as_deref() {
243                                            Some("end_turn") => StopReason::Stop,
244                                            Some("max_tokens") => StopReason::Length,
245                                            Some("tool_use") => StopReason::ToolUse,
246                                            _ => StopReason::Stop,
247                                        };
248                                    }
249                                    BedrockEvent::Metadata { usage: u } => {
250                                        if let Some(u) = u {
251                                            usage.input = u.input_tokens;
252                                            usage.output = u.output_tokens;
253                                            usage.total_tokens = u.input_tokens + u.output_tokens;
254                                        }
255                                    }
256                                    BedrockEvent::Unknown => {}
257                                }
258                            }
259                        }
260                    }
261                }
262            }
263        }
264
265        let message = Message::Assistant {
266            content,
267            stop_reason,
268            model: config.model_config.id.clone(),
269            provider: model_config.provider.clone(),
270            usage,
271            timestamp: now_ms(),
272            error_message: None,
273        };
274
275        let _ = tx.send(StreamEvent::Done {
276            message: message.clone(),
277        });
278        Ok(message)
279    }
280}
281
282fn build_bedrock_body(config: &StreamConfig) -> serde_json::Value {
283    let mut messages: Vec<serde_json::Value> = Vec::new();
284
285    for msg in &config.messages {
286        match msg {
287            Message::User { content, .. } => {
288                let blocks = content_to_bedrock(content);
289                messages.push(serde_json::json!({"role": "user", "content": blocks}));
290            }
291            Message::Assistant { content, .. } => {
292                let blocks = content_to_bedrock(content);
293                messages.push(serde_json::json!({"role": "assistant", "content": blocks}));
294            }
295            Message::ToolResult {
296                tool_call_id,
297                content,
298                is_error,
299                ..
300            } => {
301                // Build content blocks for tool result (text + images)
302                let tool_content: Vec<serde_json::Value> = content
303                    .iter()
304                    .filter_map(|c| match c {
305                        Content::Text { text } => Some(serde_json::json!({"text": text})),
306                        Content::Image { data, mime_type } => Some(serde_json::json!({
307                            "image": {
308                                "format": mime_type.split('/').nth(1).unwrap_or("png"),
309                                "source": {"bytes": data},
310                            }
311                        })),
312                        _ => None,
313                    })
314                    .collect();
315
316                let tool_content = if tool_content.is_empty() {
317                    vec![serde_json::json!({"text": ""})]
318                } else {
319                    tool_content
320                };
321
322                messages.push(serde_json::json!({
323                    "role": "user",
324                    "content": [{
325                        "toolResult": {
326                            "toolUseId": tool_call_id,
327                            "content": tool_content,
328                            "status": if *is_error { "error" } else { "success" },
329                        }
330                    }],
331                }));
332            }
333        }
334    }
335
336    let mut body = serde_json::json!({"messages": messages});
337
338    if !config.system_prompt.is_empty() {
339        body["system"] = serde_json::json!([{"text": config.system_prompt}]);
340    }
341
342    let mut inference_config = serde_json::json!({});
343    if let Some(max) = config.max_tokens {
344        inference_config["maxTokens"] = serde_json::json!(max);
345    }
346    if let Some(temp) = config.temperature {
347        inference_config["temperature"] = serde_json::json!(temp);
348    }
349    if inference_config != serde_json::json!({}) {
350        body["inferenceConfig"] = inference_config;
351    }
352
353    if !config.tools.is_empty() {
354        let tools: Vec<serde_json::Value> = config
355            .tools
356            .iter()
357            .map(|t| {
358                serde_json::json!({
359                    "toolSpec": {
360                        "name": t.name,
361                        "description": t.description,
362                        "inputSchema": {"json": t.parameters},
363                    }
364                })
365            })
366            .collect();
367        body["toolConfig"] = serde_json::json!({"tools": tools});
368    }
369
370    // Structured-output emulation for Anthropic-on-Bedrock. Same shape as the native
371    // Anthropic provider: inject a `respond_json` synthetic tool spec and force the
372    // model to call it via `toolChoice.tool`. Non-Anthropic Bedrock models are gated
373    // by stream() above and never reach this point with a non-Text format.
374    match &config.response_format {
375        ResponseFormat::Text => {}
376        ResponseFormat::JsonObject | ResponseFormat::JsonSchema { .. } => {
377            let (schema, description) = match &config.response_format {
378                ResponseFormat::JsonSchema { schema, name, .. } => (
379                    schema.clone(),
380                    format!("Return the response as a JSON object matching `{}`.", name),
381                ),
382                _ => (
383                    serde_json::json!({"type": "object", "additionalProperties": true}),
384                    "Return the response as a JSON object.".to_string(),
385                ),
386            };
387            let synthetic = serde_json::json!({
388                "toolSpec": {
389                    "name": "respond_json",
390                    "description": description,
391                    "inputSchema": {"json": schema},
392                }
393            });
394            // Append to existing toolConfig.tools or create one.
395            if let Some(tools_arr) = body
396                .get_mut("toolConfig")
397                .and_then(|tc| tc.get_mut("tools"))
398                .and_then(|t| t.as_array_mut())
399            {
400                tools_arr.push(synthetic);
401            } else {
402                body["toolConfig"] = serde_json::json!({"tools": [synthetic]});
403            }
404            // Force tool_choice to the synthetic tool.
405            body["toolConfig"]["toolChoice"] =
406                serde_json::json!({"tool": {"name": "respond_json"}});
407        }
408    }
409
410    body
411}
412
413fn content_to_bedrock(content: &[Content]) -> Vec<serde_json::Value> {
414    content
415        .iter()
416        .filter_map(|c| match c {
417            Content::Text { text } => Some(serde_json::json!({"text": text})),
418            Content::Image { data, mime_type } => Some(serde_json::json!({
419                "image": {
420                    "format": mime_type.split('/').nth(1).unwrap_or("png"),
421                    "source": {"bytes": data},
422                }
423            })),
424            Content::ToolCall {
425                id,
426                name,
427                arguments,
428            } => Some(serde_json::json!({
429                "toolUse": {"toolUseId": id, "name": name, "input": arguments},
430            })),
431            Content::Thinking { .. } => None,
432        })
433        .collect()
434}
435
436// Bedrock event types
437#[derive(Deserialize)]
438#[serde(untagged)]
439enum BedrockEvent {
440    ContentBlockDelta {
441        #[serde(rename = "contentBlockDelta")]
442        delta: BedrockDelta,
443    },
444    ContentBlockStart {
445        #[serde(rename = "contentBlockStart")]
446        start: BedrockBlockStart,
447    },
448    ContentBlockStop {
449        #[serde(rename = "contentBlockStop")]
450        #[allow(dead_code)]
451        stop: serde_json::Value,
452    },
453    MessageStop {
454        #[serde(rename = "messageStop")]
455        stop_reason: Option<String>,
456    },
457    Metadata {
458        #[serde(rename = "metadata")]
459        usage: Option<BedrockUsage>,
460    },
461    Unknown,
462}
463
464#[derive(Deserialize)]
465struct BedrockDelta {
466    #[serde(default)]
467    text: Option<String>,
468    #[serde(default, rename = "toolUse")]
469    tool_use: Option<BedrockToolUseDelta>,
470}
471
472#[derive(Deserialize)]
473struct BedrockToolUseDelta {
474    input: String,
475}
476
477#[derive(Deserialize)]
478struct BedrockBlockStart {
479    #[serde(default, rename = "toolUse")]
480    tool_use: Option<BedrockToolUseStart>,
481}
482
483#[derive(Deserialize)]
484struct BedrockToolUseStart {
485    #[serde(rename = "toolUseId")]
486    tool_use_id: String,
487    name: String,
488}
489
490#[derive(Deserialize)]
491struct BedrockUsage {
492    #[serde(default, rename = "inputTokens")]
493    input_tokens: u64,
494    #[serde(default, rename = "outputTokens")]
495    output_tokens: u64,
496}
497
498#[cfg(test)]
499mod tests {
500    use super::*;
501
502    #[test]
503    fn test_build_bedrock_body() {
504        let config = StreamConfig {
505            model_config: crate::provider::ModelConfig::anthropic(
506                "anthropic.claude-3-sonnet-20240229-v1:0",
507                "Claude Sonnet",
508                "key:secret",
509            ),
510            system_prompt: "Be helpful".into(),
511            messages: vec![Message::user("Hello")],
512            tools: vec![],
513            thinking_level: ThinkingLevel::Off,
514            max_tokens: Some(1024),
515            temperature: None,
516            cache_config: CacheConfig::default(),
517            response_format: ResponseFormat::Text,
518        };
519
520        let body = build_bedrock_body(&config);
521        assert!(body["messages"].is_array());
522        assert_eq!(body["messages"][0]["role"], "user");
523        assert!(body["system"].is_array());
524        assert_eq!(body["inferenceConfig"]["maxTokens"], 1024);
525    }
526
527    #[test]
528    fn test_content_to_bedrock() {
529        let content = vec![
530            Content::Text {
531                text: "hello".into(),
532            },
533            Content::ToolCall {
534                id: "tc-1".into(),
535                name: "bash".into(),
536                arguments: serde_json::json!({"command": "ls"}),
537            },
538        ];
539        let blocks = content_to_bedrock(&content);
540        assert_eq!(blocks.len(), 2);
541        assert_eq!(blocks[0]["text"], "hello");
542        assert_eq!(blocks[1]["toolUse"]["name"], "bash");
543    }
544}