Skip to main content

walrus_model/
provider.rs

1//! Provider implementation.
2//!
3//! Unified `Provider` enum with enum dispatch over concrete backends.
4//! `build_provider()` matches on `ProviderKind` detected from the model name.
5
6use crate::claude::Claude;
7use crate::config::{ProviderConfig, ProviderKind};
8use crate::deepseek::DeepSeek;
9use crate::openai::OpenAI;
10use anyhow::Result;
11use async_stream::try_stream;
12use compact_str::CompactString;
13use futures_core::Stream;
14use futures_util::StreamExt;
15use wcore::model::{Model, Response, StreamChunk};
16
17/// Unified LLM provider enum.
18///
19/// The gateway constructs the appropriate variant based on `ProviderKind`
20/// detected from the model name. The runtime is monomorphized on `Provider`.
21#[derive(Clone)]
22pub enum Provider {
23    /// DeepSeek API.
24    DeepSeek(DeepSeek),
25    /// OpenAI-compatible API (covers OpenAI, Grok, Qwen, Kimi, Ollama).
26    OpenAI(OpenAI),
27    /// Anthropic Messages API.
28    Claude(Claude),
29    /// Local inference via mistralrs.
30    #[cfg(feature = "local")]
31    Local(crate::local::Local),
32}
33
34impl Provider {
35    /// Query the context length for a given model ID.
36    ///
37    /// Local providers delegate to mistralrs; remote providers return None
38    /// (callers fall back to the static map in `wcore::model::default_context_limit`).
39    pub fn context_length(&self, _model: &str) -> Option<usize> {
40        match self {
41            Self::DeepSeek(_) | Self::OpenAI(_) | Self::Claude(_) => None,
42            #[cfg(feature = "local")]
43            Self::Local(p) => p.context_length(_model),
44        }
45    }
46}
47
48/// Construct a `Provider` from config and a shared HTTP client.
49///
50/// This function is async because local providers need to load model weights
51/// asynchronously.
52pub async fn build_provider(config: &ProviderConfig, client: reqwest::Client) -> Result<Provider> {
53    let kind = config.kind()?;
54    let api_key = config.api_key.as_deref().unwrap_or("");
55    let base_url = config.base_url.as_deref();
56
57    let provider = match kind {
58        ProviderKind::DeepSeek => match base_url {
59            Some(url) => Provider::OpenAI(OpenAI::custom(client, api_key, url)?),
60            None => Provider::DeepSeek(DeepSeek::new(client, api_key)?),
61        },
62        ProviderKind::OpenAI => match base_url {
63            Some(url) => Provider::OpenAI(OpenAI::custom(client, api_key, url)?),
64            None => Provider::OpenAI(OpenAI::api(client, api_key)?),
65        },
66        ProviderKind::Claude => match base_url {
67            Some(url) => Provider::Claude(Claude::custom(client, api_key, url)?),
68            None => Provider::Claude(Claude::anthropic(client, api_key)?),
69        },
70        ProviderKind::Grok => match base_url {
71            Some(url) => Provider::OpenAI(OpenAI::custom(client, api_key, url)?),
72            None => Provider::OpenAI(OpenAI::grok(client, api_key)?),
73        },
74        ProviderKind::Qwen => match base_url {
75            Some(url) => Provider::OpenAI(OpenAI::custom(client, api_key, url)?),
76            None => Provider::OpenAI(OpenAI::qwen(client, api_key)?),
77        },
78        ProviderKind::Kimi => match base_url {
79            Some(url) => Provider::OpenAI(OpenAI::custom(client, api_key, url)?),
80            None => Provider::OpenAI(OpenAI::kimi(client, api_key)?),
81        },
82        #[cfg(feature = "local")]
83        ProviderKind::Local => {
84            use crate::config::Loader;
85            let loader = config.loader.unwrap_or_default();
86            let isq = config.quantization.map(|q| q.to_isq());
87            let chat_template = config.chat_template.as_deref();
88            let local = match loader {
89                Loader::Text => {
90                    crate::local::Local::from_text(&config.model, isq, chat_template).await?
91                }
92                Loader::Gguf => {
93                    crate::local::Local::from_gguf(&config.model, chat_template).await?
94                }
95                Loader::Vision => {
96                    crate::local::Local::from_vision(&config.model, isq, chat_template).await?
97                }
98                Loader::Lora | Loader::XLora | Loader::GgufLora | Loader::GgufXLora => {
99                    anyhow::bail!(
100                        "loader {:?} requires adapter configuration (not yet supported)",
101                        loader
102                    );
103                }
104            };
105            Provider::Local(local)
106        }
107        #[cfg(not(feature = "local"))]
108        ProviderKind::Local => {
109            anyhow::bail!("local provider requires the 'local' feature");
110        }
111    };
112    Ok(provider)
113}
114
115impl Model for Provider {
116    async fn send(&self, request: &wcore::model::Request) -> Result<Response> {
117        match self {
118            Self::DeepSeek(p) => p.send(request).await,
119            Self::OpenAI(p) => p.send(request).await,
120            Self::Claude(p) => p.send(request).await,
121            #[cfg(feature = "local")]
122            Self::Local(p) => p.send(request).await,
123        }
124    }
125
126    fn stream(
127        &self,
128        request: wcore::model::Request,
129    ) -> impl Stream<Item = Result<StreamChunk>> + Send {
130        let this = self.clone();
131        try_stream! {
132            match this {
133                Provider::DeepSeek(p) => {
134                    let mut stream = std::pin::pin!(p.stream(request));
135                    while let Some(chunk) = stream.next().await {
136                        yield chunk?;
137                    }
138                }
139                Provider::OpenAI(p) => {
140                    let mut stream = std::pin::pin!(p.stream(request));
141                    while let Some(chunk) = stream.next().await {
142                        yield chunk?;
143                    }
144                }
145                Provider::Claude(p) => {
146                    let mut stream = std::pin::pin!(p.stream(request));
147                    while let Some(chunk) = stream.next().await {
148                        yield chunk?;
149                    }
150                }
151                #[cfg(feature = "local")]
152                Provider::Local(p) => {
153                    let mut stream = std::pin::pin!(p.stream(request));
154                    while let Some(chunk) = stream.next().await {
155                        yield chunk?;
156                    }
157                }
158            }
159        }
160    }
161
162    fn context_limit(&self, model: &str) -> usize {
163        self.context_length(model)
164            .unwrap_or_else(|| wcore::model::default_context_limit(model))
165    }
166
167    fn active_model(&self) -> CompactString {
168        match self {
169            Self::DeepSeek(p) => p.active_model(),
170            Self::OpenAI(p) => p.active_model(),
171            Self::Claude(p) => p.active_model(),
172            #[cfg(feature = "local")]
173            Self::Local(p) => p.active_model(),
174        }
175    }
176}