Skip to main content

walrus_model/
provider.rs

1//! Provider implementation.
2//!
3//! Unified `Provider` enum with enum dispatch over concrete backends.
4//! `build_provider()` uses a URL lookup table for OpenAI-compatible kinds,
5//! eliminating repeated match arms for each variant.
6
7use crate::{
8    config::{ProviderConfig, ProviderKind},
9    remote::{
10        claude::{self, Claude},
11        openai::{self, OpenAI},
12    },
13};
14use anyhow::Result;
15use async_stream::try_stream;
16use compact_str::CompactString;
17use futures_core::Stream;
18use futures_util::StreamExt;
19use wcore::model::{Model, Response, StreamChunk};
20
21/// Unified LLM provider enum.
22///
23/// The gateway constructs the appropriate variant based on `ProviderKind`
24/// detected from the model name. The runtime is monomorphized on `Provider`.
25#[derive(Clone)]
26pub enum Provider {
27    /// OpenAI-compatible API (covers OpenAI, DeepSeek, Grok, Qwen, Kimi, Ollama).
28    OpenAI(OpenAI),
29    /// Anthropic Messages API.
30    Claude(Claude),
31    /// Local inference via mistralrs.
32    #[cfg(feature = "local")]
33    Local(crate::local::Local),
34}
35
36impl Provider {
37    /// Query the context length for a given model ID.
38    ///
39    /// Local providers delegate to mistralrs; remote providers return None
40    /// (callers fall back to the static map in `wcore::model::default_context_limit`).
41    pub fn context_length(&self, _model: &str) -> Option<usize> {
42        match self {
43            Self::OpenAI(_) | Self::Claude(_) => None,
44            #[cfg(feature = "local")]
45            Self::Local(p) => p.context_length(_model),
46        }
47    }
48}
49
50/// Construct a `Provider` from config and a shared HTTP client.
51///
52/// OpenAI-compatible kinds use a URL lookup table — no repeated arms.
53/// The `model` string from config is stored in the provider for accurate
54/// `active_model()` reporting.
55pub async fn build_provider(config: &ProviderConfig, client: reqwest::Client) -> Result<Provider> {
56    let kind = config.kind()?;
57    let api_key = config.api_key.as_deref().unwrap_or("");
58    let model = config.model.as_str();
59
60    match kind {
61        ProviderKind::Claude => {
62            let url = config.base_url.as_deref().unwrap_or(claude::ENDPOINT);
63            return Ok(Provider::Claude(Claude::custom(
64                client, api_key, url, model,
65            )?));
66        }
67        #[cfg(feature = "local")]
68        ProviderKind::Local => {
69            use crate::config::Loader;
70            let loader = config.loader.unwrap_or_default();
71            let isq = config.quantization.map(|q| q.to_isq());
72            let chat_template = config.chat_template.as_deref();
73            let local = match loader {
74                Loader::Text => {
75                    crate::local::Local::from_text(&config.model, isq, chat_template).await?
76                }
77                Loader::Gguf => {
78                    crate::local::Local::from_gguf(&config.model, chat_template).await?
79                }
80                Loader::Vision => {
81                    crate::local::Local::from_vision(&config.model, isq, chat_template).await?
82                }
83                Loader::Lora | Loader::XLora | Loader::GgufLora | Loader::GgufXLora => {
84                    anyhow::bail!(
85                        "loader {:?} requires adapter configuration (not yet supported)",
86                        loader
87                    );
88                }
89            };
90            return Ok(Provider::Local(local));
91        }
92        #[cfg(not(feature = "local"))]
93        ProviderKind::Local => {
94            anyhow::bail!("local provider requires the 'local' feature");
95        }
96        _ => {}
97    }
98
99    // All remaining kinds are OpenAI-compatible. Look up the default endpoint URL.
100    let default_url: &str = match kind {
101        ProviderKind::OpenAI => openai::endpoint::OPENAI,
102        ProviderKind::DeepSeek => openai::endpoint::DEEPSEEK,
103        ProviderKind::Grok => openai::endpoint::GROK,
104        ProviderKind::Qwen => openai::endpoint::QWEN,
105        ProviderKind::Kimi => openai::endpoint::KIMI,
106        // Claude and Local are handled above; this arm is unreachable.
107        _ => unreachable!(),
108    };
109    let url = config.base_url.as_deref().unwrap_or(default_url);
110    let provider = if api_key.is_empty() {
111        OpenAI::no_auth(client, url, model)
112    } else {
113        OpenAI::custom(client, api_key, url, model)?
114    };
115    Ok(Provider::OpenAI(provider))
116}
117
118impl Model for Provider {
119    async fn send(&self, request: &wcore::model::Request) -> Result<Response> {
120        match self {
121            Self::OpenAI(p) => p.send(request).await,
122            Self::Claude(p) => p.send(request).await,
123            #[cfg(feature = "local")]
124            Self::Local(p) => p.send(request).await,
125        }
126    }
127
128    fn stream(
129        &self,
130        request: wcore::model::Request,
131    ) -> impl Stream<Item = Result<StreamChunk>> + Send {
132        let this = self.clone();
133        try_stream! {
134            match this {
135                Provider::OpenAI(p) => {
136                    let mut stream = std::pin::pin!(p.stream(request));
137                    while let Some(chunk) = stream.next().await {
138                        yield chunk?;
139                    }
140                }
141                Provider::Claude(p) => {
142                    let mut stream = std::pin::pin!(p.stream(request));
143                    while let Some(chunk) = stream.next().await {
144                        yield chunk?;
145                    }
146                }
147                #[cfg(feature = "local")]
148                Provider::Local(p) => {
149                    let mut stream = std::pin::pin!(p.stream(request));
150                    while let Some(chunk) = stream.next().await {
151                        yield chunk?;
152                    }
153                }
154            }
155        }
156    }
157
158    fn context_limit(&self, model: &str) -> usize {
159        self.context_length(model)
160            .unwrap_or_else(|| wcore::model::default_context_limit(model))
161    }
162
163    fn active_model(&self) -> CompactString {
164        match self {
165            Self::OpenAI(p) => p.active_model(),
166            Self::Claude(p) => p.active_model(),
167            #[cfg(feature = "local")]
168            Self::Local(p) => p.active_model(),
169        }
170    }
171}