codetether_agent/provider/bedrock/
mod.rs1pub mod aliases;
35pub mod auth;
36pub mod body;
37pub mod convert;
38pub mod discovery;
39pub mod estimates;
40pub mod eventstream;
41pub mod response;
42pub mod retry;
43pub mod sigv4;
44pub mod stream;
45
46pub use aliases::resolve_model_id;
47pub use auth::{AwsCredentials, BedrockAuth};
48pub use body::build_converse_body;
49pub use convert::{convert_messages, convert_tools};
50pub use estimates::{estimate_context_window, estimate_max_output};
51pub use response::{BedrockError, parse_converse_response};
52
53use crate::provider::{CompletionRequest, CompletionResponse, ModelInfo, Provider, StreamChunk};
54use crate::util;
55use anyhow::{Context, Result};
56use async_trait::async_trait;
57use reqwest::Client;
58use std::fmt;
59
60pub const DEFAULT_REGION: &str = "us-east-1";
62
63#[derive(Clone)]
77pub struct BedrockProvider {
78 pub(crate) client: Client,
79 pub(crate) auth: BedrockAuth,
80 pub(crate) region: String,
81}
82
83impl fmt::Debug for BedrockProvider {
84 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
85 f.debug_struct("BedrockProvider")
86 .field(
87 "auth",
88 &match &self.auth {
89 BedrockAuth::SigV4(_) => "SigV4",
90 BedrockAuth::BearerToken(_) => "BearerToken",
91 },
92 )
93 .field("region", &self.region)
94 .finish()
95 }
96}
97
98impl BedrockProvider {
99 #[allow(dead_code)]
114 pub fn new(api_key: String) -> Result<Self> {
115 Self::with_region(api_key, DEFAULT_REGION.to_string())
116 }
117
118 pub fn with_region(api_key: String, region: String) -> Result<Self> {
128 tracing::debug!(
129 provider = "bedrock",
130 region = %region,
131 auth = "bearer_token",
132 "Creating Bedrock provider"
133 );
134 Ok(Self {
135 client: Client::new(),
136 auth: BedrockAuth::BearerToken(api_key),
137 region,
138 })
139 }
140
141 pub fn with_credentials(credentials: AwsCredentials, region: String) -> Result<Self> {
152 tracing::debug!(
153 provider = "bedrock",
154 region = %region,
155 auth = "sigv4",
156 "Creating Bedrock provider with AWS credentials"
157 );
158 Ok(Self {
159 client: Client::new(),
160 auth: BedrockAuth::SigV4(credentials),
161 region,
162 })
163 }
164
165 pub fn region(&self) -> &str {
167 &self.region
168 }
169
170 pub fn resolve_model_id(model: &str) -> &str {
174 aliases::resolve_model_id(model)
175 }
176
177 pub fn build_converse_body(
179 &self,
180 request: &CompletionRequest,
181 model_id: &str,
182 ) -> serde_json::Value {
183 body::build_converse_body(request, model_id)
184 }
185
186 pub(crate) fn validate_auth(&self) -> Result<()> {
193 match &self.auth {
194 BedrockAuth::BearerToken(key) => {
195 if key.is_empty() {
196 anyhow::bail!("Bedrock API key is empty");
197 }
198 }
199 BedrockAuth::SigV4(creds) => {
200 if creds.access_key_id.is_empty() || creds.secret_access_key.is_empty() {
201 anyhow::bail!("AWS credentials are incomplete");
202 }
203 }
204 }
205 Ok(())
206 }
207}
208
209#[async_trait]
210impl Provider for BedrockProvider {
211 fn name(&self) -> &str {
212 "bedrock"
213 }
214
215 async fn list_models(&self) -> Result<Vec<ModelInfo>> {
216 self.validate_auth()?;
217 self.discover_models().await
218 }
219
220 async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
221 let model_id = Self::resolve_model_id(&request.model);
222
223 tracing::debug!(
224 provider = "bedrock",
225 model = %model_id,
226 original_model = %request.model,
227 message_count = request.messages.len(),
228 tool_count = request.tools.len(),
229 "Starting Bedrock Converse request"
230 );
231
232 self.validate_auth()?;
233
234 let body = self.build_converse_body(&request, model_id);
235
236 let url = format!("{}/model/{}/converse", self.base_url(), model_id);
240 tracing::debug!("Bedrock request URL: {}", url);
241
242 let body_bytes = serde_json::to_vec(&body)?;
243 let policy = retry::RetryPolicy::default();
244
245 for attempt in 1..=policy.max_attempts {
246 let response = self
247 .send_request("POST", &url, Some(&body_bytes), "bedrock")
248 .await?;
249
250 let status = response.status();
251 let text = response
252 .text()
253 .await
254 .context("Failed to read Bedrock response")?;
255
256 if status.is_success() {
257 return parse_converse_response(&text);
258 }
259
260 let retryable =
261 retry::should_retry_status(status.as_u16()) && attempt < policy.max_attempts;
262 if retryable {
263 let sleep = policy.delay_for(attempt);
264 tracing::warn!(
265 provider = "bedrock",
266 status = %status,
267 attempt,
268 sleep_ms = sleep.as_millis() as u64,
269 "Retrying Bedrock request after transient error"
270 );
271 tokio::time::sleep(sleep).await;
272 continue;
273 }
274
275 if let Ok(err) = serde_json::from_str::<BedrockError>(&text) {
276 anyhow::bail!("Bedrock API error ({}): {}", status, err.message);
277 }
278 anyhow::bail!(
279 "Bedrock API error: {} {}",
280 status,
281 util::truncate_bytes_safe(&text, 500)
282 );
283 }
284
285 unreachable!("retry loop exits via return or bail!");
286 }
287
288 async fn complete_stream(
289 &self,
290 request: CompletionRequest,
291 ) -> Result<futures::stream::BoxStream<'static, StreamChunk>> {
292 let model_id = Self::resolve_model_id(&request.model);
293 self.validate_auth()?;
294
295 let body = self.build_converse_body(&request, model_id);
296 let body_bytes = serde_json::to_vec(&body)?;
297 self.converse_stream(model_id, body_bytes).await
298 }
299}
300
301#[cfg(test)]
302mod tests;