Skip to main content

agent_core_runtime/client/providers/bedrock/
mod.rs

1//! Amazon Bedrock API provider implementation.
2//!
3//! Uses the Converse API for chat completions with AWS SigV4 authentication.
4
5mod signing;
6mod sse;
7mod types;
8
9use async_stream::stream;
10use futures::Stream;
11
12use crate::client::error::LlmError;
13use crate::client::http::HttpClient;
14use crate::client::models::{Message, MessageOptions, StreamEvent};
15use crate::client::traits::LlmProvider;
16use std::future::Future;
17use std::pin::Pin;
18
19// =============================================================================
20// Provider
21// =============================================================================
22
23/// AWS Bedrock credentials.
24#[derive(Clone)]
25pub struct BedrockCredentials {
26    /// AWS access key ID.
27    pub access_key_id: String,
28    /// AWS secret access key.
29    pub secret_access_key: String,
30    /// Optional session token for temporary credentials.
31    pub session_token: Option<String>,
32}
33
34impl BedrockCredentials {
35    /// Create new credentials with access key and secret.
36    pub fn new(access_key_id: impl Into<String>, secret_access_key: impl Into<String>) -> Self {
37        Self {
38            access_key_id: access_key_id.into(),
39            secret_access_key: secret_access_key.into(),
40            session_token: None,
41        }
42    }
43
44    /// Create new credentials with session token (for temporary/assumed role credentials).
45    pub fn with_session_token(
46        access_key_id: impl Into<String>,
47        secret_access_key: impl Into<String>,
48        session_token: impl Into<String>,
49    ) -> Self {
50        Self {
51            access_key_id: access_key_id.into(),
52            secret_access_key: secret_access_key.into(),
53            session_token: Some(session_token.into()),
54        }
55    }
56}
57
58/// Amazon Bedrock API provider.
59///
60/// Uses the Converse API for chat completions.
61/// Requires AWS credentials and region configuration.
62pub struct BedrockProvider {
63    /// AWS credentials.
64    credentials: BedrockCredentials,
65    /// AWS region (e.g., "us-east-1").
66    region: String,
67    /// Model identifier (e.g., "anthropic.claude-3-sonnet-20240229-v1:0").
68    model: String,
69}
70
71impl BedrockProvider {
72    /// Create a new Bedrock provider.
73    ///
74    /// # Arguments
75    /// * `credentials` - AWS credentials
76    /// * `region` - AWS region (e.g., "us-east-1")
77    /// * `model` - Bedrock model ID (e.g., "anthropic.claude-3-sonnet-20240229-v1:0")
78    pub fn new(credentials: BedrockCredentials, region: String, model: String) -> Self {
79        Self {
80            credentials,
81            region,
82            model,
83        }
84    }
85
86    /// Returns the model identifier.
87    pub fn model(&self) -> &str {
88        &self.model
89    }
90
91    /// Returns the AWS region.
92    pub fn region(&self) -> &str {
93        &self.region
94    }
95}
96
97impl LlmProvider for BedrockProvider {
98    fn send_msg(
99        &self,
100        client: &HttpClient,
101        messages: &[Message],
102        options: &MessageOptions,
103    ) -> Pin<Box<dyn Future<Output = Result<Message, LlmError>> + Send>> {
104        // Clone data for the async block
105        let client = client.clone();
106        let credentials = self.credentials.clone();
107        let region = self.region.clone();
108        let model = options.model.as_deref().unwrap_or(&self.model).to_string();
109        let messages = messages.to_vec();
110        let options = options.clone();
111
112        Box::pin(async move {
113            // Build request body
114            let body = types::build_request_body(&messages, &options)?;
115
116            // Get the API URL for this model
117            let url = types::get_converse_url(&region, &model);
118
119            // Sign the request with AWS SigV4
120            let headers = signing::sign_request(
121                &credentials,
122                &region,
123                "POST",
124                &url,
125                &body,
126                false, // not streaming
127            )?;
128
129            let headers_ref: Vec<(&str, &str)> = headers
130                .iter()
131                .map(|(k, v)| (k.as_str(), v.as_str()))
132                .collect();
133
134            // Make the API call
135            let response = client.post(&url, &headers_ref, &body).await?;
136
137            // Parse and return the response
138            types::parse_response(&response)
139        })
140    }
141
142    fn send_msg_stream(
143        &self,
144        client: &HttpClient,
145        messages: &[Message],
146        options: &MessageOptions,
147    ) -> Pin<Box<dyn Future<Output = Result<Pin<Box<dyn Stream<Item = Result<StreamEvent, LlmError>> + Send>>, LlmError>> + Send>> {
148        // Clone data for the async block
149        let client = client.clone();
150        let credentials = self.credentials.clone();
151        let region = self.region.clone();
152        let model = options.model.as_deref().unwrap_or(&self.model).to_string();
153        let messages = messages.to_vec();
154        let options = options.clone();
155
156        Box::pin(async move {
157            // Build request body (same format for streaming)
158            let body = types::build_request_body(&messages, &options)?;
159
160            // Get the streaming API URL for this model
161            let url = types::get_converse_stream_url(&region, &model);
162
163            // Sign the request with AWS SigV4
164            let headers = signing::sign_request(
165                &credentials,
166                &region,
167                "POST",
168                &url,
169                &body,
170                true, // streaming
171            )?;
172
173            let headers_ref: Vec<(&str, &str)> = headers
174                .iter()
175                .map(|(k, v)| (k.as_str(), v.as_str()))
176                .collect();
177
178            // Make the streaming API call
179            let byte_stream = client.post_stream(&url, &headers_ref, &body).await?;
180
181            // Convert byte stream to events
182            // Bedrock uses a different format than SSE - it uses event-stream encoding
183            use futures::StreamExt;
184            let event_stream = stream! {
185                let mut buffer = Vec::new();
186                let mut byte_stream = byte_stream;
187                let mut message_started = false;
188                let mut stream_state = sse::StreamState::default();
189
190                while let Some(chunk_result) = byte_stream.next().await {
191                    match chunk_result {
192                        Ok(bytes) => {
193                            buffer.extend_from_slice(&bytes);
194
195                            // Parse complete events from buffer
196                            let (events, remaining) = sse::parse_event_stream(&buffer);
197                            buffer = remaining;
198
199                            for event in events {
200                                match sse::parse_stream_event(&event, &mut stream_state) {
201                                    Ok(stream_events) => {
202                                        // Emit MessageStart on first content
203                                        if !message_started && !stream_events.is_empty() {
204                                            message_started = true;
205                                            yield Ok(StreamEvent::MessageStart {
206                                                message_id: String::new(),
207                                                model: model.clone(),
208                                            });
209                                        }
210
211                                        for stream_event in stream_events {
212                                            yield Ok(stream_event);
213                                        }
214                                    }
215                                    Err(e) => {
216                                        yield Err(e);
217                                        return;
218                                    }
219                                }
220                            }
221                        }
222                        Err(e) => {
223                            yield Err(e);
224                            break;
225                        }
226                    }
227                }
228
229                // Emit MessageStop at the end
230                if message_started {
231                    yield Ok(StreamEvent::MessageStop);
232                }
233            };
234
235            Ok(Box::pin(event_stream) as Pin<Box<dyn Stream<Item = Result<StreamEvent, LlmError>> + Send>>)
236        })
237    }
238}