1use bytes::BytesMut;
2use futures::StreamExt;
3use futures::stream::BoxStream;
4use serde_json::Value as JsonValue;
5use std::collections::VecDeque;
6use std::io;
7
8use crate::parser::pull_events_from_value;
9use crate::pull::PullEvent;
10use crate::pull::PullProgressReporter;
11use crate::url::base_url_to_host_root;
12use crate::url::is_openai_compatible_base_url;
13use agcodex_core::BUILT_IN_OSS_MODEL_PROVIDER_ID;
14use agcodex_core::ModelProviderInfo;
15use agcodex_core::WireApi;
16use agcodex_core::config::Config;
17
18const OLLAMA_CONNECTION_ERROR: &str = "No running Ollama server detected. Start it with: `ollama serve` (after installing). Install instructions: https://github.com/ollama/ollama?tab=readme-ov-file#ollama";
19
20pub struct OllamaClient {
22 client: reqwest::Client,
23 host_root: String,
24 uses_openai_compat: bool,
25}
26
27impl OllamaClient {
28 pub async fn try_from_oss_provider(config: &Config) -> io::Result<Self> {
32 let provider = config
36 .model_providers
37 .get(BUILT_IN_OSS_MODEL_PROVIDER_ID)
38 .ok_or_else(|| {
39 io::Error::new(
40 io::ErrorKind::NotFound,
41 format!("Built-in provider {BUILT_IN_OSS_MODEL_PROVIDER_ID} not found",),
42 )
43 })?;
44
45 Self::try_from_provider(provider).await
46 }
47
48 #[cfg(test)]
49 async fn try_from_provider_with_base_url(base_url: &str) -> io::Result<Self> {
50 let provider = agcodex_core::create_oss_provider_with_base_url(base_url);
51 Self::try_from_provider(&provider).await
52 }
53
54 async fn try_from_provider(provider: &ModelProviderInfo) -> io::Result<Self> {
56 #![expect(clippy::expect_used)]
57 let base_url = provider
58 .base_url
59 .as_ref()
60 .expect("oss provider must have a base_url");
61 let uses_openai_compat = is_openai_compatible_base_url(base_url)
62 || matches!(provider.wire_api, WireApi::Chat)
63 && is_openai_compatible_base_url(base_url);
64 let host_root = base_url_to_host_root(base_url);
65 let client = reqwest::Client::builder()
66 .connect_timeout(std::time::Duration::from_secs(5))
67 .build()
68 .unwrap_or_else(|_| reqwest::Client::new());
69 let client = Self {
70 client,
71 host_root,
72 uses_openai_compat,
73 };
74 client.probe_server().await?;
75 Ok(client)
76 }
77
78 async fn probe_server(&self) -> io::Result<()> {
80 let url = if self.uses_openai_compat {
81 format!("{}/v1/models", self.host_root.trim_end_matches('/'))
82 } else {
83 format!("{}/api/tags", self.host_root.trim_end_matches('/'))
84 };
85 let resp = self.client.get(url).send().await.map_err(|err| {
86 tracing::warn!("Failed to connect to Ollama server: {err:?}");
87 io::Error::other(OLLAMA_CONNECTION_ERROR)
88 })?;
89 if resp.status().is_success() {
90 Ok(())
91 } else {
92 tracing::warn!(
93 "Failed to probe server at {}: HTTP {}",
94 self.host_root,
95 resp.status()
96 );
97 Err(io::Error::other(OLLAMA_CONNECTION_ERROR))
98 }
99 }
100
101 pub async fn fetch_models(&self) -> io::Result<Vec<String>> {
103 let tags_url = format!("{}/api/tags", self.host_root.trim_end_matches('/'));
104 let resp = self
105 .client
106 .get(tags_url)
107 .send()
108 .await
109 .map_err(io::Error::other)?;
110 if !resp.status().is_success() {
111 return Ok(Vec::new());
112 }
113 let val = resp.json::<JsonValue>().await.map_err(io::Error::other)?;
114 let names = val
115 .get("models")
116 .and_then(|m| m.as_array())
117 .map(|arr| {
118 arr.iter()
119 .filter_map(|v| v.get("name").and_then(|n| n.as_str()))
120 .map(|s| s.to_string())
121 .collect::<Vec<_>>()
122 })
123 .unwrap_or_default();
124 Ok(names)
125 }
126
127 #[allow(tail_expr_drop_order)]
130 pub async fn pull_model_stream(
131 &self,
132 model: &str,
133 ) -> io::Result<BoxStream<'static, PullEvent>> {
134 let url = format!("{}/api/pull", self.host_root.trim_end_matches('/'));
135 let resp = self
136 .client
137 .post(url)
138 .json(&serde_json::json!({"model": model, "stream": true}))
139 .send()
140 .await
141 .map_err(io::Error::other)?;
142 if !resp.status().is_success() {
143 return Err(io::Error::other(format!(
144 "failed to start pull: HTTP {}",
145 resp.status()
146 )));
147 }
148
149 let mut stream = resp.bytes_stream();
150 let mut buf = BytesMut::new();
151 let _pending: VecDeque<PullEvent> = VecDeque::new();
152
153 let s = async_stream::stream! {
155 while let Some(chunk) = stream.next().await {
156 match chunk {
157 Ok(bytes) => {
158 buf.extend_from_slice(&bytes);
159 while let Some(pos) = buf.iter().position(|b| *b == b'\n') {
160 let line = buf.split_to(pos + 1);
161 if let Ok(text) = std::str::from_utf8(&line) {
162 let text = text.trim();
163 if text.is_empty() { continue; }
164 if let Ok(value) = serde_json::from_str::<JsonValue>(text) {
165 for ev in pull_events_from_value(&value) { yield ev; }
166 if let Some(err_msg) = value.get("error").and_then(|e| e.as_str()) {
167 yield PullEvent::Error(err_msg.to_string());
168 return;
169 }
170 if let Some(status) = value.get("status").and_then(|s| s.as_str())
171 && status == "success" { yield PullEvent::Success; return; }
172 }
173 }
174 }
175 }
176 Err(_) => {
177 return;
179 }
180 }
181 }
182 };
183
184 Ok(Box::pin(s))
185 }
186
187 pub async fn pull_with_reporter(
189 &self,
190 model: &str,
191 reporter: &mut dyn PullProgressReporter,
192 ) -> io::Result<()> {
193 reporter.on_event(&PullEvent::Status(format!("Pulling model {model}...")))?;
194 let mut stream = self.pull_model_stream(model).await?;
195 while let Some(event) = stream.next().await {
196 reporter.on_event(&event)?;
197 match event {
198 PullEvent::Success => {
199 return Ok(());
200 }
201 PullEvent::Error(err) => {
202 return Err(io::Error::other(format!("Pull failed: {err}")));
210 }
211 PullEvent::ChunkProgress { .. } | PullEvent::Status(_) => {
212 continue;
213 }
214 }
215 }
216 Err(io::Error::other(
217 "Pull stream ended unexpectedly without success.",
218 ))
219 }
220
221 #[cfg(test)]
223 fn from_host_root(host_root: impl Into<String>) -> Self {
224 let client = reqwest::Client::builder()
225 .connect_timeout(std::time::Duration::from_secs(5))
226 .build()
227 .unwrap_or_else(|_| reqwest::Client::new());
228 Self {
229 client,
230 host_root: host_root.into(),
231 uses_openai_compat: false,
232 }
233 }
234}
235
236#[cfg(test)]
237mod tests {
238 use super::*;
239
240 #[tokio::test]
242 async fn test_fetch_models_happy_path() {
243 if std::env::var(agcodex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
244 tracing::info!(
245 "{} is set; skipping test_fetch_models_happy_path",
246 agcodex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR
247 );
248 return;
249 }
250
251 let server = wiremock::MockServer::start().await;
252 wiremock::Mock::given(wiremock::matchers::method("GET"))
253 .and(wiremock::matchers::path("/api/tags"))
254 .respond_with(
255 wiremock::ResponseTemplate::new(200).set_body_raw(
256 serde_json::json!({
257 "models": [ {"name": "llama3.2:3b"}, {"name":"mistral"} ]
258 })
259 .to_string(),
260 "application/json",
261 ),
262 )
263 .mount(&server)
264 .await;
265
266 let client = OllamaClient::from_host_root(server.uri());
267 let models = client.fetch_models().await.expect("fetch models");
268 assert!(models.contains(&"llama3.2:3b".to_string()));
269 assert!(models.contains(&"mistral".to_string()));
270 }
271
272 #[tokio::test]
273 async fn test_probe_server_happy_path_openai_compat_and_native() {
274 if std::env::var(agcodex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
275 tracing::info!(
276 "{} set; skipping test_probe_server_happy_path_openai_compat_and_native",
277 agcodex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR
278 );
279 return;
280 }
281
282 let server = wiremock::MockServer::start().await;
283
284 wiremock::Mock::given(wiremock::matchers::method("GET"))
286 .and(wiremock::matchers::path("/api/tags"))
287 .respond_with(wiremock::ResponseTemplate::new(200))
288 .mount(&server)
289 .await;
290 let native = OllamaClient::from_host_root(server.uri());
291 native.probe_server().await.expect("probe native");
292
293 wiremock::Mock::given(wiremock::matchers::method("GET"))
295 .and(wiremock::matchers::path("/v1/models"))
296 .respond_with(wiremock::ResponseTemplate::new(200))
297 .mount(&server)
298 .await;
299 let ollama_client =
300 OllamaClient::try_from_provider_with_base_url(&format!("{}/v1", server.uri()))
301 .await
302 .expect("probe OpenAI compat");
303 ollama_client
304 .probe_server()
305 .await
306 .expect("probe OpenAI compat");
307 }
308
309 #[tokio::test]
310 async fn test_try_from_oss_provider_ok_when_server_running() {
311 if std::env::var(agcodex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
312 tracing::info!(
313 "{} set; skipping test_try_from_oss_provider_ok_when_server_running",
314 agcodex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR
315 );
316 return;
317 }
318
319 let server = wiremock::MockServer::start().await;
320
321 wiremock::Mock::given(wiremock::matchers::method("GET"))
323 .and(wiremock::matchers::path("/v1/models"))
324 .respond_with(wiremock::ResponseTemplate::new(200))
325 .mount(&server)
326 .await;
327
328 OllamaClient::try_from_provider_with_base_url(&format!("{}/v1", server.uri()))
329 .await
330 .expect("client should be created when probe succeeds");
331 }
332
333 #[tokio::test]
334 async fn test_try_from_oss_provider_err_when_server_missing() {
335 if std::env::var(agcodex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
336 tracing::info!(
337 "{} set; skipping test_try_from_oss_provider_err_when_server_missing",
338 agcodex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR
339 );
340 return;
341 }
342
343 let server = wiremock::MockServer::start().await;
344 let err = OllamaClient::try_from_provider_with_base_url(&format!("{}/v1", server.uri()))
345 .await
346 .err()
347 .expect("expected error");
348 assert_eq!(OLLAMA_CONNECTION_ERROR, err.to_string());
349 }
350}