Skip to main content

distri_types/
models.rs

1//! Core model, provider, and audio types used across the entire system.
2
3use serde::{Deserialize, Serialize};
4
5// ── Provider identity ───────────────────────────────────────────────────
6
7/// Known provider types. Used for identity and dispatch.
8#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
9#[serde(rename_all = "snake_case")]
10pub enum ProviderType {
11    #[serde(rename = "openai")]
12    OpenAI,
13    Anthropic,
14    Azure,
15    Gemini,
16    AzureAiFoundry,
17    AwsBedrock,
18    GoogleVertex,
19    AlibabaCloud,
20    #[serde(rename = "elevenlabs")]
21    ElevenLabs,
22    #[serde(rename = "fal_ai")]
23    FalAi,
24    /// User-defined provider (LangDB-compatible / OpenAI-compatible)
25    #[serde(untagged)]
26    Custom(String),
27}
28
29impl ProviderType {
30    pub fn as_str(&self) -> &str {
31        match self {
32            Self::OpenAI => "openai",
33            Self::Anthropic => "anthropic",
34            Self::Azure => "azure",
35            Self::Gemini => "gemini",
36            Self::AzureAiFoundry => "azure_ai_foundry",
37            Self::AwsBedrock => "aws_bedrock",
38            Self::GoogleVertex => "google_vertex",
39            Self::AlibabaCloud => "alibaba_cloud",
40            Self::ElevenLabs => "elevenlabs",
41            Self::FalAi => "fal_ai",
42            Self::Custom(id) => id.as_str(),
43        }
44    }
45
46    pub fn display_name(&self) -> &str {
47        match self {
48            Self::OpenAI => "OpenAI",
49            Self::Anthropic => "Anthropic",
50            Self::Azure => "Azure",
51            Self::Gemini => "Google Gemini",
52            Self::AzureAiFoundry => "Azure AI Foundry",
53            Self::AwsBedrock => "AWS Bedrock",
54            Self::GoogleVertex => "Google Vertex AI",
55            Self::AlibabaCloud => "Alibaba Cloud",
56            Self::ElevenLabs => "ElevenLabs",
57            Self::FalAi => "fal.ai",
58            Self::Custom(id) => id.as_str(),
59        }
60    }
61
62    pub fn from_id(id: &str) -> Self {
63        match id {
64            "openai" => Self::OpenAI,
65            "anthropic" => Self::Anthropic,
66            "azure" | "azure_openai" | "azure_speech" => Self::Azure,
67            "gemini" => Self::Gemini,
68            "azure_ai_foundry" => Self::AzureAiFoundry,
69            "aws_bedrock" => Self::AwsBedrock,
70            "google_vertex" => Self::GoogleVertex,
71            "alibaba_cloud" => Self::AlibabaCloud,
72            "elevenlabs" => Self::ElevenLabs,
73            "fal_ai" => Self::FalAi,
74            other => Self::Custom(other.to_string()),
75        }
76    }
77}
78
79impl std::fmt::Display for ProviderType {
80    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
81        f.write_str(self.as_str())
82    }
83}
84
85// ── Model types ─────────────────────────────────────────────────────────
86
87/// What a model can do.
88#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
89#[serde(rename_all = "snake_case")]
90pub enum ModelCapability {
91    Completion,
92    Tts,
93    Stt,
94    Image,
95}
96
97/// Pricing varies by capability type.
98#[derive(Debug, Clone, Serialize, Deserialize)]
99#[serde(tag = "type", rename_all = "snake_case")]
100pub enum ModelPricing {
101    /// Completion model pricing — per 1M tokens (USD).
102    Completion {
103        input: f64,
104        output: f64,
105        #[serde(default, skip_serializing_if = "Option::is_none")]
106        cached_input: Option<f64>,
107    },
108    /// TTS pricing — per 1M characters (USD).
109    Tts { per_1m_chars: f64 },
110    /// STT pricing — per minute of audio (USD).
111    Stt { per_minute: f64 },
112    /// Image generation pricing — per image (USD), with optional per-quality
113    /// overrides keyed by quality tier name (`"low"` / `"medium"` / `"high"`
114    /// for gpt-image-1, `"standard"` / `"hd"` for dall-e-3, etc.).
115    Image {
116        per_image: f64,
117        #[serde(default, skip_serializing_if = "std::collections::BTreeMap::is_empty")]
118        per_quality: std::collections::BTreeMap<String, f64>,
119    },
120}
121
122/// A model with its capability, pricing, and metadata.
123#[derive(Debug, Clone, Serialize, Deserialize)]
124pub struct Model {
125    pub id: String,
126    /// Human-readable name. Optional in config sources — when omitted it is
127    /// backfilled from `id` by `register_provider_extensions`.
128    #[serde(default)]
129    pub name: String,
130    pub capability: ModelCapability,
131    #[serde(default, skip_serializing_if = "Option::is_none")]
132    pub context_window: Option<u32>,
133    #[serde(default, skip_serializing_if = "Option::is_none")]
134    pub pricing: Option<ModelPricing>,
135    #[serde(default, skip_serializing_if = "Vec::is_empty")]
136    pub voices: Vec<TtsVoiceInfo>,
137    #[serde(default, skip_serializing_if = "Vec::is_empty")]
138    pub formats: Vec<String>,
139}
140
141/// A model with denormalized provider info — returned by GET /v1/models.
142#[derive(Debug, Clone, Serialize, Deserialize)]
143pub struct ModelWithProvider {
144    #[serde(flatten)]
145    pub model: Model,
146    pub provider_id: String,
147    pub provider_label: String,
148    pub configured: bool,
149}
150
151// ── Provider definition ─────────────────────────────────────────────────
152
153/// Secret key definition for a provider.
154#[derive(Debug, Clone, Serialize, Deserialize)]
155pub struct ProviderKeyDefinition {
156    pub key: String,
157    pub label: String,
158    #[serde(default)]
159    pub placeholder: String,
160    #[serde(default = "default_true")]
161    pub required: bool,
162    #[serde(default = "default_true")]
163    pub sensitive: bool,
164    /// When set, the UI renders this field as a resource segment embedded in
165    /// the URL template (`{}` marks the editable segment), showing the full
166    /// endpoint read-only around it. Azure AI Foundry uses this: the user
167    /// edits only the resource name and that is all we store.
168    #[serde(default, skip_serializing_if = "Option::is_none")]
169    pub url_template: Option<String>,
170}
171
172/// A provider definition with its keys and available models.
173#[derive(Debug, Clone, Serialize, Deserialize)]
174pub struct ModelProviderDefinition {
175    pub id: String,
176    pub label: String,
177    pub keys: Vec<ProviderKeyDefinition>,
178    pub models: Vec<Model>,
179    #[serde(default)]
180    pub is_custom: bool,
181    /// Per-provider override of how `/v1/providers/test` validates the API
182    /// key. When omitted, the test endpoint probes `GET {base_url}/models`.
183    /// fal.ai sets this because it has no `/models` listing endpoint.
184    #[serde(default, skip_serializing_if = "Option::is_none")]
185    pub test: Option<ProviderTestConfig>,
186}
187
188/// Per-provider override of the `/v1/providers/test` probe.
189///
190/// Default behavior (when omitted): `GET {base_url}/models` with both
191/// `Authorization: Bearer <key>` and `api-key: <key>` headers, parsing
192/// `{data: [{id}]}`.
193///
194/// Set this when a provider has no `/models` listing endpoint (fal.ai).
195/// The probe sends the configured request and treats any response status
196/// outside the configured fail set (default: 401/403) as proof the auth
197/// header was accepted.
198#[derive(Debug, Clone, Serialize, Deserialize)]
199pub struct ProviderTestConfig {
200    /// Full URL, or a template containing `{base_url}`.
201    pub url: String,
202    /// HTTP method. Default `GET`.
203    #[serde(default = "default_test_method")]
204    pub method: String,
205    /// Auth header style: `bearer` (default), `key` (fal.ai), or `api_key`.
206    #[serde(default = "default_test_auth")]
207    pub auth: String,
208    /// Optional JSON body (POST/PUT). For fal.ai we send a body that fails
209    /// validation (`{}`) so we never pay for a generation.
210    #[serde(default, skip_serializing_if = "Option::is_none")]
211    pub body: Option<serde_json::Value>,
212    /// HTTP status codes that count as success. When empty (default), any
213    /// status other than 401/403 passes — the auth header reached the server.
214    #[serde(default, skip_serializing_if = "Vec::is_empty")]
215    pub accept_status: Vec<u16>,
216}
217
218fn default_test_method() -> String {
219    "GET".to_string()
220}
221
222fn default_test_auth() -> String {
223    "bearer".to_string()
224}
225
226// ── TTS voice info ──────────────────────────────────────────────────────
227
228/// Information about a TTS voice.
229#[derive(Debug, Clone, Serialize, Deserialize)]
230pub struct TtsVoiceInfo {
231    pub id: String,
232    pub name: String,
233    #[serde(skip_serializing_if = "Option::is_none")]
234    pub description: Option<String>,
235    #[serde(default, skip_serializing_if = "Vec::is_empty")]
236    pub languages: Vec<String>,
237}
238
239fn default_true() -> bool {
240    true
241}