use crate::core::providers::unified_provider::ProviderError;
use std::collections::HashMap;
use std::sync::LazyLock;
pub const AWS_REGIONS: &[&str] = &[
"us-east-1",
"us-east-2",
"us-west-1",
"us-west-2",
"eu-west-1",
"eu-west-2",
"eu-west-3",
"eu-central-1",
"eu-central-2",
"eu-north-1",
"eu-south-1",
"eu-south-2",
"ap-northeast-1",
"ap-northeast-2",
"ap-northeast-3",
"ap-south-1",
"ap-south-2",
"ap-southeast-1",
"ap-southeast-2",
"ap-southeast-3",
"ap-southeast-4",
"ap-southeast-5",
"ca-central-1",
"sa-east-1",
"us-gov-west-1",
"us-gov-east-1",
];
static MODEL_REGION_MAPPING: LazyLock<HashMap<&'static str, &'static [&'static str]>> =
LazyLock::new(|| {
let mut map = HashMap::new();
map.insert("anthropic.claude", AWS_REGIONS);
map.insert(
"amazon.titan",
&["us-east-1", "us-east-2", "us-west-1", "us-west-2"],
);
map.insert("amazon.nova", &["us-east-1", "us-west-2"]);
map.insert(
"ai21",
&["us-east-1", "us-west-2", "eu-west-1", "ap-southeast-2"],
);
map.insert(
"cohere",
&["us-east-1", "us-west-2", "eu-west-1", "ap-southeast-2"],
);
map.insert("mistral", &["us-east-1", "us-west-2", "eu-west-1"]);
map.insert(
"meta.llama",
&["us-east-1", "us-west-2", "eu-west-1", "ap-southeast-2"],
);
map.insert(
"stability",
&["us-east-1", "us-west-2", "eu-west-1", "ap-southeast-1"],
);
map
});
pub fn validate_region(region: &str) -> Result<(), ProviderError> {
if AWS_REGIONS.contains(®ion) {
Ok(())
} else {
Err(ProviderError::configuration(
"bedrock",
format!(
"Invalid AWS region: {}. Supported regions: {:?}",
region, AWS_REGIONS
),
))
}
}
pub fn is_model_available_in_region(model_id: &str, region: &str) -> bool {
let model_prefix = extract_model_prefix(model_id);
if let Some(regions) = MODEL_REGION_MAPPING.get(model_prefix) {
regions.contains(®ion)
} else {
AWS_REGIONS.contains(®ion)
}
}
fn extract_model_prefix(model_id: &str) -> &str {
if model_id.starts_with("anthropic.claude") {
"anthropic.claude"
} else if model_id.starts_with("amazon.titan") {
"amazon.titan"
} else if model_id.starts_with("amazon.nova") {
"amazon.nova"
} else if model_id.starts_with("ai21") {
"ai21"
} else if model_id.starts_with("cohere") {
"cohere"
} else if model_id.starts_with("mistral") {
"mistral"
} else if model_id.starts_with("meta.llama") {
"meta.llama"
} else if model_id.starts_with("stability") {
"stability"
} else {
model_id.split('-').next().unwrap_or(model_id)
}
}
#[cfg(test)]
pub fn get_us_regions() -> &'static [&'static str] {
&["us-east-1", "us-east-2", "us-west-1", "us-west-2"]
}
#[cfg(test)]
pub fn get_eu_regions() -> &'static [&'static str] {
&[
"eu-west-1",
"eu-west-2",
"eu-west-3",
"eu-central-1",
"eu-central-2",
"eu-north-1",
"eu-south-1",
"eu-south-2",
]
}
#[cfg(test)]
pub fn get_ap_regions() -> &'static [&'static str] {
&[
"ap-northeast-1",
"ap-northeast-2",
"ap-northeast-3",
"ap-south-1",
"ap-south-2",
"ap-southeast-1",
"ap-southeast-2",
"ap-southeast-3",
"ap-southeast-4",
"ap-southeast-5",
]
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_region_validation() {
assert!(validate_region("us-east-1").is_ok());
assert!(validate_region("eu-west-1").is_ok());
assert!(validate_region("invalid-region").is_err());
}
#[test]
fn test_model_availability() {
assert!(is_model_available_in_region(
"anthropic.claude-3-opus",
"us-east-1"
));
assert!(is_model_available_in_region(
"anthropic.claude-3-opus",
"eu-west-1"
));
assert!(is_model_available_in_region("amazon.nova-pro", "us-east-1"));
assert!(!is_model_available_in_region(
"amazon.nova-pro",
"ap-south-1"
));
}
#[test]
fn test_model_prefix_extraction() {
assert_eq!(
extract_model_prefix("anthropic.claude-3-opus-20240229"),
"anthropic.claude"
);
assert_eq!(
extract_model_prefix("amazon.titan-text-express-v1"),
"amazon.titan"
);
assert_eq!(
extract_model_prefix("meta.llama3-70b-instruct-v1:0"),
"meta.llama"
);
}
#[test]
fn test_regional_getters() {
assert!(get_us_regions().contains(&"us-east-1"));
assert!(get_eu_regions().contains(&"eu-west-1"));
assert!(get_ap_regions().contains(&"ap-southeast-1"));
}
}