1use std::pin::Pin;
2
3use async_trait::async_trait;
4use futures_util::{Stream, StreamExt, TryStreamExt};
5use tokio_util::codec::{FramedRead, LinesCodec};
6use tokio_util::io::StreamReader;
7
8use crate::error::OxideError;
9use crate::types::{
10 ChatRequest, ChatResponse, EmbedRequest, EmbedResponse, GenerateRequest, GenerateResponse,
11 ListModelsResponse,
12};
13
14pub type BoxStream<T> = Pin<Box<dyn Stream<Item = Result<T, OxideError>> + Send>>;
16
17#[async_trait]
26pub trait OllamaClient: Send + Sync {
27 async fn generate(&self, req: GenerateRequest) -> Result<GenerateResponse, OxideError>;
31
32 async fn chat(&self, req: ChatRequest) -> Result<ChatResponse, OxideError>;
34
35 async fn embed(&self, req: EmbedRequest) -> Result<EmbedResponse, OxideError>;
37
38 async fn list_models(&self) -> Result<ListModelsResponse, OxideError>;
40
41 fn stream_generate(&self, req: GenerateRequest) -> BoxStream<GenerateResponse>;
50
51 fn stream_chat(&self, req: ChatRequest) -> BoxStream<ChatResponse>;
53}
54
55pub struct HttpOllamaClient {
59 base_url: String,
60 http: reqwest::Client,
61}
62
63impl HttpOllamaClient {
64 pub fn new(base_url: impl Into<String>) -> Self {
66 Self {
67 base_url: base_url.into(),
68 http: reqwest::Client::new(),
69 }
70 }
71
72 fn url(&self, path: &str) -> String {
73 format!("{}{}", self.base_url.trim_end_matches('/'), path)
74 }
75
76 async fn post_json<B: serde::Serialize>(
78 &self,
79 path: &str,
80 body: &B,
81 ) -> Result<reqwest::Response, OxideError> {
82 let resp = self
83 .http
84 .post(self.url(path))
85 .json(body)
86 .send()
87 .await
88 .map_err(OxideError::Http)?;
89
90 if !resp.status().is_success() {
91 let status = resp.status().as_u16();
92 let text = resp.text().await.unwrap_or_default();
93 return Err(OxideError::ApiError(status, text));
94 }
95
96 Ok(resp)
97 }
98
99 fn ndjson_lines(resp: reqwest::Response) -> impl Stream<Item = Result<String, OxideError>> {
106 let byte_stream = resp
107 .bytes_stream()
108 .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e));
109
110 let reader = StreamReader::new(byte_stream);
111 FramedRead::new(reader, LinesCodec::new())
112 .map_err(|e| OxideError::Other(e.to_string()))
113 }
114}
115
116#[async_trait]
117impl OllamaClient for HttpOllamaClient {
118 async fn generate(&self, mut req: GenerateRequest) -> Result<GenerateResponse, OxideError> {
119 req.stream = false;
120 let resp = self.post_json("/api/generate", &req).await?;
121 resp.json::<GenerateResponse>().await.map_err(OxideError::Http)
122 }
123
124 async fn chat(&self, mut req: ChatRequest) -> Result<ChatResponse, OxideError> {
125 req.stream = false;
126 let resp = self.post_json("/api/chat", &req).await?;
127 resp.json::<ChatResponse>().await.map_err(OxideError::Http)
128 }
129
130 async fn embed(&self, req: EmbedRequest) -> Result<EmbedResponse, OxideError> {
131 let resp = self.post_json("/api/embed", &req).await?;
132 resp.json::<EmbedResponse>().await.map_err(OxideError::Http)
133 }
134
135 async fn list_models(&self) -> Result<ListModelsResponse, OxideError> {
136 let resp = self
137 .http
138 .get(self.url("/api/tags"))
139 .send()
140 .await
141 .map_err(OxideError::Http)?;
142
143 if !resp.status().is_success() {
144 return Err(OxideError::ApiError(
145 resp.status().as_u16(),
146 resp.text().await.unwrap_or_default(),
147 ));
148 }
149
150 resp.json::<ListModelsResponse>().await.map_err(OxideError::Http)
151 }
152
153 fn stream_generate(&self, mut req: GenerateRequest) -> BoxStream<GenerateResponse> {
154 req.stream = true;
155 let http = self.http.clone();
156 let url = self.url("/api/generate");
157
158 Box::pin(async_stream::try_stream! {
159 let resp = http.post(&url).json(&req).send().await.map_err(OxideError::Http)?;
160 let status = resp.status();
161 if status.is_success() {
162 let mut lines = Self::ndjson_lines(resp);
163 while let Some(line) = lines.next().await {
164 let line = line?;
165 if line.trim().is_empty() { continue; }
166 let chunk = serde_json::from_str::<GenerateResponse>(&line)
167 .map_err(OxideError::Serde)?;
168 yield chunk;
169 }
170 } else {
171 let text = resp.text().await.unwrap_or_default();
172 Err(OxideError::ApiError(status.as_u16(), text))?;
173 }
174 })
175 }
176
177 fn stream_chat(&self, mut req: ChatRequest) -> BoxStream<ChatResponse> {
178 req.stream = true;
179 let http = self.http.clone();
180 let url = self.url("/api/chat");
181
182 Box::pin(async_stream::try_stream! {
183 let resp = http.post(&url).json(&req).send().await.map_err(OxideError::Http)?;
184 let status = resp.status();
185 if status.is_success() {
186 let mut lines = Self::ndjson_lines(resp);
187 while let Some(line) = lines.next().await {
188 let line = line?;
189 if line.trim().is_empty() { continue; }
190 let chunk = serde_json::from_str::<ChatResponse>(&line)
191 .map_err(OxideError::Serde)?;
192 yield chunk;
193 }
194 } else {
195 let text = resp.text().await.unwrap_or_default();
196 Err(OxideError::ApiError(status.as_u16(), text))?;
197 }
198 })
199 }
200}
201
202#[cfg(test)]
205mod tests {
206 use super::*;
207 use crate::types::{Message, Role};
208 use futures_util::StreamExt;
209
210 struct MockOllamaClient {
213 chat_chunks: Vec<ChatResponse>,
214 }
215
216 #[async_trait]
217 impl OllamaClient for MockOllamaClient {
218 async fn generate(&self, _: GenerateRequest) -> Result<GenerateResponse, OxideError> {
219 unimplemented!()
220 }
221
222 async fn chat(&self, _: ChatRequest) -> Result<ChatResponse, OxideError> {
223 Ok(self.chat_chunks.last().unwrap().clone())
225 }
226
227 async fn embed(&self, _: EmbedRequest) -> Result<EmbedResponse, OxideError> {
228 unimplemented!()
229 }
230
231 async fn list_models(&self) -> Result<ListModelsResponse, OxideError> {
232 unimplemented!()
233 }
234
235 fn stream_generate(&self, _: GenerateRequest) -> BoxStream<GenerateResponse> {
236 unimplemented!()
237 }
238
239 fn stream_chat(&self, _: ChatRequest) -> BoxStream<ChatResponse> {
240 let chunks: Vec<Result<ChatResponse, OxideError>> =
241 self.chat_chunks.iter().cloned().map(Ok).collect();
242 Box::pin(futures_util::stream::iter(chunks))
243 }
244 }
245
246 fn make_mock() -> MockOllamaClient {
247 MockOllamaClient {
248 chat_chunks: vec![
249 ChatResponse {
250 model: "llama3".into(),
251 message: Message { role: Role::Assistant, content: "Hello".into(), tool_calls: None },
252 done: false,
253 },
254 ChatResponse {
255 model: "llama3".into(),
256 message: Message { role: Role::Assistant, content: ", world!".into(), tool_calls: None },
257 done: true,
258 },
259 ],
260 }
261 }
262
263 #[tokio::test]
266 async fn mock_client_returns_canned_response() {
267 let mock = make_mock();
268 let req = ChatRequest {
269 model: "llama3".into(),
270 messages: vec![Message {
271 role: Role::User,
272 content: "Say hello.".into(),
273 tool_calls: None,
274 }],
275 tools: None,
276 stream: false,
277 };
278
279 let resp = mock.chat(req).await.unwrap();
280 assert_eq!(resp.message.role, Role::Assistant);
281 assert!(resp.done);
282 }
283
284 #[tokio::test]
289 async fn mock_stream_chat_yields_all_chunks() {
290 let mock = make_mock();
291 let req = ChatRequest {
292 model: "llama3".into(),
293 messages: vec![Message {
294 role: Role::User,
295 content: "Say hello.".into(),
296 tool_calls: None,
297 }],
298 tools: None,
299 stream: true,
300 };
301
302 let chunks: Vec<_> = mock.stream_chat(req).collect().await;
303 assert_eq!(chunks.len(), 2);
304
305 let first = chunks[0].as_ref().unwrap();
306 assert_eq!(first.message.content, "Hello");
307 assert!(!first.done);
308
309 let last = chunks[1].as_ref().unwrap();
310 assert_eq!(last.message.content, ", world!");
311 assert!(last.done);
312 }
313
314 #[tokio::test]
316 async fn stream_content_matches_buffered_content() {
317 let mock = make_mock();
318 let req = ChatRequest {
319 model: "llama3".into(),
320 messages: vec![],
321 tools: None,
322 stream: true,
323 };
324
325 let full_text: String = mock
326 .stream_chat(req)
327 .filter_map(|r| async move { r.ok() })
328 .map(|c| c.message.content)
329 .collect::<Vec<_>>()
330 .await
331 .join("");
332
333 assert_eq!(full_text, "Hello, world!");
334 }
335
336 #[test]
339 fn trait_is_object_safe() {
340 fn accepts_boxed(_: Box<dyn OllamaClient>) {}
341 accepts_boxed(Box::new(make_mock()));
342 }
343}