Skip to main content

codetether_agent/provider/bedrock/
discovery.rs

1//! Dynamic model discovery via Bedrock management APIs.
2//!
3//! Merges [`ListFoundationModels`](https://docs.aws.amazon.com/bedrock/latest/APIReference/API_ListFoundationModels.html)
4//! and [`ListInferenceProfiles`](https://docs.aws.amazon.com/bedrock/latest/APIReference/API_ListInferenceProfiles.html)
5//! so both on-demand foundation models and cross-region inference profiles
6//! appear in the returned catalog.
7
8use super::BedrockProvider;
9use super::estimates::{estimate_context_window, estimate_max_output};
10use crate::provider::ModelInfo;
11use anyhow::Result;
12use serde_json::Value;
13use std::collections::HashMap;
14
15impl BedrockProvider {
16    /// Dynamically discover available Bedrock models.
17    ///
18    /// Queries both foundation models and system-defined inference profiles,
19    /// filters to text-output chat models, and emits [`ModelInfo`] records.
20    ///
21    /// # Errors
22    ///
23    /// Returns an error only on unrecoverable network setup issues; individual
24    /// API failures are logged and skipped, and the returned list may be empty.
25    ///
26    /// # Examples
27    ///
28    /// ```rust,no_run
29    /// # tokio::runtime::Runtime::new().unwrap().block_on(async {
30    /// use codetether_agent::provider::bedrock::{AwsCredentials, BedrockProvider};
31    /// use codetether_agent::provider::Provider;
32    ///
33    /// let creds = AwsCredentials::from_environment().unwrap();
34    /// let p = BedrockProvider::with_credentials(creds, "us-west-2".into()).unwrap();
35    /// let models = p.list_models().await.unwrap();
36    /// assert!(!models.is_empty());
37    /// # });
38    /// ```
39    pub(super) async fn discover_models(&self) -> Result<Vec<ModelInfo>> {
40        let mut models: HashMap<String, ModelInfo> = HashMap::new();
41        self.discover_foundation_models(&mut models).await;
42        self.discover_inference_profiles(&mut models).await;
43
44        let mut result: Vec<ModelInfo> = models.into_values().collect();
45        result.sort_by(|a, b| a.id.cmp(&b.id));
46
47        tracing::info!(
48            provider = "bedrock",
49            model_count = result.len(),
50            "Discovered Bedrock models dynamically"
51        );
52
53        Ok(result)
54    }
55
56    async fn discover_foundation_models(&self, models: &mut HashMap<String, ModelInfo>) {
57        let fm_url = format!("{}/foundation-models", self.management_url());
58        let Ok(resp) = self.send_request("GET", &fm_url, None, "bedrock").await else {
59            return;
60        };
61        if !resp.status().is_success() {
62            return;
63        }
64        let Ok(data) = resp.json::<Value>().await else {
65            return;
66        };
67        let Some(summaries) = data.get("modelSummaries").and_then(|v| v.as_array()) else {
68            return;
69        };
70
71        for m in summaries {
72            if let Some(info) = foundation_model_to_info(m) {
73                models.insert(info.id.clone(), info);
74            }
75        }
76    }
77
78    async fn discover_inference_profiles(&self, models: &mut HashMap<String, ModelInfo>) {
79        let ip_url = format!(
80            "{}/inference-profiles?typeEquals=SYSTEM_DEFINED&maxResults=200",
81            self.management_url()
82        );
83        let Ok(resp) = self.send_request("GET", &ip_url, None, "bedrock").await else {
84            return;
85        };
86        if !resp.status().is_success() {
87            return;
88        }
89        let Ok(data) = resp.json::<Value>().await else {
90            return;
91        };
92        let Some(profiles) = data
93            .get("inferenceProfileSummaries")
94            .and_then(|v| v.as_array())
95        else {
96            return;
97        };
98
99        for p in profiles {
100            if let Some(info) = inference_profile_to_info(p, models) {
101                models.insert(info.id.clone(), info);
102            }
103        }
104    }
105}
106
107fn foundation_model_to_info(m: &Value) -> Option<ModelInfo> {
108    let model_id = m.get("modelId").and_then(|v| v.as_str()).unwrap_or("");
109    let model_name = m.get("modelName").and_then(|v| v.as_str()).unwrap_or("");
110
111    let output_modalities: Vec<&str> = m
112        .get("outputModalities")
113        .and_then(|v| v.as_array())
114        .map(|a| a.iter().filter_map(|v| v.as_str()).collect::<Vec<_>>())
115        .unwrap_or_default();
116
117    let input_modalities: Vec<&str> = m
118        .get("inputModalities")
119        .and_then(|v| v.as_array())
120        .map(|a| a.iter().filter_map(|v| v.as_str()).collect::<Vec<_>>())
121        .unwrap_or_default();
122
123    let inference_types: Vec<&str> = m
124        .get("inferenceTypesSupported")
125        .and_then(|v| v.as_array())
126        .map(|a| a.iter().filter_map(|v| v.as_str()).collect::<Vec<_>>())
127        .unwrap_or_default();
128
129    if !output_modalities.contains(&"TEXT")
130        || (!inference_types.contains(&"ON_DEMAND")
131            && !inference_types.contains(&"INFERENCE_PROFILE"))
132    {
133        return None;
134    }
135
136    let name_lower = model_name.to_lowercase();
137    if ["rerank", "embed", "safeguard", "sonic", "pegasus"]
138        .iter()
139        .any(|n| name_lower.contains(n))
140    {
141        return None;
142    }
143
144    let streaming = m
145        .get("responseStreamingSupported")
146        .and_then(|v| v.as_bool())
147        .unwrap_or(false);
148    let vision = input_modalities.contains(&"IMAGE");
149
150    let actual_id = if model_id.starts_with("amazon.") {
151        model_id.to_string()
152    } else if inference_types.contains(&"INFERENCE_PROFILE") {
153        format!("us.{model_id}")
154    } else {
155        model_id.to_string()
156    };
157
158    Some(ModelInfo {
159        id: actual_id.clone(),
160        name: format!("{model_name} (Bedrock)"),
161        provider: "bedrock".to_string(),
162        context_window: estimate_context_window(model_id),
163        max_output_tokens: Some(estimate_max_output(model_id)),
164        supports_vision: vision,
165        supports_tools: true,
166        supports_streaming: streaming,
167        input_cost_per_million: None,
168        output_cost_per_million: None,
169    })
170}
171
172fn inference_profile_to_info(p: &Value, models: &HashMap<String, ModelInfo>) -> Option<ModelInfo> {
173    let pid = p
174        .get("inferenceProfileId")
175        .and_then(|v| v.as_str())
176        .unwrap_or("");
177    let pname = p
178        .get("inferenceProfileName")
179        .and_then(|v| v.as_str())
180        .unwrap_or("");
181
182    if !pid.starts_with("us.") || models.contains_key(pid) {
183        return None;
184    }
185
186    let name_lower = pname.to_lowercase();
187    let skip_tokens = [
188        "image",
189        "stable ",
190        "upscale",
191        "embed",
192        "marengo",
193        "outpaint",
194        "inpaint",
195        "erase",
196        "recolor",
197        "replace",
198        "style ",
199        "background",
200        "sketch",
201        "control",
202        "transfer",
203        "sonic",
204        "pegasus",
205        "rerank",
206    ];
207    if skip_tokens.iter().any(|t| name_lower.contains(t)) {
208        return None;
209    }
210
211    let vision = pid.contains("llama3-2-11b")
212        || pid.contains("llama3-2-90b")
213        || pid.contains("pixtral")
214        || pid.contains("claude-3")
215        || pid.contains("claude-sonnet-4")
216        || pid.contains("claude-opus-4")
217        || pid.contains("claude-haiku-4");
218
219    let display_name = pname.replace("US ", "");
220    let display_name = format!("{} (Bedrock)", display_name.trim());
221
222    Some(ModelInfo {
223        id: pid.to_string(),
224        name: display_name,
225        provider: "bedrock".to_string(),
226        context_window: estimate_context_window(pid),
227        max_output_tokens: Some(estimate_max_output(pid)),
228        supports_vision: vision,
229        supports_tools: true,
230        supports_streaming: true,
231        input_cost_per_million: None,
232        output_cost_per_million: None,
233    })
234}