oxide-agent 0.1.0

Type-safe, high-performance Rust crate for building agentic systems on Ollama
Documentation
use std::pin::Pin;

use async_trait::async_trait;
use futures_util::{Stream, StreamExt, TryStreamExt};
use tokio_util::codec::{FramedRead, LinesCodec};
use tokio_util::io::StreamReader;

use crate::error::OxideError;
use crate::types::{
    ChatRequest, ChatResponse, EmbedRequest, EmbedResponse, GenerateRequest, GenerateResponse,
    ListModelsResponse,
};

/// Convenience alias for a pinned, owned, send-able stream of fallible items.
pub type BoxStream<T> = Pin<Box<dyn Stream<Item = Result<T, OxideError>> + Send>>;

// ── The core trait ────────────────────────────────────────────────────────────

/// The type-safe bridge to any Ollama-compatible backend.
///
/// Decoupling agent logic from transport via a trait means you can:
/// - Swap in a real HTTP client for production.
/// - Swap in a `MockOllamaClient` for unit testing — no running server needed.
/// - Wrap the client with middleware (retry, logging, rate-limiting) transparently.
#[async_trait]
pub trait OllamaClient: Send + Sync {
    // ── Blocking (buffered) variants ──────────────────────────────────────────

    /// Single-turn text completion (`/api/generate`).
    async fn generate(&self, req: GenerateRequest) -> Result<GenerateResponse, OxideError>;

    /// Multi-turn chat completion (`/api/chat`), including tool/function calling.
    async fn chat(&self, req: ChatRequest) -> Result<ChatResponse, OxideError>;

    /// Produce dense vector embeddings for one or more texts (`/api/embed`).
    async fn embed(&self, req: EmbedRequest) -> Result<EmbedResponse, OxideError>;

    /// List models available on the Ollama server (`/api/tags`).
    async fn list_models(&self) -> Result<ListModelsResponse, OxideError>;

    // ── Zero-copy streaming variants ──────────────────────────────────────────

    /// Stream a generate response token-by-token.
    ///
    /// Ollama sends newline-delimited JSON (`NDJSON`) when `stream: true`.
    /// This method decodes each line into a `GenerateResponse` without buffering
    /// the entire response body, keeping memory overhead minimal regardless of
    /// output length.
    fn stream_generate(&self, req: GenerateRequest) -> BoxStream<GenerateResponse>;

    /// Stream a chat response token-by-token.
    fn stream_chat(&self, req: ChatRequest) -> BoxStream<ChatResponse>;
}

// ── HTTP implementation ───────────────────────────────────────────────────────

/// Production client that speaks to a real Ollama server over HTTP.
pub struct HttpOllamaClient {
    base_url: String,
    http: reqwest::Client,
}

impl HttpOllamaClient {
    /// `base_url` is typically `"http://localhost:11434"`.
    pub fn new(base_url: impl Into<String>) -> Self {
        Self {
            base_url: base_url.into(),
            http: reqwest::Client::new(),
        }
    }

    fn url(&self, path: &str) -> String {
        format!("{}{}", self.base_url.trim_end_matches('/'), path)
    }

    /// Shared helper: POST JSON, check status, return the response.
    async fn post_json<B: serde::Serialize>(
        &self,
        path: &str,
        body: &B,
    ) -> Result<reqwest::Response, OxideError> {
        let resp = self
            .http
            .post(self.url(path))
            .json(body)
            .send()
            .await
            .map_err(OxideError::Http)?;

        if !resp.status().is_success() {
            let status = resp.status().as_u16();
            let text = resp.text().await.unwrap_or_default();
            return Err(OxideError::ApiError(status, text));
        }

        Ok(resp)
    }

    /// Convert a reqwest response body into a line-by-line NDJSON stream.
    ///
    /// - `bytes_stream()` gives `Stream<Item = Result<Bytes, reqwest::Error>>`
    /// - `StreamReader` adapts that to `AsyncRead` (zero-copy: the `Bytes`
    ///   chunks from reqwest are forwarded directly without re-allocation)
    /// - `FramedRead` + `LinesCodec` splits on `\n` without extra heap copies
    fn ndjson_lines(resp: reqwest::Response) -> impl Stream<Item = Result<String, OxideError>> {
        let byte_stream = resp
            .bytes_stream()
            .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e));

        let reader = StreamReader::new(byte_stream);
        FramedRead::new(reader, LinesCodec::new())
            .map_err(|e| OxideError::Other(e.to_string()))
    }
}

#[async_trait]
impl OllamaClient for HttpOllamaClient {
    async fn generate(&self, mut req: GenerateRequest) -> Result<GenerateResponse, OxideError> {
        req.stream = false;
        let resp = self.post_json("/api/generate", &req).await?;
        resp.json::<GenerateResponse>().await.map_err(OxideError::Http)
    }

    async fn chat(&self, mut req: ChatRequest) -> Result<ChatResponse, OxideError> {
        req.stream = false;
        let resp = self.post_json("/api/chat", &req).await?;
        resp.json::<ChatResponse>().await.map_err(OxideError::Http)
    }

    async fn embed(&self, req: EmbedRequest) -> Result<EmbedResponse, OxideError> {
        let resp = self.post_json("/api/embed", &req).await?;
        resp.json::<EmbedResponse>().await.map_err(OxideError::Http)
    }

    async fn list_models(&self) -> Result<ListModelsResponse, OxideError> {
        let resp = self
            .http
            .get(self.url("/api/tags"))
            .send()
            .await
            .map_err(OxideError::Http)?;

        if !resp.status().is_success() {
            return Err(OxideError::ApiError(
                resp.status().as_u16(),
                resp.text().await.unwrap_or_default(),
            ));
        }

        resp.json::<ListModelsResponse>().await.map_err(OxideError::Http)
    }

    fn stream_generate(&self, mut req: GenerateRequest) -> BoxStream<GenerateResponse> {
        req.stream = true;
        let http = self.http.clone();
        let url = self.url("/api/generate");

        Box::pin(async_stream::try_stream! {
            let resp = http.post(&url).json(&req).send().await.map_err(OxideError::Http)?;
            let status = resp.status();
            if status.is_success() {
                let mut lines = Self::ndjson_lines(resp);
                while let Some(line) = lines.next().await {
                    let line = line?;
                    if line.trim().is_empty() { continue; }
                    let chunk = serde_json::from_str::<GenerateResponse>(&line)
                        .map_err(OxideError::Serde)?;
                    yield chunk;
                }
            } else {
                let text = resp.text().await.unwrap_or_default();
                Err(OxideError::ApiError(status.as_u16(), text))?;
            }
        })
    }

    fn stream_chat(&self, mut req: ChatRequest) -> BoxStream<ChatResponse> {
        req.stream = true;
        let http = self.http.clone();
        let url = self.url("/api/chat");

        Box::pin(async_stream::try_stream! {
            let resp = http.post(&url).json(&req).send().await.map_err(OxideError::Http)?;
            let status = resp.status();
            if status.is_success() {
                let mut lines = Self::ndjson_lines(resp);
                while let Some(line) = lines.next().await {
                    let line = line?;
                    if line.trim().is_empty() { continue; }
                    let chunk = serde_json::from_str::<ChatResponse>(&line)
                        .map_err(OxideError::Serde)?;
                    yield chunk;
                }
            } else {
                let text = resp.text().await.unwrap_or_default();
                Err(OxideError::ApiError(status.as_u16(), text))?;
            }
        })
    }
}

// ── Tests ─────────────────────────────────────────────────────────────────────

#[cfg(test)]
mod tests {
    use super::*;
    use crate::types::{Message, Role};
    use futures_util::StreamExt;

    // ── Shared mock ───────────────────────────────────────────────────────────

    struct MockOllamaClient {
        chat_chunks: Vec<ChatResponse>,
    }

    #[async_trait]
    impl OllamaClient for MockOllamaClient {
        async fn generate(&self, _: GenerateRequest) -> Result<GenerateResponse, OxideError> {
            unimplemented!()
        }

        async fn chat(&self, _: ChatRequest) -> Result<ChatResponse, OxideError> {
            // Return the last (done=true) chunk as the buffered response.
            Ok(self.chat_chunks.last().unwrap().clone())
        }

        async fn embed(&self, _: EmbedRequest) -> Result<EmbedResponse, OxideError> {
            unimplemented!()
        }

        async fn list_models(&self) -> Result<ListModelsResponse, OxideError> {
            unimplemented!()
        }

        fn stream_generate(&self, _: GenerateRequest) -> BoxStream<GenerateResponse> {
            unimplemented!()
        }

        fn stream_chat(&self, _: ChatRequest) -> BoxStream<ChatResponse> {
            let chunks: Vec<Result<ChatResponse, OxideError>> =
                self.chat_chunks.iter().cloned().map(Ok).collect();
            Box::pin(futures_util::stream::iter(chunks))
        }
    }

    fn make_mock() -> MockOllamaClient {
        MockOllamaClient {
            chat_chunks: vec![
                ChatResponse {
                    model: "llama3".into(),
                    message: Message { role: Role::Assistant, content: "Hello".into(), tool_calls: None },
                    done: false,
                },
                ChatResponse {
                    model: "llama3".into(),
                    message: Message { role: Role::Assistant, content: ", world!".into(), tool_calls: None },
                    done: true,
                },
            ],
        }
    }

    // ── Buffered chat ─────────────────────────────────────────────────────────

    #[tokio::test]
    async fn mock_client_returns_canned_response() {
        let mock = make_mock();
        let req = ChatRequest {
            model: "llama3".into(),
            messages: vec![Message {
                role: Role::User,
                content: "Say hello.".into(),
                tool_calls: None,
            }],
            tools: None,
            stream: false,
        };

        let resp = mock.chat(req).await.unwrap();
        assert_eq!(resp.message.role, Role::Assistant);
        assert!(resp.done);
    }

    // ── Streaming chat ────────────────────────────────────────────────────────

    /// Proves the streaming path yields all chunks in order and the final
    /// chunk carries `done: true` — without a live server.
    #[tokio::test]
    async fn mock_stream_chat_yields_all_chunks() {
        let mock = make_mock();
        let req = ChatRequest {
            model: "llama3".into(),
            messages: vec![Message {
                role: Role::User,
                content: "Say hello.".into(),
                tool_calls: None,
            }],
            tools: None,
            stream: true,
        };

        let chunks: Vec<_> = mock.stream_chat(req).collect().await;
        assert_eq!(chunks.len(), 2);

        let first = chunks[0].as_ref().unwrap();
        assert_eq!(first.message.content, "Hello");
        assert!(!first.done);

        let last = chunks[1].as_ref().unwrap();
        assert_eq!(last.message.content, ", world!");
        assert!(last.done);
    }

    /// Accumulating chunk content mirrors what the buffered response would return.
    #[tokio::test]
    async fn stream_content_matches_buffered_content() {
        let mock = make_mock();
        let req = ChatRequest {
            model: "llama3".into(),
            messages: vec![],
            tools: None,
            stream: true,
        };

        let full_text: String = mock
            .stream_chat(req)
            .filter_map(|r| async move { r.ok() })
            .map(|c| c.message.content)
            .collect::<Vec<_>>()
            .await
            .join("");

        assert_eq!(full_text, "Hello, world!");
    }

    // ── Object safety ─────────────────────────────────────────────────────────

    #[test]
    fn trait_is_object_safe() {
        fn accepts_boxed(_: Box<dyn OllamaClient>) {}
        accepts_boxed(Box::new(make_mock()));
    }
}