Skip to main content

oxide_agent/client/
mod.rs

1use std::pin::Pin;
2
3use async_trait::async_trait;
4use futures_util::{Stream, StreamExt, TryStreamExt};
5use tokio_util::codec::{FramedRead, LinesCodec};
6use tokio_util::io::StreamReader;
7
8use crate::error::OxideError;
9use crate::types::{
10    ChatRequest, ChatResponse, EmbedRequest, EmbedResponse, GenerateRequest, GenerateResponse,
11    ListModelsResponse,
12};
13
14/// Convenience alias for a pinned, owned, send-able stream of fallible items.
15pub type BoxStream<T> = Pin<Box<dyn Stream<Item = Result<T, OxideError>> + Send>>;
16
17// ── The core trait ────────────────────────────────────────────────────────────
18
19/// The type-safe bridge to any Ollama-compatible backend.
20///
21/// Decoupling agent logic from transport via a trait means you can:
22/// - Swap in a real HTTP client for production.
23/// - Swap in a `MockOllamaClient` for unit testing — no running server needed.
24/// - Wrap the client with middleware (retry, logging, rate-limiting) transparently.
25#[async_trait]
26pub trait OllamaClient: Send + Sync {
27    // ── Blocking (buffered) variants ──────────────────────────────────────────
28
29    /// Single-turn text completion (`/api/generate`).
30    async fn generate(&self, req: GenerateRequest) -> Result<GenerateResponse, OxideError>;
31
32    /// Multi-turn chat completion (`/api/chat`), including tool/function calling.
33    async fn chat(&self, req: ChatRequest) -> Result<ChatResponse, OxideError>;
34
35    /// Produce dense vector embeddings for one or more texts (`/api/embed`).
36    async fn embed(&self, req: EmbedRequest) -> Result<EmbedResponse, OxideError>;
37
38    /// List models available on the Ollama server (`/api/tags`).
39    async fn list_models(&self) -> Result<ListModelsResponse, OxideError>;
40
41    // ── Zero-copy streaming variants ──────────────────────────────────────────
42
43    /// Stream a generate response token-by-token.
44    ///
45    /// Ollama sends newline-delimited JSON (`NDJSON`) when `stream: true`.
46    /// This method decodes each line into a `GenerateResponse` without buffering
47    /// the entire response body, keeping memory overhead minimal regardless of
48    /// output length.
49    fn stream_generate(&self, req: GenerateRequest) -> BoxStream<GenerateResponse>;
50
51    /// Stream a chat response token-by-token.
52    fn stream_chat(&self, req: ChatRequest) -> BoxStream<ChatResponse>;
53}
54
55// ── HTTP implementation ───────────────────────────────────────────────────────
56
57/// Production client that speaks to a real Ollama server over HTTP.
58pub struct HttpOllamaClient {
59    base_url: String,
60    http: reqwest::Client,
61}
62
63impl HttpOllamaClient {
64    /// `base_url` is typically `"http://localhost:11434"`.
65    pub fn new(base_url: impl Into<String>) -> Self {
66        Self {
67            base_url: base_url.into(),
68            http: reqwest::Client::new(),
69        }
70    }
71
72    fn url(&self, path: &str) -> String {
73        format!("{}{}", self.base_url.trim_end_matches('/'), path)
74    }
75
76    /// Shared helper: POST JSON, check status, return the response.
77    async fn post_json<B: serde::Serialize>(
78        &self,
79        path: &str,
80        body: &B,
81    ) -> Result<reqwest::Response, OxideError> {
82        let resp = self
83            .http
84            .post(self.url(path))
85            .json(body)
86            .send()
87            .await
88            .map_err(OxideError::Http)?;
89
90        if !resp.status().is_success() {
91            let status = resp.status().as_u16();
92            let text = resp.text().await.unwrap_or_default();
93            return Err(OxideError::ApiError(status, text));
94        }
95
96        Ok(resp)
97    }
98
99    /// Convert a reqwest response body into a line-by-line NDJSON stream.
100    ///
101    /// - `bytes_stream()` gives `Stream<Item = Result<Bytes, reqwest::Error>>`
102    /// - `StreamReader` adapts that to `AsyncRead` (zero-copy: the `Bytes`
103    ///   chunks from reqwest are forwarded directly without re-allocation)
104    /// - `FramedRead` + `LinesCodec` splits on `\n` without extra heap copies
105    fn ndjson_lines(resp: reqwest::Response) -> impl Stream<Item = Result<String, OxideError>> {
106        let byte_stream = resp
107            .bytes_stream()
108            .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e));
109
110        let reader = StreamReader::new(byte_stream);
111        FramedRead::new(reader, LinesCodec::new())
112            .map_err(|e| OxideError::Other(e.to_string()))
113    }
114}
115
116#[async_trait]
117impl OllamaClient for HttpOllamaClient {
118    async fn generate(&self, mut req: GenerateRequest) -> Result<GenerateResponse, OxideError> {
119        req.stream = false;
120        let resp = self.post_json("/api/generate", &req).await?;
121        resp.json::<GenerateResponse>().await.map_err(OxideError::Http)
122    }
123
124    async fn chat(&self, mut req: ChatRequest) -> Result<ChatResponse, OxideError> {
125        req.stream = false;
126        let resp = self.post_json("/api/chat", &req).await?;
127        resp.json::<ChatResponse>().await.map_err(OxideError::Http)
128    }
129
130    async fn embed(&self, req: EmbedRequest) -> Result<EmbedResponse, OxideError> {
131        let resp = self.post_json("/api/embed", &req).await?;
132        resp.json::<EmbedResponse>().await.map_err(OxideError::Http)
133    }
134
135    async fn list_models(&self) -> Result<ListModelsResponse, OxideError> {
136        let resp = self
137            .http
138            .get(self.url("/api/tags"))
139            .send()
140            .await
141            .map_err(OxideError::Http)?;
142
143        if !resp.status().is_success() {
144            return Err(OxideError::ApiError(
145                resp.status().as_u16(),
146                resp.text().await.unwrap_or_default(),
147            ));
148        }
149
150        resp.json::<ListModelsResponse>().await.map_err(OxideError::Http)
151    }
152
153    fn stream_generate(&self, mut req: GenerateRequest) -> BoxStream<GenerateResponse> {
154        req.stream = true;
155        let http = self.http.clone();
156        let url = self.url("/api/generate");
157
158        Box::pin(async_stream::try_stream! {
159            let resp = http.post(&url).json(&req).send().await.map_err(OxideError::Http)?;
160            let status = resp.status();
161            if status.is_success() {
162                let mut lines = Self::ndjson_lines(resp);
163                while let Some(line) = lines.next().await {
164                    let line = line?;
165                    if line.trim().is_empty() { continue; }
166                    let chunk = serde_json::from_str::<GenerateResponse>(&line)
167                        .map_err(OxideError::Serde)?;
168                    yield chunk;
169                }
170            } else {
171                let text = resp.text().await.unwrap_or_default();
172                Err(OxideError::ApiError(status.as_u16(), text))?;
173            }
174        })
175    }
176
177    fn stream_chat(&self, mut req: ChatRequest) -> BoxStream<ChatResponse> {
178        req.stream = true;
179        let http = self.http.clone();
180        let url = self.url("/api/chat");
181
182        Box::pin(async_stream::try_stream! {
183            let resp = http.post(&url).json(&req).send().await.map_err(OxideError::Http)?;
184            let status = resp.status();
185            if status.is_success() {
186                let mut lines = Self::ndjson_lines(resp);
187                while let Some(line) = lines.next().await {
188                    let line = line?;
189                    if line.trim().is_empty() { continue; }
190                    let chunk = serde_json::from_str::<ChatResponse>(&line)
191                        .map_err(OxideError::Serde)?;
192                    yield chunk;
193                }
194            } else {
195                let text = resp.text().await.unwrap_or_default();
196                Err(OxideError::ApiError(status.as_u16(), text))?;
197            }
198        })
199    }
200}
201
202// ── Tests ─────────────────────────────────────────────────────────────────────
203
204#[cfg(test)]
205mod tests {
206    use super::*;
207    use crate::types::{Message, Role};
208    use futures_util::StreamExt;
209
210    // ── Shared mock ───────────────────────────────────────────────────────────
211
212    struct MockOllamaClient {
213        chat_chunks: Vec<ChatResponse>,
214    }
215
216    #[async_trait]
217    impl OllamaClient for MockOllamaClient {
218        async fn generate(&self, _: GenerateRequest) -> Result<GenerateResponse, OxideError> {
219            unimplemented!()
220        }
221
222        async fn chat(&self, _: ChatRequest) -> Result<ChatResponse, OxideError> {
223            // Return the last (done=true) chunk as the buffered response.
224            Ok(self.chat_chunks.last().unwrap().clone())
225        }
226
227        async fn embed(&self, _: EmbedRequest) -> Result<EmbedResponse, OxideError> {
228            unimplemented!()
229        }
230
231        async fn list_models(&self) -> Result<ListModelsResponse, OxideError> {
232            unimplemented!()
233        }
234
235        fn stream_generate(&self, _: GenerateRequest) -> BoxStream<GenerateResponse> {
236            unimplemented!()
237        }
238
239        fn stream_chat(&self, _: ChatRequest) -> BoxStream<ChatResponse> {
240            let chunks: Vec<Result<ChatResponse, OxideError>> =
241                self.chat_chunks.iter().cloned().map(Ok).collect();
242            Box::pin(futures_util::stream::iter(chunks))
243        }
244    }
245
246    fn make_mock() -> MockOllamaClient {
247        MockOllamaClient {
248            chat_chunks: vec![
249                ChatResponse {
250                    model: "llama3".into(),
251                    message: Message { role: Role::Assistant, content: "Hello".into(), tool_calls: None },
252                    done: false,
253                },
254                ChatResponse {
255                    model: "llama3".into(),
256                    message: Message { role: Role::Assistant, content: ", world!".into(), tool_calls: None },
257                    done: true,
258                },
259            ],
260        }
261    }
262
263    // ── Buffered chat ─────────────────────────────────────────────────────────
264
265    #[tokio::test]
266    async fn mock_client_returns_canned_response() {
267        let mock = make_mock();
268        let req = ChatRequest {
269            model: "llama3".into(),
270            messages: vec![Message {
271                role: Role::User,
272                content: "Say hello.".into(),
273                tool_calls: None,
274            }],
275            tools: None,
276            stream: false,
277        };
278
279        let resp = mock.chat(req).await.unwrap();
280        assert_eq!(resp.message.role, Role::Assistant);
281        assert!(resp.done);
282    }
283
284    // ── Streaming chat ────────────────────────────────────────────────────────
285
286    /// Proves the streaming path yields all chunks in order and the final
287    /// chunk carries `done: true` — without a live server.
288    #[tokio::test]
289    async fn mock_stream_chat_yields_all_chunks() {
290        let mock = make_mock();
291        let req = ChatRequest {
292            model: "llama3".into(),
293            messages: vec![Message {
294                role: Role::User,
295                content: "Say hello.".into(),
296                tool_calls: None,
297            }],
298            tools: None,
299            stream: true,
300        };
301
302        let chunks: Vec<_> = mock.stream_chat(req).collect().await;
303        assert_eq!(chunks.len(), 2);
304
305        let first = chunks[0].as_ref().unwrap();
306        assert_eq!(first.message.content, "Hello");
307        assert!(!first.done);
308
309        let last = chunks[1].as_ref().unwrap();
310        assert_eq!(last.message.content, ", world!");
311        assert!(last.done);
312    }
313
314    /// Accumulating chunk content mirrors what the buffered response would return.
315    #[tokio::test]
316    async fn stream_content_matches_buffered_content() {
317        let mock = make_mock();
318        let req = ChatRequest {
319            model: "llama3".into(),
320            messages: vec![],
321            tools: None,
322            stream: true,
323        };
324
325        let full_text: String = mock
326            .stream_chat(req)
327            .filter_map(|r| async move { r.ok() })
328            .map(|c| c.message.content)
329            .collect::<Vec<_>>()
330            .await
331            .join("");
332
333        assert_eq!(full_text, "Hello, world!");
334    }
335
336    // ── Object safety ─────────────────────────────────────────────────────────
337
338    #[test]
339    fn trait_is_object_safe() {
340        fn accepts_boxed(_: Box<dyn OllamaClient>) {}
341        accepts_boxed(Box::new(make_mock()));
342    }
343}