Skip to main content

llm/catalog/
bedrock.rs

1use std::borrow::Cow;
2use std::str::FromStr;
3
4use crate::ReasoningEffort;
5use crate::catalog::BedrockFoundationModel;
6
7#[derive(Debug, Clone, PartialEq, Eq, Hash)]
8pub enum BedrockModel {
9    Foundation(BedrockFoundationModel),
10    Profile(String),
11}
12
13impl BedrockModel {
14    pub fn model_id(&self) -> Cow<'static, str> {
15        match self {
16            Self::Foundation(m) => Cow::Borrowed(m.model_id()),
17            Self::Profile(s) => Cow::Owned(s.clone()),
18        }
19    }
20
21    pub fn display_name(&self) -> Cow<'static, str> {
22        match self {
23            Self::Foundation(m) => Cow::Borrowed(m.display_name()),
24            Self::Profile(s) => Cow::Owned(format!("Bedrock {s}")),
25        }
26    }
27
28    pub fn context_window(&self) -> Option<u32> {
29        match self {
30            Self::Foundation(m) => Some(m.context_window()),
31            Self::Profile(_) => None,
32        }
33    }
34
35    pub fn reasoning_levels(&self) -> &'static [ReasoningEffort] {
36        match self {
37            Self::Foundation(m) => m.reasoning_levels(),
38            Self::Profile(_) => &[],
39        }
40    }
41
42    pub fn supports_reasoning(&self) -> bool {
43        !self.reasoning_levels().is_empty()
44    }
45
46    pub fn supports_prompt_caching(&self) -> bool {
47        match self {
48            Self::Foundation(m) => m.supports_prompt_caching(),
49            Self::Profile(_) => false,
50        }
51    }
52
53    pub fn supports_image(&self) -> bool {
54        match self {
55            Self::Foundation(m) => m.supports_image(),
56            Self::Profile(_) => false,
57        }
58    }
59
60    pub fn supports_audio(&self) -> bool {
61        match self {
62            Self::Foundation(m) => m.supports_audio(),
63            Self::Profile(_) => false,
64        }
65    }
66}
67
68impl FromStr for BedrockModel {
69    type Err = String;
70
71    fn from_str(s: &str) -> Result<Self, Self::Err> {
72        match s.parse::<BedrockFoundationModel>() {
73            Ok(m) => Ok(Self::Foundation(m)),
74            Err(_) if is_bedrock_inference_profile_arn(s) => Err(
75                "Bedrock inference profile ARNs must be configured as providers.bedrock.inferenceProfileArn; keep model as bedrock:<model-id>".to_string(),
76            ),
77            Err(_) => Ok(Self::Profile(s.to_string())),
78        }
79    }
80}
81
82fn is_bedrock_inference_profile_arn(s: &str) -> bool {
83    let Some(rest) = s.strip_prefix("arn:") else {
84        return false;
85    };
86    let parts: Vec<&str> = rest.split(':').collect();
87    matches!(
88        parts.as_slice(),
89        [partition, "bedrock", _, _, resource, ..]
90            if partition.starts_with("aws")
91                && (resource.starts_with("inference-profile/")
92                    || resource.starts_with("application-inference-profile/"))
93    )
94}
95
96#[cfg(test)]
97mod tests {
98    use super::*;
99
100    #[test]
101    fn foundation_model_parses() {
102        let model: BedrockModel = "anthropic.claude-sonnet-4-5-20250929-v1:0".parse().unwrap();
103        assert!(matches!(model, BedrockModel::Foundation(_)));
104    }
105
106    #[test]
107    fn unknown_profile_id_falls_through_to_profile_variant() {
108        let model: BedrockModel = "us.anthropic.claude-future-model-v99:0".parse().unwrap();
109        assert!(matches!(model, BedrockModel::Profile(_)));
110        assert_eq!(model.context_window(), None);
111    }
112
113    #[test]
114    fn inference_profile_arn_is_rejected() {
115        let error =
116            "arn:aws:bedrock:us-west-2:000000000000:inference-profile/us.anthropic.claude-sonnet-4-5-20250929-v1:0"
117                .parse::<BedrockModel>()
118                .unwrap_err();
119        assert!(error.contains("providers.bedrock.inferenceProfileArn"));
120    }
121
122    #[test]
123    fn application_inference_profile_arn_is_rejected() {
124        let error = "arn:aws:bedrock:us-west-2:000000000000:application-inference-profile/000000000000"
125            .parse::<BedrockModel>()
126            .unwrap_err();
127        assert!(error.contains("providers.bedrock.inferenceProfileArn"));
128    }
129
130    #[test]
131    fn gov_cloud_arn_is_rejected() {
132        let error = "arn:aws-us-gov:bedrock:us-gov-west-1:000000000000:application-inference-profile/000000000000"
133            .parse::<BedrockModel>()
134            .unwrap_err();
135        assert!(error.contains("providers.bedrock.inferenceProfileArn"));
136    }
137
138    #[test]
139    fn non_bedrock_arn_falls_through_to_profile() {
140        let model: BedrockModel = "arn:aws:s3:us-west-2:000000000000:bucket/foo".parse().unwrap();
141        assert!(matches!(model, BedrockModel::Profile(_)));
142    }
143}