agent_core_runtime/client/providers/bedrock/
mod.rs1mod signing;
6mod sse;
7mod types;
8
9use async_stream::stream;
10use futures::Stream;
11
12use crate::client::error::LlmError;
13use crate::client::http::HttpClient;
14use crate::client::models::{Message, MessageOptions, StreamEvent};
15use crate::client::traits::LlmProvider;
16use std::future::Future;
17use std::pin::Pin;
18
19#[derive(Clone)]
25pub struct BedrockCredentials {
26 pub access_key_id: String,
28 pub secret_access_key: String,
30 pub session_token: Option<String>,
32}
33
34impl BedrockCredentials {
35 pub fn new(access_key_id: impl Into<String>, secret_access_key: impl Into<String>) -> Self {
37 Self {
38 access_key_id: access_key_id.into(),
39 secret_access_key: secret_access_key.into(),
40 session_token: None,
41 }
42 }
43
44 pub fn with_session_token(
46 access_key_id: impl Into<String>,
47 secret_access_key: impl Into<String>,
48 session_token: impl Into<String>,
49 ) -> Self {
50 Self {
51 access_key_id: access_key_id.into(),
52 secret_access_key: secret_access_key.into(),
53 session_token: Some(session_token.into()),
54 }
55 }
56}
57
58pub struct BedrockProvider {
63 credentials: BedrockCredentials,
65 region: String,
67 model: String,
69}
70
71impl BedrockProvider {
72 pub fn new(credentials: BedrockCredentials, region: String, model: String) -> Self {
79 Self {
80 credentials,
81 region,
82 model,
83 }
84 }
85
86 pub fn model(&self) -> &str {
88 &self.model
89 }
90
91 pub fn region(&self) -> &str {
93 &self.region
94 }
95}
96
97impl LlmProvider for BedrockProvider {
98 fn send_msg(
99 &self,
100 client: &HttpClient,
101 messages: &[Message],
102 options: &MessageOptions,
103 ) -> Pin<Box<dyn Future<Output = Result<Message, LlmError>> + Send>> {
104 let client = client.clone();
106 let credentials = self.credentials.clone();
107 let region = self.region.clone();
108 let model = options.model.as_deref().unwrap_or(&self.model).to_string();
109 let messages = messages.to_vec();
110 let options = options.clone();
111
112 Box::pin(async move {
113 let body = types::build_request_body(&messages, &options)?;
115
116 let url = types::get_converse_url(®ion, &model);
118
119 let headers = signing::sign_request(
121 &credentials,
122 ®ion,
123 "POST",
124 &url,
125 &body,
126 false, )?;
128
129 let headers_ref: Vec<(&str, &str)> = headers
130 .iter()
131 .map(|(k, v)| (k.as_str(), v.as_str()))
132 .collect();
133
134 let response = client.post(&url, &headers_ref, &body).await?;
136
137 types::parse_response(&response)
139 })
140 }
141
142 fn send_msg_stream(
143 &self,
144 client: &HttpClient,
145 messages: &[Message],
146 options: &MessageOptions,
147 ) -> Pin<Box<dyn Future<Output = Result<Pin<Box<dyn Stream<Item = Result<StreamEvent, LlmError>> + Send>>, LlmError>> + Send>> {
148 let client = client.clone();
150 let credentials = self.credentials.clone();
151 let region = self.region.clone();
152 let model = options.model.as_deref().unwrap_or(&self.model).to_string();
153 let messages = messages.to_vec();
154 let options = options.clone();
155
156 Box::pin(async move {
157 let body = types::build_request_body(&messages, &options)?;
159
160 let url = types::get_converse_stream_url(®ion, &model);
162
163 let headers = signing::sign_request(
165 &credentials,
166 ®ion,
167 "POST",
168 &url,
169 &body,
170 true, )?;
172
173 let headers_ref: Vec<(&str, &str)> = headers
174 .iter()
175 .map(|(k, v)| (k.as_str(), v.as_str()))
176 .collect();
177
178 let byte_stream = client.post_stream(&url, &headers_ref, &body).await?;
180
181 use futures::StreamExt;
184 let event_stream = stream! {
185 let mut buffer = Vec::new();
186 let mut byte_stream = byte_stream;
187 let mut message_started = false;
188 let mut stream_state = sse::StreamState::default();
189
190 while let Some(chunk_result) = byte_stream.next().await {
191 match chunk_result {
192 Ok(bytes) => {
193 buffer.extend_from_slice(&bytes);
194
195 let (events, remaining) = sse::parse_event_stream(&buffer);
197 buffer = remaining;
198
199 for event in events {
200 match sse::parse_stream_event(&event, &mut stream_state) {
201 Ok(stream_events) => {
202 if !message_started && !stream_events.is_empty() {
204 message_started = true;
205 yield Ok(StreamEvent::MessageStart {
206 message_id: String::new(),
207 model: model.clone(),
208 });
209 }
210
211 for stream_event in stream_events {
212 yield Ok(stream_event);
213 }
214 }
215 Err(e) => {
216 yield Err(e);
217 return;
218 }
219 }
220 }
221 }
222 Err(e) => {
223 yield Err(e);
224 break;
225 }
226 }
227 }
228
229 if message_started {
231 yield Ok(StreamEvent::MessageStop);
232 }
233 };
234
235 Ok(Box::pin(event_stream) as Pin<Box<dyn Stream<Item = Result<StreamEvent, LlmError>> + Send>>)
236 })
237 }
238}