use std::pin::Pin;
use async_trait::async_trait;
use futures_util::{Stream, StreamExt, TryStreamExt};
use tokio_util::codec::{FramedRead, LinesCodec};
use tokio_util::io::StreamReader;
use crate::error::OxideError;
use crate::types::{
ChatRequest, ChatResponse, EmbedRequest, EmbedResponse, GenerateRequest, GenerateResponse,
ListModelsResponse,
};
pub type BoxStream<T> = Pin<Box<dyn Stream<Item = Result<T, OxideError>> + Send>>;
#[async_trait]
pub trait OllamaClient: Send + Sync {
async fn generate(&self, req: GenerateRequest) -> Result<GenerateResponse, OxideError>;
async fn chat(&self, req: ChatRequest) -> Result<ChatResponse, OxideError>;
async fn embed(&self, req: EmbedRequest) -> Result<EmbedResponse, OxideError>;
async fn list_models(&self) -> Result<ListModelsResponse, OxideError>;
fn stream_generate(&self, req: GenerateRequest) -> BoxStream<GenerateResponse>;
fn stream_chat(&self, req: ChatRequest) -> BoxStream<ChatResponse>;
}
pub struct HttpOllamaClient {
base_url: String,
http: reqwest::Client,
}
impl HttpOllamaClient {
pub fn new(base_url: impl Into<String>) -> Self {
Self {
base_url: base_url.into(),
http: reqwest::Client::new(),
}
}
fn url(&self, path: &str) -> String {
format!("{}{}", self.base_url.trim_end_matches('/'), path)
}
async fn post_json<B: serde::Serialize>(
&self,
path: &str,
body: &B,
) -> Result<reqwest::Response, OxideError> {
let resp = self
.http
.post(self.url(path))
.json(body)
.send()
.await
.map_err(OxideError::Http)?;
if !resp.status().is_success() {
let status = resp.status().as_u16();
let text = resp.text().await.unwrap_or_default();
return Err(OxideError::ApiError(status, text));
}
Ok(resp)
}
fn ndjson_lines(resp: reqwest::Response) -> impl Stream<Item = Result<String, OxideError>> {
let byte_stream = resp
.bytes_stream()
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e));
let reader = StreamReader::new(byte_stream);
FramedRead::new(reader, LinesCodec::new())
.map_err(|e| OxideError::Other(e.to_string()))
}
}
#[async_trait]
impl OllamaClient for HttpOllamaClient {
async fn generate(&self, mut req: GenerateRequest) -> Result<GenerateResponse, OxideError> {
req.stream = false;
let resp = self.post_json("/api/generate", &req).await?;
resp.json::<GenerateResponse>().await.map_err(OxideError::Http)
}
async fn chat(&self, mut req: ChatRequest) -> Result<ChatResponse, OxideError> {
req.stream = false;
let resp = self.post_json("/api/chat", &req).await?;
resp.json::<ChatResponse>().await.map_err(OxideError::Http)
}
async fn embed(&self, req: EmbedRequest) -> Result<EmbedResponse, OxideError> {
let resp = self.post_json("/api/embed", &req).await?;
resp.json::<EmbedResponse>().await.map_err(OxideError::Http)
}
async fn list_models(&self) -> Result<ListModelsResponse, OxideError> {
let resp = self
.http
.get(self.url("/api/tags"))
.send()
.await
.map_err(OxideError::Http)?;
if !resp.status().is_success() {
return Err(OxideError::ApiError(
resp.status().as_u16(),
resp.text().await.unwrap_or_default(),
));
}
resp.json::<ListModelsResponse>().await.map_err(OxideError::Http)
}
fn stream_generate(&self, mut req: GenerateRequest) -> BoxStream<GenerateResponse> {
req.stream = true;
let http = self.http.clone();
let url = self.url("/api/generate");
Box::pin(async_stream::try_stream! {
let resp = http.post(&url).json(&req).send().await.map_err(OxideError::Http)?;
let status = resp.status();
if status.is_success() {
let mut lines = Self::ndjson_lines(resp);
while let Some(line) = lines.next().await {
let line = line?;
if line.trim().is_empty() { continue; }
let chunk = serde_json::from_str::<GenerateResponse>(&line)
.map_err(OxideError::Serde)?;
yield chunk;
}
} else {
let text = resp.text().await.unwrap_or_default();
Err(OxideError::ApiError(status.as_u16(), text))?;
}
})
}
fn stream_chat(&self, mut req: ChatRequest) -> BoxStream<ChatResponse> {
req.stream = true;
let http = self.http.clone();
let url = self.url("/api/chat");
Box::pin(async_stream::try_stream! {
let resp = http.post(&url).json(&req).send().await.map_err(OxideError::Http)?;
let status = resp.status();
if status.is_success() {
let mut lines = Self::ndjson_lines(resp);
while let Some(line) = lines.next().await {
let line = line?;
if line.trim().is_empty() { continue; }
let chunk = serde_json::from_str::<ChatResponse>(&line)
.map_err(OxideError::Serde)?;
yield chunk;
}
} else {
let text = resp.text().await.unwrap_or_default();
Err(OxideError::ApiError(status.as_u16(), text))?;
}
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::{Message, Role};
use futures_util::StreamExt;
struct MockOllamaClient {
chat_chunks: Vec<ChatResponse>,
}
#[async_trait]
impl OllamaClient for MockOllamaClient {
async fn generate(&self, _: GenerateRequest) -> Result<GenerateResponse, OxideError> {
unimplemented!()
}
async fn chat(&self, _: ChatRequest) -> Result<ChatResponse, OxideError> {
Ok(self.chat_chunks.last().unwrap().clone())
}
async fn embed(&self, _: EmbedRequest) -> Result<EmbedResponse, OxideError> {
unimplemented!()
}
async fn list_models(&self) -> Result<ListModelsResponse, OxideError> {
unimplemented!()
}
fn stream_generate(&self, _: GenerateRequest) -> BoxStream<GenerateResponse> {
unimplemented!()
}
fn stream_chat(&self, _: ChatRequest) -> BoxStream<ChatResponse> {
let chunks: Vec<Result<ChatResponse, OxideError>> =
self.chat_chunks.iter().cloned().map(Ok).collect();
Box::pin(futures_util::stream::iter(chunks))
}
}
fn make_mock() -> MockOllamaClient {
MockOllamaClient {
chat_chunks: vec![
ChatResponse {
model: "llama3".into(),
message: Message { role: Role::Assistant, content: "Hello".into(), tool_calls: None },
done: false,
},
ChatResponse {
model: "llama3".into(),
message: Message { role: Role::Assistant, content: ", world!".into(), tool_calls: None },
done: true,
},
],
}
}
#[tokio::test]
async fn mock_client_returns_canned_response() {
let mock = make_mock();
let req = ChatRequest {
model: "llama3".into(),
messages: vec![Message {
role: Role::User,
content: "Say hello.".into(),
tool_calls: None,
}],
tools: None,
stream: false,
};
let resp = mock.chat(req).await.unwrap();
assert_eq!(resp.message.role, Role::Assistant);
assert!(resp.done);
}
#[tokio::test]
async fn mock_stream_chat_yields_all_chunks() {
let mock = make_mock();
let req = ChatRequest {
model: "llama3".into(),
messages: vec![Message {
role: Role::User,
content: "Say hello.".into(),
tool_calls: None,
}],
tools: None,
stream: true,
};
let chunks: Vec<_> = mock.stream_chat(req).collect().await;
assert_eq!(chunks.len(), 2);
let first = chunks[0].as_ref().unwrap();
assert_eq!(first.message.content, "Hello");
assert!(!first.done);
let last = chunks[1].as_ref().unwrap();
assert_eq!(last.message.content, ", world!");
assert!(last.done);
}
#[tokio::test]
async fn stream_content_matches_buffered_content() {
let mock = make_mock();
let req = ChatRequest {
model: "llama3".into(),
messages: vec![],
tools: None,
stream: true,
};
let full_text: String = mock
.stream_chat(req)
.filter_map(|r| async move { r.ok() })
.map(|c| c.message.content)
.collect::<Vec<_>>()
.await
.join("");
assert_eq!(full_text, "Hello, world!");
}
#[test]
fn trait_is_object_safe() {
fn accepts_boxed(_: Box<dyn OllamaClient>) {}
accepts_boxed(Box::new(make_mock()));
}
}