opencode_provider_manager/discovery/
provider_api.rs1use super::error::{DiscoveryError, Result};
10use async_trait::async_trait;
11use reqwest::Client;
12use serde::Deserialize;
13
14#[async_trait]
16pub trait ModelDiscovery: Send + Sync {
17 fn provider_id(&self) -> &str;
19
20 async fn discover_models(&self, api_key: Option<&str>) -> Result<Vec<DiscoveredModel>>;
22}
23
24pub struct OpenAICompatibleDiscovery {
26 provider_id: String,
27 base_url: String,
28 client: Client,
29}
30
31impl OpenAICompatibleDiscovery {
32 pub fn new(provider_id: impl Into<String>, base_url: impl Into<String>) -> Self {
33 Self {
34 provider_id: provider_id.into(),
35 base_url: base_url.into(),
36 client: Client::new(),
37 }
38 }
39
40 pub fn openai() -> Self {
42 Self::new("openai", "https://api.openai.com/v1")
43 }
44
45 pub fn lmstudio() -> Self {
47 Self::new("lmstudio", "http://127.0.0.1:1234/v1")
48 }
49}
50
51#[async_trait]
52impl ModelDiscovery for OpenAICompatibleDiscovery {
53 fn provider_id(&self) -> &str {
54 &self.provider_id
55 }
56
57 async fn discover_models(&self, api_key: Option<&str>) -> Result<Vec<DiscoveredModel>> {
58 let mut request = self
59 .client
60 .get(format!("{}/models", self.base_url.trim_end_matches('/')));
61
62 if let Some(key) = api_key {
63 request = request.bearer_auth(key);
64 }
65
66 let response = request
67 .send()
68 .await
69 .map_err(|e| DiscoveryError::Network(e.to_string()))?;
70
71 let models_response: OpenAIModelsResponse = response
72 .json()
73 .await
74 .map_err(|e| DiscoveryError::Parse(e.to_string()))?;
75
76 Ok(models_response
77 .data
78 .into_iter()
79 .map(|model| {
80 let name = model.id.clone();
81 DiscoveredModel {
82 id: model.id,
83 name,
84 provider_id: self.provider_id.clone(),
85 context_length: None,
86 max_output_tokens: None,
87 input_cost_per_million: None,
88 output_cost_per_million: None,
89 }
90 })
91 .collect())
92 }
93}
94
95pub struct OllamaDiscovery {
97 base_url: String,
98 client: Client,
99}
100
101impl OllamaDiscovery {
102 pub fn new(base_url: impl Into<String>) -> Self {
103 Self {
104 base_url: base_url.into(),
105 client: Client::new(),
106 }
107 }
108
109 pub fn default_instance() -> Self {
110 Self::new("http://127.0.0.1:11434")
111 }
112}
113
114#[async_trait]
115impl ModelDiscovery for OllamaDiscovery {
116 fn provider_id(&self) -> &str {
117 "ollama"
118 }
119
120 async fn discover_models(&self, _api_key: Option<&str>) -> Result<Vec<DiscoveredModel>> {
121 let response = self
122 .client
123 .get(format!("{}/api/tags", self.base_url.trim_end_matches('/')))
124 .send()
125 .await
126 .map_err(|e| DiscoveryError::Network(e.to_string()))?;
127
128 let ollama_response: OllamaTagsResponse = response
129 .json()
130 .await
131 .map_err(|e| DiscoveryError::Parse(e.to_string()))?;
132
133 Ok(ollama_response
134 .models
135 .into_iter()
136 .map(|model| {
137 let name = model.name.clone();
138 DiscoveredModel {
139 id: model.name,
140 name,
141 provider_id: "ollama".to_string(),
142 context_length: None,
143 max_output_tokens: None,
144 input_cost_per_million: None,
145 output_cost_per_million: None,
146 }
147 })
148 .collect())
149 }
150}
151
152#[derive(Debug, Deserialize)]
155struct OpenAIModelsResponse {
156 data: Vec<OpenAIModel>,
157}
158
159#[derive(Debug, Deserialize)]
160struct OpenAIModel {
161 id: String,
162}
163
164#[derive(Debug, Deserialize)]
165struct OllamaTagsResponse {
166 models: Vec<OllamaModel>,
167}
168
169#[derive(Debug, Deserialize)]
170struct OllamaModel {
171 name: String,
172}
173
174use super::DiscoveredModel;
175
176#[cfg(test)]
177mod tests {
178 use super::*;
179
180 #[test]
181 fn test_openai_discovery_creation() {
182 let discovery = OpenAICompatibleDiscovery::openai();
183 assert_eq!(discovery.provider_id(), "openai");
184 }
185
186 #[test]
187 fn test_lmstudio_discovery_creation() {
188 let discovery = OpenAICompatibleDiscovery::lmstudio();
189 assert_eq!(discovery.provider_id(), "lmstudio");
190 }
191
192 #[test]
193 fn test_ollama_discovery_creation() {
194 let discovery = OllamaDiscovery::default_instance();
195 assert_eq!(discovery.provider_id(), "ollama");
196 }
197}