use std::future::Future;
use std::path::PathBuf;
use std::pin::Pin;
use serde::{Deserialize, Serialize};
use tokio::sync::broadcast;
use crate::error::SdkResult;
use super::AuthMethod;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum CatalogProtocol {
AnthropicMessages,
OpenAiCompletions,
OpenAiResponses,
AzureOpenAiResponses,
GoogleGenerativeAi,
GoogleVertex,
MistralConversations,
BedrockConverseStream,
OpenAiCompatible,
}
impl CatalogProtocol {
pub fn default_auth(&self) -> AuthMethod {
match self {
CatalogProtocol::AnthropicMessages => AuthMethod::XApiKey,
CatalogProtocol::AzureOpenAiResponses => AuthMethod::ApiKey,
CatalogProtocol::GoogleGenerativeAi
| CatalogProtocol::GoogleVertex
| CatalogProtocol::BedrockConverseStream => AuthMethod::None,
CatalogProtocol::OpenAiCompletions
| CatalogProtocol::OpenAiResponses
| CatalogProtocol::OpenAiCompatible
| CatalogProtocol::MistralConversations => AuthMethod::Bearer,
}
}
pub fn as_oxi_api(&self) -> oxi_ai::Api {
use CatalogProtocol::*;
match self {
AnthropicMessages => oxi_ai::Api::AnthropicMessages,
OpenAiCompletions => oxi_ai::Api::OpenAiCompletions,
OpenAiResponses => oxi_ai::Api::OpenAiResponses,
AzureOpenAiResponses => oxi_ai::Api::AzureOpenAiResponses,
GoogleGenerativeAi => oxi_ai::Api::GoogleGenerativeAi,
GoogleVertex => oxi_ai::Api::GoogleVertex,
MistralConversations => oxi_ai::Api::MistralConversations,
BedrockConverseStream => oxi_ai::Api::BedrockConverseStream,
OpenAiCompatible => oxi_ai::Api::OpenAiCompletions,
}
}
pub fn as_str(&self) -> &'static str {
use CatalogProtocol::*;
match self {
AnthropicMessages => "anthropic-messages",
OpenAiCompletions => "openai-completions",
OpenAiResponses => "openai-responses",
AzureOpenAiResponses => "azure-openai-responses",
GoogleGenerativeAi => "google-generative-ai",
GoogleVertex => "google-vertex",
MistralConversations => "mistral-conversations",
BedrockConverseStream => "bedrock-converse-stream",
OpenAiCompatible => "openai-compatible",
}
}
}
impl std::fmt::Display for CatalogProtocol {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.as_str())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CatalogModelEntry {
pub provider: String,
pub model_id: String,
pub name: String,
pub protocol: CatalogProtocol,
pub source: CatalogSource,
pub base_url: Option<String>,
pub reasoning: bool,
pub supports_vision: bool,
pub cost_input: f64,
pub cost_output: f64,
pub cost_cache_read: f64,
pub cost_cache_write: f64,
pub context_window: u32,
pub max_tokens: u32,
pub input_modalities: Vec<String>,
pub release_date: Option<String>,
pub status: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CatalogProviderEntry {
pub id: String,
pub display_name: String,
pub aliases: Vec<String>,
pub protocol: CatalogProtocol,
pub env_key: Option<String>,
pub extra_env_keys: Vec<String>,
pub base_url: Option<String>,
pub extra_headers: Vec<(String, String)>,
pub category: String,
pub description: String,
pub default_enabled: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum RefreshOutcome {
Unchanged,
Updated {
provider_count: usize,
model_count: usize,
},
Offline { reason: &'static str },
Failed { reason: String },
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum CatalogSource {
Embedded,
Cache,
Live,
Local,
Override,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum CatalogEvent {
Updated {
provider_count: usize,
model_count: usize,
},
RefreshFailed {
reason: String,
provider_count: usize,
model_count: usize,
},
OverrideApplied {
path: PathBuf,
provider_overrides: usize,
model_overrides: usize,
},
LocalDiscovered {
base_url: String,
model_count: usize,
},
}
pub trait ModelCatalog: Send + Sync + 'static {
fn list_providers(&self) -> Pin<Box<dyn Future<Output = SdkResult<Vec<String>>> + Send + '_>>;
fn get_provider(
&self,
provider_id: &str,
) -> Pin<Box<dyn Future<Output = SdkResult<Option<CatalogProviderEntry>>> + Send + '_>>;
fn list_models(
&self,
provider_id: &str,
) -> Pin<Box<dyn Future<Output = SdkResult<Vec<CatalogModelEntry>>> + Send + '_>>;
fn get_model(
&self,
provider_id: &str,
model_id: &str,
) -> Pin<Box<dyn Future<Output = SdkResult<Option<CatalogModelEntry>>> + Send + '_>>;
fn search(
&self,
pattern: &str,
) -> Pin<Box<dyn Future<Output = SdkResult<Vec<CatalogModelEntry>>> + Send + '_>>;
fn model_count(&self) -> Pin<Box<dyn Future<Output = SdkResult<usize>> + Send + '_>>;
fn refresh(&self) -> Pin<Box<dyn Future<Output = SdkResult<RefreshOutcome>> + Send + '_>>;
fn subscribe(&self) -> broadcast::Receiver<CatalogEvent>;
fn list_providers_sync(&self) -> Vec<String> {
Vec::new()
}
fn get_provider_sync(&self, _provider_id: &str) -> Option<CatalogProviderEntry> {
None
}
fn list_models_sync(&self, _provider_id: &str) -> Vec<CatalogModelEntry> {
Vec::new()
}
fn get_model_sync(&self, _provider_id: &str, _model_id: &str) -> Option<CatalogModelEntry> {
None
}
fn search_sync(&self, _pattern: &str) -> Vec<CatalogModelEntry> {
Vec::new()
}
fn model_count_sync(&self) -> usize {
0
}
}
pub struct NoopModelCatalog {
tx: broadcast::Sender<CatalogEvent>,
}
impl NoopModelCatalog {
pub fn new() -> std::sync::Arc<Self> {
let (tx, _) = broadcast::channel(16);
std::sync::Arc::new(Self { tx })
}
}
impl std::fmt::Debug for NoopModelCatalog {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("NoopModelCatalog").finish_non_exhaustive()
}
}
impl ModelCatalog for NoopModelCatalog {
fn list_providers(&self) -> Pin<Box<dyn Future<Output = SdkResult<Vec<String>>> + Send + '_>> {
Box::pin(async { Ok(vec![]) })
}
fn get_provider(
&self,
_: &str,
) -> Pin<Box<dyn Future<Output = SdkResult<Option<CatalogProviderEntry>>> + Send + '_>> {
Box::pin(async { Ok(None) })
}
fn list_models(
&self,
_: &str,
) -> Pin<Box<dyn Future<Output = SdkResult<Vec<CatalogModelEntry>>> + Send + '_>> {
Box::pin(async { Ok(vec![]) })
}
fn get_model(
&self,
_: &str,
_: &str,
) -> Pin<Box<dyn Future<Output = SdkResult<Option<CatalogModelEntry>>> + Send + '_>> {
Box::pin(async { Ok(None) })
}
fn search(
&self,
_: &str,
) -> Pin<Box<dyn Future<Output = SdkResult<Vec<CatalogModelEntry>>> + Send + '_>> {
Box::pin(async { Ok(vec![]) })
}
fn model_count(&self) -> Pin<Box<dyn Future<Output = SdkResult<usize>> + Send + '_>> {
Box::pin(async { Ok(0) })
}
fn refresh(&self) -> Pin<Box<dyn Future<Output = SdkResult<RefreshOutcome>> + Send + '_>> {
Box::pin(async { Ok(RefreshOutcome::Unchanged) })
}
fn subscribe(&self) -> broadcast::Receiver<CatalogEvent> {
self.tx.subscribe()
}
}