use std::sync::Arc;
use async_trait::async_trait;
use tracing::debug;
#[cfg(not(feature = "bedrock"))]
use tracing::info;
#[cfg(feature = "bedrock")]
use tracing::warn;
use super::{
ANTHROPIC_MODEL_PREFIX, AnthropicProvider, BEDROCK_MODEL_PREFIX, LlmProvider,
OPENROUTER_MODEL_PREFIX, OpenRouterProvider, ProviderKind, SmLlmError,
};
use crate::core::sm::config::SmInferenceConfig;
#[cfg(feature = "bedrock")]
use super::BedrockProvider;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SmModelTier {
Orchestration,
Summary,
Compaction,
}
pub fn resolve_provider_and_model(
model: &str,
default_provider: ProviderKind,
) -> (ProviderKind, String) {
if let Some(bare) = model.strip_prefix(ANTHROPIC_MODEL_PREFIX) {
return (ProviderKind::Anthropic, bare.to_string());
}
if let Some(bare) = model.strip_prefix(BEDROCK_MODEL_PREFIX) {
return (ProviderKind::Bedrock, bare.to_string());
}
if let Some(bare) = model.strip_prefix(OPENROUTER_MODEL_PREFIX) {
return (ProviderKind::OpenRouter, bare.to_string());
}
(default_provider, model.to_string())
}
pub fn resolve_tier_model(
cfg: &SmInferenceConfig,
tier: SmModelTier,
) -> Result<String, SmLlmError> {
let chosen = match tier {
SmModelTier::Orchestration => {
if !cfg.sm_model.trim().is_empty() {
cfg.sm_model.clone()
} else {
cfg.model.clone()
}
}
SmModelTier::Summary => cfg.summary_model.clone(),
SmModelTier::Compaction => {
if !cfg.compaction_model.trim().is_empty() {
cfg.compaction_model.clone()
} else {
cfg.summary_model.clone()
}
}
};
if chosen.trim().is_empty() {
return Err(SmLlmError::Validation(format!(
"no model configured for the {tier:?} tier \
(set sm_model/summary_model/compaction_model in [session_manager.inference])"
)));
}
Ok(chosen)
}
pub struct ResolvedCall {
pub provider: Arc<dyn LlmProvider>,
pub model: String,
pub kind: ProviderKind,
}
impl std::fmt::Debug for ResolvedCall {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ResolvedCall")
.field("kind", &self.kind)
.field("model", &self.model)
.field("provider", &self.provider.name())
.finish()
}
}
#[async_trait]
pub trait TierResolver: Send + Sync {
async fn resolve(
&self,
cfg: &SmInferenceConfig,
tier: SmModelTier,
) -> Result<ResolvedCall, SmLlmError>;
}
#[async_trait]
impl TierResolver for ProviderRegistry {
async fn resolve(
&self,
cfg: &SmInferenceConfig,
tier: SmModelTier,
) -> Result<ResolvedCall, SmLlmError> {
self.build(cfg, tier).await
}
}
#[derive(Debug, Clone, Default)]
pub struct ProviderRegistry {
pub anthropic_api_key: Option<String>,
pub aws_credentials_available: bool,
pub openrouter_api_key: Option<String>,
}
impl ProviderRegistry {
pub fn from_env() -> Self {
let non_empty = |k: &str| std::env::var(k).ok().filter(|v| !v.trim().is_empty());
let aws_credentials_available = ["AWS_ACCESS_KEY_ID", "AWS_PROFILE", "AWS_ROLE_ARN"]
.iter()
.any(|k| non_empty(k).is_some());
if !aws_credentials_available {
let aws_hosted_markers = [
"AWS_EXECUTION_ENV",
"ECS_CONTAINER_METADATA_URI",
"ECS_CONTAINER_METADATA_URI_V4",
"AWS_CONTAINER_CREDENTIALS_RELATIVE_URI",
"AWS_WEB_IDENTITY_TOKEN_FILE",
];
if aws_hosted_markers
.iter()
.any(|k| std::env::var_os(k).is_some())
{
debug!(
"AWS-hosted environment detected but no file/role-based AWS credentials \
visible to the env heuristic; set [session_manager.inference].provider = \
\"bedrock\" explicitly to pin Bedrock"
);
}
}
Self {
anthropic_api_key: non_empty("ANTHROPIC_API_KEY"),
aws_credentials_available,
openrouter_api_key: non_empty("OPENROUTER_API_KEY"),
}
}
pub async fn build(
&self,
cfg: &SmInferenceConfig,
tier: SmModelTier,
) -> Result<ResolvedCall, SmLlmError> {
let (effective_kind, bare_model) = self.resolve_kind_and_model(cfg, tier)?;
let provider = self.construct(effective_kind, &bare_model).await?;
Ok(ResolvedCall {
provider,
model: bare_model,
kind: effective_kind,
})
}
fn resolve_kind_and_model(
&self,
cfg: &SmInferenceConfig,
tier: SmModelTier,
) -> Result<(ProviderKind, String), SmLlmError> {
let default_provider = ProviderKind::parse(&cfg.provider)?;
let tier_model = resolve_tier_model(cfg, tier)?;
let (routed_kind, bare_model) = resolve_provider_and_model(&tier_model, default_provider);
let effective_kind = match routed_kind {
ProviderKind::Auto => self.auto_precedence()?,
explicit => explicit,
};
debug!(
?tier,
requested = %tier_model,
?effective_kind,
bare_model = %bare_model,
"sm resolve provider+model"
);
Ok((effective_kind, bare_model))
}
fn auto_precedence(&self) -> Result<ProviderKind, SmLlmError> {
if self.anthropic_api_key.is_some() {
return Ok(ProviderKind::Anthropic);
}
#[cfg(feature = "bedrock")]
if self.aws_credentials_available {
return Ok(ProviderKind::Bedrock);
}
if self.openrouter_api_key.is_some() {
return Ok(ProviderKind::OpenRouter);
}
Err(SmLlmError::Degraded(
"no ANTHROPIC_API_KEY, AWS credentials, or OPENROUTER_API_KEY available".to_string(),
))
}
async fn construct(
&self,
kind: ProviderKind,
bare_model: &str,
) -> Result<Arc<dyn LlmProvider>, SmLlmError> {
match kind {
ProviderKind::Anthropic => {
let key = self.anthropic_api_key.clone().ok_or_else(|| {
SmLlmError::Degraded(
"anthropic provider selected but ANTHROPIC_API_KEY is not set".to_string(),
)
})?;
Ok(Arc::new(AnthropicProvider::new(
key,
bare_model.to_string(),
)?))
}
ProviderKind::OpenRouter => {
let key = self.openrouter_api_key.clone().ok_or_else(|| {
SmLlmError::Degraded(
"openrouter provider selected but OPENROUTER_API_KEY is not set"
.to_string(),
)
})?;
Ok(Arc::new(OpenRouterProvider::new(
key,
bare_model.to_string(),
)?))
}
ProviderKind::Bedrock => self.construct_bedrock(bare_model).await,
ProviderKind::Auto => Err(SmLlmError::Validation(
"internal: ProviderKind::Auto must be resolved before construction".to_string(),
)),
}
}
#[cfg(feature = "bedrock")]
async fn construct_bedrock(
&self,
bare_model: &str,
) -> Result<Arc<dyn LlmProvider>, SmLlmError> {
if !self.aws_credentials_available {
warn!("bedrock selected but no AWS credentials detected; attempting SDK chain anyway");
}
let provider = BedrockProvider::new(bare_model.to_string(), None).await?;
Ok(Arc::new(provider))
}
#[cfg(not(feature = "bedrock"))]
async fn construct_bedrock(
&self,
_bare_model: &str,
) -> Result<Arc<dyn LlmProvider>, SmLlmError> {
info!("bedrock model selected but the `bedrock` cargo feature is not enabled");
Err(SmLlmError::Validation(
"bedrock provider requires building trusty-mpm with `--features bedrock`".to_string(),
))
}
}
#[cfg(test)]
#[path = "resolve_tests.rs"]
mod tests;