use std::path::Path;
use anyhow::bail;
use yoagent::provider::model::{CostConfig, ModelConfig};
use yoagent::types::Usage;
pub mod anthropic;
pub mod compat;
pub mod generate_models;
pub mod models;
pub mod oauth;
pub mod openai_compat;
#[derive(Debug, Clone)]
pub struct ResolvedModel {
pub model_config: ModelConfig,
pub api_key: String,
}
pub struct ProviderRegistry {
entries: Vec<models::ProviderEntry>,
auth_storage: crate::auth::AuthStorage,
}
impl ProviderRegistry {
pub fn load(agent_dir: &Path) -> anyhow::Result<Self> {
crate::provider::oauth::register_builtins();
let builtin_json = include_str!("models.json");
let builtin = models::load_builtin(builtin_json)?;
let user_path = agent_dir.join("models.json");
let user = models::load_user(&user_path)?;
let entries = models::merge(builtin, user);
let auth_storage = crate::auth::AuthStorage::load()?;
Ok(Self {
entries,
auth_storage,
})
}
pub fn reload(&mut self, agent_dir: &Path) -> anyhow::Result<()> {
let fresh = Self::load(agent_dir)?;
self.entries = fresh.entries;
self.auth_storage = fresh.auth_storage;
Ok(())
}
pub fn resolve(
&self,
model_id: &str,
preferred_provider: Option<&str>,
) -> anyhow::Result<ResolvedModel> {
if let Some(preferred) = preferred_provider
&& let Some(result) = self.resolve_from_provider(model_id, preferred)
{
return Ok(result);
}
for entry in &self.entries {
if let Some(model_config) = entry.models.iter().find(|m| m.id == model_id) {
let api_key = self
.auth_storage
.api_key(&entry.id)
.or_else(|| {
self.auth_storage.oauth_token(&entry.id)
})
.or_else(|| {
let env_var = entry.env_var_name();
std::env::var(env_var).ok()
})
.unwrap_or_default();
let mut model_config = model_config.clone();
if entry.id == "github-copilot" {
let enterprise_domain =
self.auth_storage
.oauth_credential(&entry.id)
.and_then(|c| match c {
crate::auth::AuthCredential::Oauth { enterprise_url, .. } => {
enterprise_url
}
_ => None,
});
let derived = crate::provider::oauth::github_copilot::get_copilot_base_url(
Some(&api_key),
enterprise_domain.as_deref(),
);
model_config.base_url = derived;
}
return Ok(ResolvedModel {
model_config,
api_key,
});
}
}
bail!(
"Unknown model '{}'. Available models: {}",
model_id,
self.list_models().join(", ")
);
}
fn resolve_from_provider(&self, model_id: &str, provider_id: &str) -> Option<ResolvedModel> {
let entry = self.entries.iter().find(|e| e.id == provider_id)?;
let mut model_config = entry.models.iter().find(|m| m.id == model_id)?.clone();
let api_key = self
.auth_storage
.api_key(provider_id)
.or_else(|| {
self.auth_storage.oauth_token(provider_id)
})
.or_else(|| {
let env_var = entry.env_var_name();
std::env::var(env_var).ok()
})
.unwrap_or_default();
if provider_id == "github-copilot" {
let enterprise_domain = self
.auth_storage
.oauth_credential(provider_id)
.and_then(|c| match c {
crate::auth::AuthCredential::Oauth { enterprise_url, .. } => enterprise_url,
_ => None,
});
let derived = crate::provider::oauth::github_copilot::get_copilot_base_url(
Some(&api_key),
enterprise_domain.as_deref(),
);
model_config.base_url = derived;
}
Some(ResolvedModel {
model_config,
api_key,
})
}
pub fn list_models(&self) -> Vec<String> {
let mut model_set = std::collections::BTreeSet::new();
for entry in &self.entries {
for m in &entry.models {
model_set.insert(m.id.clone());
}
}
model_set.into_iter().collect()
}
pub fn list_authenticated_model_ids(&self) -> Vec<String> {
let mut model_set = std::collections::BTreeSet::new();
for entry in &self.entries {
if self.provider_has_auth(&entry.id) {
for m in &entry.models {
model_set.insert(m.id.clone());
}
}
}
model_set.into_iter().collect()
}
pub fn list_model_provider_tuples(&self) -> Vec<(String, String, String)> {
let mut result = Vec::new();
for entry in &self.entries {
for m in &entry.models {
result.push((entry.id.clone(), m.id.clone(), m.name.clone()));
}
}
result
}
pub fn provider_for_model(
&self,
model_id: &str,
preferred_provider: Option<&str>,
) -> Option<String> {
if let Some(preferred) = preferred_provider
&& self
.entries
.iter()
.any(|e| e.id == preferred && e.models.iter().any(|m| m.id == model_id))
{
return Some(preferred.to_string());
}
for entry in &self.entries {
if entry.models.iter().any(|m| m.id == model_id) {
return Some(entry.id.clone());
}
}
None
}
pub fn api_key_for_provider(&self, provider_id: &str) -> Option<String> {
self.auth_storage.api_key(provider_id)
}
pub fn count_providers(&self) -> usize {
self.entries.len()
}
pub fn list_providers(&self) -> Vec<(String, String)> {
self.entries
.iter()
.map(|e| (e.id.clone(), e.name.clone()))
.collect()
}
pub fn configured_providers(&self) -> Vec<String> {
self.entries
.iter()
.filter_map(|e| {
if self.auth_storage.api_key(&e.id).is_some() {
Some(e.id.clone())
} else {
None
}
})
.collect()
}
pub fn provider_has_auth(&self, provider_id: &str) -> bool {
if self.auth_storage.api_key(provider_id).is_some()
|| self.auth_storage.oauth_token(provider_id).is_some()
{
return true;
}
if crate::provider::oauth::is_built_in(provider_id) {
return self.auth_storage.oauth_token(provider_id).is_some();
}
self.entries
.iter()
.find(|e| e.id == provider_id)
.and_then(|e| {
let env_name = e.env_var_name();
if std::env::var(env_name).is_ok() {
Some(())
} else {
None
}
})
.is_some()
}
pub fn auth_status_for_provider(
&self,
provider_id: &str,
) -> crate::agent::ui::components::oauth_selector::ProviderAuthStatus {
let has_stored = self.auth_storage.api_key(provider_id).is_some()
|| self.auth_storage.oauth_token(provider_id).is_some();
let env_var = self
.entries
.iter()
.find(|e| e.id == provider_id)
.and_then(|e| {
let env_name = e.env_var_name();
if std::env::var(env_name).is_ok() {
Some(env_name.to_string())
} else {
None
}
});
let configured = has_stored || env_var.is_some();
let (source, label) = if has_stored {
(Some("stored".to_string()), None)
} else if let Some(env) = env_var {
(Some("environment".to_string()), Some(env))
} else {
(None, None)
};
crate::agent::ui::components::oauth_selector::ProviderAuthStatus {
configured,
source,
label,
}
}
}
pub fn calculate_cost(cost_config: &CostConfig, usage: &Usage) -> (f64, f64, f64, f64, f64) {
let input_cost = (cost_config.input_per_million / 1_000_000.0) * usage.input as f64;
let output_cost = (cost_config.output_per_million / 1_000_000.0) * usage.output as f64;
let cache_read_cost =
(cost_config.cache_read_per_million / 1_000_000.0) * usage.cache_read as f64;
let cache_write_cost =
(cost_config.cache_write_per_million / 1_000_000.0) * usage.cache_write as f64;
let total = input_cost + output_cost + cache_read_cost + cache_write_cost;
(
input_cost,
output_cost,
cache_read_cost,
cache_write_cost,
total,
)
}
pub fn get_agent_dir() -> std::path::PathBuf {
directories::BaseDirs::new()
.map(|d| d.home_dir().join(".rab").join("agent"))
.unwrap_or_else(|| std::path::PathBuf::from("/tmp/.rab/agent"))
}