agcodex_ollama/
client.rs

1use bytes::BytesMut;
2use futures::StreamExt;
3use futures::stream::BoxStream;
4use serde_json::Value as JsonValue;
5use std::collections::VecDeque;
6use std::io;
7
8use crate::parser::pull_events_from_value;
9use crate::pull::PullEvent;
10use crate::pull::PullProgressReporter;
11use crate::url::base_url_to_host_root;
12use crate::url::is_openai_compatible_base_url;
13use agcodex_core::BUILT_IN_OSS_MODEL_PROVIDER_ID;
14use agcodex_core::ModelProviderInfo;
15use agcodex_core::WireApi;
16use agcodex_core::config::Config;
17
18const OLLAMA_CONNECTION_ERROR: &str = "No running Ollama server detected. Start it with: `ollama serve` (after installing). Install instructions: https://github.com/ollama/ollama?tab=readme-ov-file#ollama";
19
20/// Client for interacting with a local Ollama instance.
21pub struct OllamaClient {
22    client: reqwest::Client,
23    host_root: String,
24    uses_openai_compat: bool,
25}
26
27impl OllamaClient {
28    /// Construct a client for the built‑in open‑source ("oss") model provider
29    /// and verify that a local Ollama server is reachable. If no server is
30    /// detected, returns an error with helpful installation/run instructions.
31    pub async fn try_from_oss_provider(config: &Config) -> io::Result<Self> {
32        // Note that we must look up the provider from the Config to ensure that
33        // any overrides the user has in their config.toml are taken into
34        // account.
35        let provider = config
36            .model_providers
37            .get(BUILT_IN_OSS_MODEL_PROVIDER_ID)
38            .ok_or_else(|| {
39                io::Error::new(
40                    io::ErrorKind::NotFound,
41                    format!("Built-in provider {BUILT_IN_OSS_MODEL_PROVIDER_ID} not found",),
42                )
43            })?;
44
45        Self::try_from_provider(provider).await
46    }
47
48    #[cfg(test)]
49    async fn try_from_provider_with_base_url(base_url: &str) -> io::Result<Self> {
50        let provider = agcodex_core::create_oss_provider_with_base_url(base_url);
51        Self::try_from_provider(&provider).await
52    }
53
54    /// Build a client from a provider definition and verify the server is reachable.
55    async fn try_from_provider(provider: &ModelProviderInfo) -> io::Result<Self> {
56        #![expect(clippy::expect_used)]
57        let base_url = provider
58            .base_url
59            .as_ref()
60            .expect("oss provider must have a base_url");
61        let uses_openai_compat = is_openai_compatible_base_url(base_url)
62            || matches!(provider.wire_api, WireApi::Chat)
63                && is_openai_compatible_base_url(base_url);
64        let host_root = base_url_to_host_root(base_url);
65        let client = reqwest::Client::builder()
66            .connect_timeout(std::time::Duration::from_secs(5))
67            .build()
68            .unwrap_or_else(|_| reqwest::Client::new());
69        let client = Self {
70            client,
71            host_root,
72            uses_openai_compat,
73        };
74        client.probe_server().await?;
75        Ok(client)
76    }
77
78    /// Probe whether the server is reachable by hitting the appropriate health endpoint.
79    async fn probe_server(&self) -> io::Result<()> {
80        let url = if self.uses_openai_compat {
81            format!("{}/v1/models", self.host_root.trim_end_matches('/'))
82        } else {
83            format!("{}/api/tags", self.host_root.trim_end_matches('/'))
84        };
85        let resp = self.client.get(url).send().await.map_err(|err| {
86            tracing::warn!("Failed to connect to Ollama server: {err:?}");
87            io::Error::other(OLLAMA_CONNECTION_ERROR)
88        })?;
89        if resp.status().is_success() {
90            Ok(())
91        } else {
92            tracing::warn!(
93                "Failed to probe server at {}: HTTP {}",
94                self.host_root,
95                resp.status()
96            );
97            Err(io::Error::other(OLLAMA_CONNECTION_ERROR))
98        }
99    }
100
101    /// Return the list of model names known to the local Ollama instance.
102    pub async fn fetch_models(&self) -> io::Result<Vec<String>> {
103        let tags_url = format!("{}/api/tags", self.host_root.trim_end_matches('/'));
104        let resp = self
105            .client
106            .get(tags_url)
107            .send()
108            .await
109            .map_err(io::Error::other)?;
110        if !resp.status().is_success() {
111            return Ok(Vec::new());
112        }
113        let val = resp.json::<JsonValue>().await.map_err(io::Error::other)?;
114        let names = val
115            .get("models")
116            .and_then(|m| m.as_array())
117            .map(|arr| {
118                arr.iter()
119                    .filter_map(|v| v.get("name").and_then(|n| n.as_str()))
120                    .map(|s| s.to_string())
121                    .collect::<Vec<_>>()
122            })
123            .unwrap_or_default();
124        Ok(names)
125    }
126
127    /// Start a model pull and emit streaming events. The returned stream ends when
128    /// a Success event is observed or the server closes the connection.
129    #[allow(tail_expr_drop_order)]
130    pub async fn pull_model_stream(
131        &self,
132        model: &str,
133    ) -> io::Result<BoxStream<'static, PullEvent>> {
134        let url = format!("{}/api/pull", self.host_root.trim_end_matches('/'));
135        let resp = self
136            .client
137            .post(url)
138            .json(&serde_json::json!({"model": model, "stream": true}))
139            .send()
140            .await
141            .map_err(io::Error::other)?;
142        if !resp.status().is_success() {
143            return Err(io::Error::other(format!(
144                "failed to start pull: HTTP {}",
145                resp.status()
146            )));
147        }
148
149        let mut stream = resp.bytes_stream();
150        let mut buf = BytesMut::new();
151        let _pending: VecDeque<PullEvent> = VecDeque::new();
152
153        // Using an async stream adaptor backed by unfold-like manual loop.
154        let s = async_stream::stream! {
155            while let Some(chunk) = stream.next().await {
156                match chunk {
157                    Ok(bytes) => {
158                        buf.extend_from_slice(&bytes);
159                        while let Some(pos) = buf.iter().position(|b| *b == b'\n') {
160                            let line = buf.split_to(pos + 1);
161                            if let Ok(text) = std::str::from_utf8(&line) {
162                                let text = text.trim();
163                                if text.is_empty() { continue; }
164                                if let Ok(value) = serde_json::from_str::<JsonValue>(text) {
165                                    for ev in pull_events_from_value(&value) { yield ev; }
166                                    if let Some(err_msg) = value.get("error").and_then(|e| e.as_str()) {
167                                        yield PullEvent::Error(err_msg.to_string());
168                                        return;
169                                    }
170                                    if let Some(status) = value.get("status").and_then(|s| s.as_str())
171                                        && status == "success" { yield PullEvent::Success; return; }
172                                }
173                            }
174                        }
175                    }
176                    Err(_) => {
177                        // Connection error: end the stream.
178                        return;
179                    }
180                }
181            }
182        };
183
184        Ok(Box::pin(s))
185    }
186
187    /// High-level helper to pull a model and drive a progress reporter.
188    pub async fn pull_with_reporter(
189        &self,
190        model: &str,
191        reporter: &mut dyn PullProgressReporter,
192    ) -> io::Result<()> {
193        reporter.on_event(&PullEvent::Status(format!("Pulling model {model}...")))?;
194        let mut stream = self.pull_model_stream(model).await?;
195        while let Some(event) = stream.next().await {
196            reporter.on_event(&event)?;
197            match event {
198                PullEvent::Success => {
199                    return Ok(());
200                }
201                PullEvent::Error(err) => {
202                    // Empirically, ollama returns a 200 OK response even when
203                    // the output stream includes an error message. Verify with:
204                    //
205                    // `curl -i http://localhost:11434/api/pull -d '{ "model": "foobarbaz" }'`
206                    //
207                    // As such, we have to check the event stream, not the
208                    // HTTP response status, to determine whether to return Err.
209                    return Err(io::Error::other(format!("Pull failed: {err}")));
210                }
211                PullEvent::ChunkProgress { .. } | PullEvent::Status(_) => {
212                    continue;
213                }
214            }
215        }
216        Err(io::Error::other(
217            "Pull stream ended unexpectedly without success.",
218        ))
219    }
220
221    /// Low-level constructor given a raw host root, e.g. "http://localhost:11434".
222    #[cfg(test)]
223    fn from_host_root(host_root: impl Into<String>) -> Self {
224        let client = reqwest::Client::builder()
225            .connect_timeout(std::time::Duration::from_secs(5))
226            .build()
227            .unwrap_or_else(|_| reqwest::Client::new());
228        Self {
229            client,
230            host_root: host_root.into(),
231            uses_openai_compat: false,
232        }
233    }
234}
235
236#[cfg(test)]
237mod tests {
238    use super::*;
239
240    // Happy-path tests using a mock HTTP server; skip if sandbox network is disabled.
241    #[tokio::test]
242    async fn test_fetch_models_happy_path() {
243        if std::env::var(agcodex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
244            tracing::info!(
245                "{} is set; skipping test_fetch_models_happy_path",
246                agcodex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR
247            );
248            return;
249        }
250
251        let server = wiremock::MockServer::start().await;
252        wiremock::Mock::given(wiremock::matchers::method("GET"))
253            .and(wiremock::matchers::path("/api/tags"))
254            .respond_with(
255                wiremock::ResponseTemplate::new(200).set_body_raw(
256                    serde_json::json!({
257                        "models": [ {"name": "llama3.2:3b"}, {"name":"mistral"} ]
258                    })
259                    .to_string(),
260                    "application/json",
261                ),
262            )
263            .mount(&server)
264            .await;
265
266        let client = OllamaClient::from_host_root(server.uri());
267        let models = client.fetch_models().await.expect("fetch models");
268        assert!(models.contains(&"llama3.2:3b".to_string()));
269        assert!(models.contains(&"mistral".to_string()));
270    }
271
272    #[tokio::test]
273    async fn test_probe_server_happy_path_openai_compat_and_native() {
274        if std::env::var(agcodex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
275            tracing::info!(
276                "{} set; skipping test_probe_server_happy_path_openai_compat_and_native",
277                agcodex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR
278            );
279            return;
280        }
281
282        let server = wiremock::MockServer::start().await;
283
284        // Native endpoint
285        wiremock::Mock::given(wiremock::matchers::method("GET"))
286            .and(wiremock::matchers::path("/api/tags"))
287            .respond_with(wiremock::ResponseTemplate::new(200))
288            .mount(&server)
289            .await;
290        let native = OllamaClient::from_host_root(server.uri());
291        native.probe_server().await.expect("probe native");
292
293        // OpenAI compatibility endpoint
294        wiremock::Mock::given(wiremock::matchers::method("GET"))
295            .and(wiremock::matchers::path("/v1/models"))
296            .respond_with(wiremock::ResponseTemplate::new(200))
297            .mount(&server)
298            .await;
299        let ollama_client =
300            OllamaClient::try_from_provider_with_base_url(&format!("{}/v1", server.uri()))
301                .await
302                .expect("probe OpenAI compat");
303        ollama_client
304            .probe_server()
305            .await
306            .expect("probe OpenAI compat");
307    }
308
309    #[tokio::test]
310    async fn test_try_from_oss_provider_ok_when_server_running() {
311        if std::env::var(agcodex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
312            tracing::info!(
313                "{} set; skipping test_try_from_oss_provider_ok_when_server_running",
314                agcodex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR
315            );
316            return;
317        }
318
319        let server = wiremock::MockServer::start().await;
320
321        // OpenAI‑compat models endpoint responds OK.
322        wiremock::Mock::given(wiremock::matchers::method("GET"))
323            .and(wiremock::matchers::path("/v1/models"))
324            .respond_with(wiremock::ResponseTemplate::new(200))
325            .mount(&server)
326            .await;
327
328        OllamaClient::try_from_provider_with_base_url(&format!("{}/v1", server.uri()))
329            .await
330            .expect("client should be created when probe succeeds");
331    }
332
333    #[tokio::test]
334    async fn test_try_from_oss_provider_err_when_server_missing() {
335        if std::env::var(agcodex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
336            tracing::info!(
337                "{} set; skipping test_try_from_oss_provider_err_when_server_missing",
338                agcodex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR
339            );
340            return;
341        }
342
343        let server = wiremock::MockServer::start().await;
344        let err = OllamaClient::try_from_provider_with_base_url(&format!("{}/v1", server.uri()))
345            .await
346            .err()
347            .expect("expected error");
348        assert_eq!(OLLAMA_CONNECTION_ERROR, err.to_string());
349    }
350}