Skip to main content

codetether_agent/provider/bedrock/
mod.rs

1//! Amazon Bedrock provider for the Converse API.
2//!
3//! Supports both AWS IAM SigV4 credentials and opaque bearer tokens (e.g.
4//! from an API Gateway fronting Bedrock). Dispatches chat completions, tool
5//! calls, and model discovery via the `bedrock-runtime` and `bedrock`
6//! management APIs.
7//!
8//! # Quick Start
9//!
10//! ```rust,no_run
11//! # tokio::runtime::Runtime::new().unwrap().block_on(async {
12//! use codetether_agent::provider::Provider;
13//! use codetether_agent::provider::bedrock::{AwsCredentials, BedrockProvider};
14//!
15//! let creds = AwsCredentials::from_environment().expect("creds");
16//! let region = AwsCredentials::detect_region().unwrap_or_else(|| "us-west-2".into());
17//! let provider = BedrockProvider::with_credentials(creds, region).unwrap();
18//! let models = provider.list_models().await.unwrap();
19//! assert!(!models.is_empty());
20//! # });
21//! ```
22//!
23//! # Architecture
24//!
25//! - [`auth`] — credential loading, SigV4/bearer auth mode enum
26//! - [`aliases`] — short-name to full model ID mapping
27//! - [`sigv4`] — SigV4 signing + HTTP dispatch
28//! - [`estimates`] — context-window / max-output heuristics
29//! - [`convert`] — [`Message`] → Converse API JSON translation
30//! - [`body`] — Converse request body builder
31//! - [`response`] — Converse response parser
32//! - [`discovery`] — dynamic model list via management APIs
33
34pub mod aliases;
35pub mod auth;
36pub mod body;
37pub mod convert;
38pub mod discovery;
39pub mod estimates;
40pub mod eventstream;
41pub mod response;
42pub mod retry;
43pub mod sigv4;
44pub mod stream;
45
46pub use aliases::resolve_model_id;
47pub use auth::{AwsCredentials, BedrockAuth};
48pub use body::build_converse_body;
49pub use convert::{convert_messages, convert_tools};
50pub use estimates::{estimate_context_window, estimate_max_output};
51pub use response::{BedrockError, parse_converse_response};
52
53use crate::provider::{CompletionRequest, CompletionResponse, ModelInfo, Provider, StreamChunk};
54use crate::util;
55use anyhow::{Context, Result};
56use async_trait::async_trait;
57use reqwest::Client;
58use std::fmt;
59
60/// Default AWS region when none is configured via env or config file.
61pub const DEFAULT_REGION: &str = "us-east-1";
62
63/// Amazon Bedrock provider implementation.
64///
65/// Clone-able and cheap to copy (wraps an [`Arc`]-backed [`reqwest::Client`]).
66///
67/// # Examples
68///
69/// ```rust,no_run
70/// use codetether_agent::provider::bedrock::{AwsCredentials, BedrockProvider};
71///
72/// let creds = AwsCredentials::from_environment().unwrap();
73/// let p = BedrockProvider::with_credentials(creds, "us-west-2".into()).unwrap();
74/// assert_eq!(p.region(), "us-west-2");
75/// ```
76#[derive(Clone)]
77pub struct BedrockProvider {
78    pub(crate) client: Client,
79    pub(crate) auth: BedrockAuth,
80    pub(crate) region: String,
81}
82
83impl fmt::Debug for BedrockProvider {
84    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
85        f.debug_struct("BedrockProvider")
86            .field(
87                "auth",
88                &match &self.auth {
89                    BedrockAuth::SigV4(_) => "SigV4",
90                    BedrockAuth::BearerToken(_) => "BearerToken",
91                },
92            )
93            .field("region", &self.region)
94            .finish()
95    }
96}
97
98impl BedrockProvider {
99    /// Create a bearer-token provider in the default region.
100    ///
101    /// # Errors
102    ///
103    /// Currently infallible, but returns [`Result`] for API symmetry with
104    /// [`BedrockProvider::with_credentials`] and future validation needs.
105    ///
106    /// # Examples
107    ///
108    /// ```rust
109    /// use codetether_agent::provider::bedrock::BedrockProvider;
110    /// let p = BedrockProvider::new("token-abc".into()).unwrap();
111    /// assert_eq!(p.region(), "us-east-1");
112    /// ```
113    #[allow(dead_code)]
114    pub fn new(api_key: String) -> Result<Self> {
115        Self::with_region(api_key, DEFAULT_REGION.to_string())
116    }
117
118    /// Create a bearer-token provider with an explicit region.
119    ///
120    /// # Examples
121    ///
122    /// ```rust
123    /// use codetether_agent::provider::bedrock::BedrockProvider;
124    /// let p = BedrockProvider::with_region("token".into(), "eu-west-1".into()).unwrap();
125    /// assert_eq!(p.region(), "eu-west-1");
126    /// ```
127    pub fn with_region(api_key: String, region: String) -> Result<Self> {
128        tracing::debug!(
129            provider = "bedrock",
130            region = %region,
131            auth = "bearer_token",
132            "Creating Bedrock provider"
133        );
134        Ok(Self {
135            client: Client::new(),
136            auth: BedrockAuth::BearerToken(api_key),
137            region,
138        })
139    }
140
141    /// Create a SigV4 provider from AWS IAM credentials.
142    ///
143    /// # Examples
144    ///
145    /// ```rust,no_run
146    /// use codetether_agent::provider::bedrock::{AwsCredentials, BedrockProvider};
147    /// let creds = AwsCredentials::from_environment().unwrap();
148    /// let p = BedrockProvider::with_credentials(creds, "us-west-2".into()).unwrap();
149    /// assert_eq!(p.region(), "us-west-2");
150    /// ```
151    pub fn with_credentials(credentials: AwsCredentials, region: String) -> Result<Self> {
152        tracing::debug!(
153            provider = "bedrock",
154            region = %region,
155            auth = "sigv4",
156            "Creating Bedrock provider with AWS credentials"
157        );
158        Ok(Self {
159            client: Client::new(),
160            auth: BedrockAuth::SigV4(credentials),
161            region,
162        })
163    }
164
165    /// Return the configured AWS region.
166    pub fn region(&self) -> &str {
167        &self.region
168    }
169
170    /// Resolve a short model alias to a full Bedrock model ID.
171    ///
172    /// See [`aliases::resolve_model_id`] for details.
173    pub fn resolve_model_id(model: &str) -> &str {
174        aliases::resolve_model_id(model)
175    }
176
177    /// Build a Converse request body. See [`body::build_converse_body`].
178    pub fn build_converse_body(
179        &self,
180        request: &CompletionRequest,
181        model_id: &str,
182    ) -> serde_json::Value {
183        body::build_converse_body(request, model_id)
184    }
185
186    /// Validate that credentials/token are non-empty.
187    ///
188    /// # Errors
189    ///
190    /// Returns [`anyhow::Error`] if the bearer token is empty or AWS keys
191    /// are incomplete.
192    pub(crate) fn validate_auth(&self) -> Result<()> {
193        match &self.auth {
194            BedrockAuth::BearerToken(key) => {
195                if key.is_empty() {
196                    anyhow::bail!("Bedrock API key is empty");
197                }
198            }
199            BedrockAuth::SigV4(creds) => {
200                if creds.access_key_id.is_empty() || creds.secret_access_key.is_empty() {
201                    anyhow::bail!("AWS credentials are incomplete");
202                }
203            }
204        }
205        Ok(())
206    }
207}
208
209#[async_trait]
210impl Provider for BedrockProvider {
211    fn name(&self) -> &str {
212        "bedrock"
213    }
214
215    async fn list_models(&self) -> Result<Vec<ModelInfo>> {
216        self.validate_auth()?;
217        self.discover_models().await
218    }
219
220    async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
221        let model_id = Self::resolve_model_id(&request.model);
222
223        tracing::debug!(
224            provider = "bedrock",
225            model = %model_id,
226            original_model = %request.model,
227            message_count = request.messages.len(),
228            tool_count = request.tools.len(),
229            "Starting Bedrock Converse request"
230        );
231
232        self.validate_auth()?;
233
234        let body = self.build_converse_body(&request, model_id);
235
236        // Keep the runtime URL readable; the SigV4 signer canonicalizes path
237        // segments so model suffixes like `:0` are encoded exactly once when
238        // constructing the signature.
239        let url = format!("{}/model/{}/converse", self.base_url(), model_id);
240        tracing::debug!("Bedrock request URL: {}", url);
241
242        let body_bytes = serde_json::to_vec(&body)?;
243        let policy = retry::RetryPolicy::default();
244
245        for attempt in 1..=policy.max_attempts {
246            let response = self
247                .send_request("POST", &url, Some(&body_bytes), "bedrock")
248                .await?;
249
250            let status = response.status();
251            let text = response
252                .text()
253                .await
254                .context("Failed to read Bedrock response")?;
255
256            if status.is_success() {
257                return parse_converse_response(&text);
258            }
259
260            let retryable =
261                retry::should_retry_status(status.as_u16()) && attempt < policy.max_attempts;
262            if retryable {
263                let sleep = policy.delay_for(attempt);
264                tracing::warn!(
265                    provider = "bedrock",
266                    status = %status,
267                    attempt,
268                    sleep_ms = sleep.as_millis() as u64,
269                    "Retrying Bedrock request after transient error"
270                );
271                tokio::time::sleep(sleep).await;
272                continue;
273            }
274
275            if let Ok(err) = serde_json::from_str::<BedrockError>(&text) {
276                anyhow::bail!("Bedrock API error ({}): {}", status, err.message);
277            }
278            anyhow::bail!(
279                "Bedrock API error: {} {}",
280                status,
281                util::truncate_bytes_safe(&text, 500)
282            );
283        }
284
285        unreachable!("retry loop exits via return or bail!");
286    }
287
288    async fn complete_stream(
289        &self,
290        request: CompletionRequest,
291    ) -> Result<futures::stream::BoxStream<'static, StreamChunk>> {
292        let model_id = Self::resolve_model_id(&request.model);
293        self.validate_auth()?;
294
295        let body = self.build_converse_body(&request, model_id);
296        let body_bytes = serde_json::to_vec(&body)?;
297        self.converse_stream(model_id, body_bytes).await
298    }
299}
300
301#[cfg(test)]
302mod tests;