llmoxide 0.1.0

Provider-agnostic Rust SDK for OpenAI, Anthropic, Gemini, and Ollama (streaming + tools)
Documentation
use crate::error::Result;
use crate::types::{Event, Message, Model, Response, ResponseRequest, Role};

#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct Chat {
    model: Option<Model>,
    max_output_tokens: Option<u32>,
    messages: Vec<Message>,
}

impl Chat {
    pub fn new(model: impl Into<String>) -> Self {
        Self {
            model: Some(Model::new(model)),
            max_output_tokens: None,
            messages: Vec::new(),
        }
    }

    pub fn new_auto() -> Self {
        Self {
            model: None,
            max_output_tokens: None,
            messages: Vec::new(),
        }
    }

    pub fn model(mut self, model: impl Into<String>) -> Self {
        self.model = Some(Model::new(model));
        self
    }

    pub fn max_output_tokens(mut self, max: u32) -> Self {
        self.max_output_tokens = Some(max);
        self
    }

    pub fn set_max_output_tokens(&mut self, max: u32) -> &mut Self {
        self.max_output_tokens = Some(max);
        self
    }

    pub fn push_message(&mut self, message: Message) -> &mut Self {
        self.messages.push(message);
        self
    }

    pub fn push_text(&mut self, role: Role, text: impl Into<String>) -> &mut Self {
        self.push_message(Message::text(role, text))
    }

    pub fn push_user(&mut self, text: impl Into<String>) -> &mut Self {
        self.push_text(Role::User, text)
    }

    pub fn push(&mut self, text: impl Into<String>) -> &mut Self {
        self.push_user(text)
    }

    pub fn messages(&self) -> &[Message] {
        &self.messages
    }

    pub fn into_messages(self) -> Vec<Message> {
        self.messages
    }

    fn to_request(&self) -> ResponseRequest {
        ResponseRequest {
            model: self.model.clone(),
            messages: self.messages.clone(),
            max_output_tokens: self.max_output_tokens,
            tools: Vec::new(),
        }
    }

    pub async fn send(&mut self, client: &crate::client::Client) -> Result<Response> {
        let resp = client.send(self.to_request()).await?;
        self.messages.push(resp.message.clone());
        Ok(resp)
    }

    pub async fn stream<F>(
        &mut self,
        client: &crate::client::Client,
        on_event: F,
    ) -> Result<Response>
    where
        F: FnMut(Event),
    {
        let resp = client.stream(self.to_request(), on_event).await?;
        self.messages.push(resp.message.clone());
        Ok(resp)
    }
}

pub struct ChatSession<'a> {
    client: &'a crate::client::Client,
    chat: Chat,
}

impl<'a> ChatSession<'a> {
    pub(crate) fn new(client: &'a crate::client::Client, chat: Chat) -> Self {
        Self { client, chat }
    }

    pub fn chat(&self) -> &Chat {
        &self.chat
    }

    pub fn chat_mut(&mut self) -> &mut Chat {
        &mut self.chat
    }

    pub fn into_chat(self) -> Chat {
        self.chat
    }

    pub fn push_message(&mut self, message: Message) -> &mut Self {
        self.chat.push_message(message);
        self
    }

    pub fn push_text(&mut self, role: Role, text: impl Into<String>) -> &mut Self {
        self.chat.push_text(role, text);
        self
    }

    pub fn push_user(&mut self, text: impl Into<String>) -> &mut Self {
        self.chat.push_user(text);
        self
    }

    pub fn push(&mut self, text: impl Into<String>) -> &mut Self {
        self.push_user(text)
    }

    pub async fn send(&mut self) -> Result<Response> {
        self.chat.send(self.client).await
    }

    pub async fn stream<F>(&mut self, on_event: F) -> Result<Response>
    where
        F: FnMut(Event),
    {
        self.chat.stream(self.client, on_event).await
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn chat_serde_roundtrip() {
        let mut chat = Chat::new("x");
        chat.push_text(Role::User, "hi");
        chat.push_text(Role::Assistant, "hello");
        chat.set_max_output_tokens(123);

        let bytes = serde_json::to_vec(&chat).unwrap();
        let back: Chat = serde_json::from_slice(&bytes).unwrap();
        assert_eq!(back.messages().len(), 2);
    }
}