Skip to main content

opencode_provider_manager/discovery/
provider_api.rs

1//! Direct Provider API queries for discovering available models.
2//!
3//! Supports querying provider APIs directly to list available models:
4//! - OpenAI: GET /v1/models
5//! - Ollama: GET /api/tags
6//! - LM Studio: GET /v1/models
7//! - Extensible via the `ModelDiscovery` trait
8
9use super::error::{DiscoveryError, Result};
10use async_trait::async_trait;
11use reqwest::Client;
12use serde::Deserialize;
13
14/// Trait for provider-specific model discovery.
15#[async_trait]
16pub trait ModelDiscovery: Send + Sync {
17    /// The provider ID this discovery implementation handles.
18    fn provider_id(&self) -> &str;
19
20    /// Discover available models from the provider.
21    async fn discover_models(&self, api_key: Option<&str>) -> Result<Vec<DiscoveredModel>>;
22}
23
24/// OpenAI-compatible provider model discovery.
25pub struct OpenAICompatibleDiscovery {
26    provider_id: String,
27    base_url: String,
28    client: Client,
29}
30
31impl OpenAICompatibleDiscovery {
32    pub fn new(provider_id: impl Into<String>, base_url: impl Into<String>) -> Self {
33        Self {
34            provider_id: provider_id.into(),
35            base_url: base_url.into(),
36            client: Client::new(),
37        }
38    }
39
40    /// Create a discovery for OpenAI.
41    pub fn openai() -> Self {
42        Self::new("openai", "https://api.openai.com/v1")
43    }
44
45    /// Create a discovery for LM Studio.
46    pub fn lmstudio() -> Self {
47        Self::new("lmstudio", "http://127.0.0.1:1234/v1")
48    }
49}
50
51#[async_trait]
52impl ModelDiscovery for OpenAICompatibleDiscovery {
53    fn provider_id(&self) -> &str {
54        &self.provider_id
55    }
56
57    async fn discover_models(&self, api_key: Option<&str>) -> Result<Vec<DiscoveredModel>> {
58        let mut request = self
59            .client
60            .get(format!("{}/models", self.base_url.trim_end_matches('/')));
61
62        if let Some(key) = api_key {
63            request = request.bearer_auth(key);
64        }
65
66        let response = request
67            .send()
68            .await
69            .map_err(|e| DiscoveryError::Network(e.to_string()))?;
70
71        let models_response: OpenAIModelsResponse = response
72            .json()
73            .await
74            .map_err(|e| DiscoveryError::Parse(e.to_string()))?;
75
76        Ok(models_response
77            .data
78            .into_iter()
79            .map(|model| {
80                let name = model.id.clone();
81                DiscoveredModel {
82                    id: model.id,
83                    name,
84                    provider_id: self.provider_id.clone(),
85                    context_length: None,
86                    max_output_tokens: None,
87                    input_cost_per_million: None,
88                    output_cost_per_million: None,
89                }
90            })
91            .collect())
92    }
93}
94
95/// Ollama model discovery.
96pub struct OllamaDiscovery {
97    base_url: String,
98    client: Client,
99}
100
101impl OllamaDiscovery {
102    pub fn new(base_url: impl Into<String>) -> Self {
103        Self {
104            base_url: base_url.into(),
105            client: Client::new(),
106        }
107    }
108
109    pub fn default_instance() -> Self {
110        Self::new("http://127.0.0.1:11434")
111    }
112}
113
114#[async_trait]
115impl ModelDiscovery for OllamaDiscovery {
116    fn provider_id(&self) -> &str {
117        "ollama"
118    }
119
120    async fn discover_models(&self, _api_key: Option<&str>) -> Result<Vec<DiscoveredModel>> {
121        let response = self
122            .client
123            .get(format!("{}/api/tags", self.base_url.trim_end_matches('/')))
124            .send()
125            .await
126            .map_err(|e| DiscoveryError::Network(e.to_string()))?;
127
128        let ollama_response: OllamaTagsResponse = response
129            .json()
130            .await
131            .map_err(|e| DiscoveryError::Parse(e.to_string()))?;
132
133        Ok(ollama_response
134            .models
135            .into_iter()
136            .map(|model| {
137                let name = model.name.clone();
138                DiscoveredModel {
139                    id: model.name,
140                    name,
141                    provider_id: "ollama".to_string(),
142                    context_length: None,
143                    max_output_tokens: None,
144                    input_cost_per_million: None,
145                    output_cost_per_million: None,
146                }
147            })
148            .collect())
149    }
150}
151
152// Response type definitions
153
154#[derive(Debug, Deserialize)]
155struct OpenAIModelsResponse {
156    data: Vec<OpenAIModel>,
157}
158
159#[derive(Debug, Deserialize)]
160struct OpenAIModel {
161    id: String,
162}
163
164#[derive(Debug, Deserialize)]
165struct OllamaTagsResponse {
166    models: Vec<OllamaModel>,
167}
168
169#[derive(Debug, Deserialize)]
170struct OllamaModel {
171    name: String,
172}
173
174use super::DiscoveredModel;
175
176#[cfg(test)]
177mod tests {
178    use super::*;
179
180    #[test]
181    fn test_openai_discovery_creation() {
182        let discovery = OpenAICompatibleDiscovery::openai();
183        assert_eq!(discovery.provider_id(), "openai");
184    }
185
186    #[test]
187    fn test_lmstudio_discovery_creation() {
188        let discovery = OpenAICompatibleDiscovery::lmstudio();
189        assert_eq!(discovery.provider_id(), "lmstudio");
190    }
191
192    #[test]
193    fn test_ollama_discovery_creation() {
194        let discovery = OllamaDiscovery::default_instance();
195        assert_eq!(discovery.provider_id(), "ollama");
196    }
197}