Skip to main content

walrus_model/
config.rs

1//! Provider configuration.
2//!
3//! Flat `ProviderConfig` with optional fields for both remote and local
4//! providers. Provider kind inferred from model name prefix via `kind()`.
5//! `Loader` selects which mistralrs builder to use for local models.
6
7use anyhow::{Result, bail};
8use compact_str::CompactString;
9use serde::{Deserialize, Serialize};
10
11/// Flat provider configuration. All fields except `model` are optional.
12/// Provider kind is inferred from the model name — no explicit `provider` tag.
13#[derive(Debug, Serialize, Deserialize, Clone)]
14pub struct ProviderConfig {
15    /// Model identifier. Remote models use known prefixes (`deepseek-*`,
16    /// `gpt-*`, `claude-*`, etc.). Local models use HuggingFace repo IDs
17    /// containing `/` (e.g. `microsoft/Phi-3.5-mini-instruct`).
18    pub model: CompactString,
19    /// API key for remote providers. Supports `${ENV_VAR}` expansion at the
20    /// daemon layer.
21    #[serde(default, skip_serializing_if = "Option::is_none")]
22    pub api_key: Option<String>,
23    /// Base URL override for remote providers.
24    #[serde(default, skip_serializing_if = "Option::is_none")]
25    pub base_url: Option<String>,
26    /// Mistralrs model builder to use for local models. Defaults to `Text`.
27    #[serde(default, skip_serializing_if = "Option::is_none")]
28    pub loader: Option<Loader>,
29    /// In-situ quantization for local models.
30    #[serde(default, skip_serializing_if = "Option::is_none")]
31    pub quantization: Option<QuantizationType>,
32    /// Chat template override for local models (path or inline Jinja).
33    #[serde(default, skip_serializing_if = "Option::is_none")]
34    pub chat_template: Option<String>,
35}
36
37impl ProviderConfig {
38    /// Detect the provider kind from the model name.
39    pub fn kind(&self) -> Result<ProviderKind> {
40        ProviderKind::from_model(&self.model)
41    }
42
43    /// Validate field combinations.
44    ///
45    /// Called on startup and on provider add/reload.
46    pub fn validate(&self) -> Result<()> {
47        if self.model.is_empty() {
48            bail!("model is required");
49        }
50
51        let kind = self.kind()?;
52
53        match kind {
54            ProviderKind::Local => {
55                if self.api_key.is_some() {
56                    bail!("local provider '{}' must not have api_key", self.model);
57                }
58            }
59            _ => {
60                // Remote providers: api_key is required unless base_url is set
61                // (e.g. Ollama which is keyless with a local base_url).
62                if self.api_key.is_none() && self.base_url.is_none() {
63                    bail!(
64                        "remote provider '{}' requires api_key or base_url",
65                        self.model
66                    );
67                }
68                if self.loader.is_some() {
69                    bail!(
70                        "remote provider '{}' must not have loader field",
71                        self.model
72                    );
73                }
74                if self.quantization.is_some() {
75                    bail!(
76                        "remote provider '{}' must not have quantization field",
77                        self.model
78                    );
79                }
80                if self.chat_template.is_some() {
81                    bail!(
82                        "remote provider '{}' must not have chat_template field",
83                        self.model
84                    );
85                }
86            }
87        }
88
89        Ok(())
90    }
91}
92
93/// Provider kind, inferred from the model name at runtime.
94///
95/// Not serialized — purely a dispatch enum.
96#[derive(Debug, Clone, Copy, PartialEq, Eq)]
97pub enum ProviderKind {
98    DeepSeek,
99    OpenAI,
100    Claude,
101    Grok,
102    Qwen,
103    Kimi,
104    Local,
105}
106
107impl ProviderKind {
108    /// Detect provider kind from a model name string.
109    ///
110    /// Rules:
111    /// 1. If model contains `/` → Local (HuggingFace repo ID).
112    /// 2. Otherwise, match known remote prefixes.
113    /// 3. No match → error.
114    pub fn from_model(model: &str) -> Result<Self> {
115        if model.contains('/') {
116            return Ok(Self::Local);
117        }
118
119        let prefixes: &[(&[&str], ProviderKind)] = &[
120            (&["deepseek-"], ProviderKind::DeepSeek),
121            (&["gpt-", "o1-", "o3-", "o4-"], ProviderKind::OpenAI),
122            (&["claude-"], ProviderKind::Claude),
123            (&["grok-"], ProviderKind::Grok),
124            (&["qwen-", "qwq-"], ProviderKind::Qwen),
125            (&["kimi-", "moonshot-"], ProviderKind::Kimi),
126        ];
127
128        for (patterns, kind) in prefixes {
129            for prefix in *patterns {
130                if model.starts_with(prefix) {
131                    return Ok(*kind);
132                }
133            }
134        }
135
136        bail!("unknown model prefix: '{model}' — cannot detect provider kind")
137    }
138
139    /// Human-readable name for logging.
140    pub fn as_str(self) -> &'static str {
141        match self {
142            Self::DeepSeek => "deepseek",
143            Self::OpenAI => "openai",
144            Self::Claude => "claude",
145            Self::Grok => "grok",
146            Self::Qwen => "qwen",
147            Self::Kimi => "kimi",
148            Self::Local => "local",
149        }
150    }
151}
152
153/// Selects which mistralrs model builder to use for local inference.
154///
155/// Defaults to `Text` when omitted in config.
156#[derive(Debug, Serialize, Deserialize, Clone, Copy, Default)]
157#[serde(rename_all = "snake_case")]
158pub enum Loader {
159    /// `TextModelBuilder` — standard text models.
160    #[default]
161    Text,
162    /// `LoraModelBuilder` — LoRA adapter models.
163    Lora,
164    /// `XLoraModelBuilder` — X-LoRA adapter models.
165    #[serde(rename = "xlora")]
166    XLora,
167    /// `GgufModelBuilder` — GGUF quantized models.
168    Gguf,
169    /// `GgufLoraModelBuilder` — GGUF + LoRA.
170    #[serde(rename = "gguf_lora")]
171    GgufLora,
172    /// `GgufXLoraModelBuilder` — GGUF + X-LoRA.
173    #[serde(rename = "gguf_xlora")]
174    GgufXLora,
175    /// `VisionModelBuilder` — vision-language models.
176    Vision,
177}
178
179/// Quantization types supported by mistralrs (maps to `IsqType`).
180#[derive(Debug, Serialize, Deserialize, Clone, Copy)]
181pub enum QuantizationType {
182    #[serde(rename = "q4_0")]
183    Q4_0,
184    #[serde(rename = "q4_1")]
185    Q4_1,
186    #[serde(rename = "q5_0")]
187    Q5_0,
188    #[serde(rename = "q5_1")]
189    Q5_1,
190    #[serde(rename = "q8_0")]
191    Q8_0,
192    #[serde(rename = "q8_1")]
193    Q8_1,
194    #[serde(rename = "q2k")]
195    Q2K,
196    #[serde(rename = "q3k")]
197    Q3K,
198    #[serde(rename = "q4k")]
199    Q4K,
200    #[serde(rename = "q5k")]
201    Q5K,
202    #[serde(rename = "q6k")]
203    Q6K,
204    #[serde(rename = "q8k")]
205    Q8K,
206}
207
208#[cfg(feature = "local")]
209impl QuantizationType {
210    /// Convert to the mistralrs `IsqType`.
211    pub fn to_isq(self) -> mistralrs::IsqType {
212        match self {
213            Self::Q4_0 => mistralrs::IsqType::Q4_0,
214            Self::Q4_1 => mistralrs::IsqType::Q4_1,
215            Self::Q5_0 => mistralrs::IsqType::Q5_0,
216            Self::Q5_1 => mistralrs::IsqType::Q5_1,
217            Self::Q8_0 => mistralrs::IsqType::Q8_0,
218            Self::Q8_1 => mistralrs::IsqType::Q8_1,
219            Self::Q2K => mistralrs::IsqType::Q2K,
220            Self::Q3K => mistralrs::IsqType::Q3K,
221            Self::Q4K => mistralrs::IsqType::Q4K,
222            Self::Q5K => mistralrs::IsqType::Q5K,
223            Self::Q6K => mistralrs::IsqType::Q6K,
224            Self::Q8K => mistralrs::IsqType::Q8K,
225        }
226    }
227}