oxide-agent 0.1.0

Type-safe, high-performance Rust crate for building agentic systems on Ollama
Documentation
use std::sync::Arc;

use crate::client::OllamaClient;
use crate::error::OxideError;
use crate::types::{ChatRequest, Message, Role};

// ── Configuration ─────────────────────────────────────────────────────────────

/// Rough token estimate: 1 token ≈ 4 UTF-8 characters.
fn estimate_tokens(text: &str) -> usize {
    (text.chars().count() + 3) / 4
}

fn messages_token_count(messages: &[Message]) -> usize {
    messages.iter().map(|m| estimate_tokens(&m.content)).sum()
}

/// What to do when the context window fills up.
#[derive(Debug, Clone)]
pub enum CompressionStrategy {
    /// Drop oldest non-system messages until the budget is met.
    TruncateOldest,
    /// Ask Ollama itself to summarise the oldest half of the history into one
    /// compact system message, then discard the originals.
    Summarize {
        /// Model to use for summarisation (can differ from the chat model).
        model: String,
    },
}

impl Default for CompressionStrategy {
    fn default() -> Self {
        Self::TruncateOldest
    }
}

#[derive(Debug, Clone)]
pub struct SessionConfig {
    /// Soft limit on the estimated number of tokens in the message history.
    /// Compression triggers when history exceeds `max_tokens * threshold`.
    pub max_tokens: usize,
    /// Fraction of `max_tokens` at which compression is triggered (0.0–1.0).
    pub compression_threshold: f32,
    pub compression_strategy: CompressionStrategy,
}

impl Default for SessionConfig {
    fn default() -> Self {
        Self {
            max_tokens: 8_000,
            compression_threshold: 0.80,
            compression_strategy: CompressionStrategy::default(),
        }
    }
}

// ── Session ───────────────────────────────────────────────────────────────────

/// Stateful multi-turn conversation manager.
///
/// Automatically tracks message history and compresses the context when it
/// approaches the configured token budget, so callers never have to think
/// about context windows.
pub struct Session {
    client: Arc<dyn OllamaClient>,
    model: String,
    config: SessionConfig,
    /// Full message history, including system prompt if set.
    messages: Vec<Message>,
}

impl Session {
    pub fn new<C: OllamaClient + 'static>(
        client: Arc<C>,
        model: impl Into<String>,
        config: SessionConfig,
    ) -> Self {
        let client: Arc<dyn OllamaClient> = client;
        Self {
            client,
            model: model.into(),
            config,
            messages: Vec::new(),
        }
    }

    /// Prepend a system prompt. Replaces any existing system message.
    pub fn set_system_prompt(&mut self, prompt: impl Into<String>) {
        self.messages.retain(|m| m.role != Role::System);
        self.messages.insert(
            0,
            Message {
                role: Role::System,
                content: prompt.into(),
                tool_calls: None,
            },
        );
    }

    /// Send a user message and return the assistant's reply.
    /// History is updated automatically on both sides.
    pub async fn ask(&mut self, user_input: impl Into<String>) -> Result<String, OxideError> {
        self.messages.push(Message {
            role: Role::User,
            content: user_input.into(),
            tool_calls: None,
        });

        // Compress before sending if we're over the threshold.
        self.maybe_compress().await?;

        let req = ChatRequest {
            model: self.model.clone(),
            messages: self.messages.clone(),
            tools: None,
            stream: false,
        };

        let resp = self.client.chat(req).await?;
        let content = resp.message.content.clone();

        self.messages.push(resp.message);
        Ok(content)
    }

    /// Expose read-only view of the current history.
    pub fn history(&self) -> &[Message] {
        &self.messages
    }

    /// Estimated tokens currently in the context.
    pub fn estimated_tokens(&self) -> usize {
        messages_token_count(&self.messages)
    }

    // ── Compression ───────────────────────────────────────────────────────────

    async fn maybe_compress(&mut self) -> Result<(), OxideError> {
        let limit = (self.config.max_tokens as f32 * self.config.compression_threshold) as usize;
        if self.estimated_tokens() <= limit {
            return Ok(());
        }

        match &self.config.compression_strategy.clone() {
            CompressionStrategy::TruncateOldest => self.truncate_oldest(limit),
            CompressionStrategy::Summarize { model } => {
                self.summarize_oldest(model.clone(), limit).await?
            }
        }

        Ok(())
    }

    /// Drop oldest non-system messages one at a time until under `limit`.
    fn truncate_oldest(&mut self, limit: usize) {
        while self.estimated_tokens() > limit {
            // Find the first non-system message and remove it.
            let pos = self.messages.iter().position(|m| m.role != Role::System);
            match pos {
                Some(i) => {
                    self.messages.remove(i);
                }
                None => break, // Only system message left; nothing to drop.
            }
        }
    }

    /// Summarise the oldest half of non-system messages using Ollama, replacing
    /// them with a compact summary injected as a system message.
    async fn summarize_oldest(&mut self, model: String, limit: usize) -> Result<(), OxideError> {
        // Collect oldest non-system messages up to half the history.
        let non_system: Vec<usize> = self
            .messages
            .iter()
            .enumerate()
            .filter(|(_, m)| m.role != Role::System)
            .map(|(i, _)| i)
            .collect();

        if non_system.len() < 2 {
            // Fall back to truncation — not enough to summarise.
            self.truncate_oldest(limit);
            return Ok(());
        }

        let half = non_system.len() / 2;
        let to_summarise_indices: Vec<usize> = non_system[..half].to_vec();

        // Build a transcript to summarise.
        let transcript: String = to_summarise_indices
            .iter()
            .map(|&i| {
                let m = &self.messages[i];
                format!("{:?}: {}", m.role, m.content)
            })
            .collect::<Vec<_>>()
            .join("\n");

        let summary_prompt = format!(
            "Summarise the following conversation history concisely, preserving key facts:\n\n{transcript}"
        );

        let summary_req = ChatRequest {
            model: model.clone(),
            messages: vec![Message {
                role: Role::User,
                content: summary_prompt,
                tool_calls: None,
            }],
            tools: None,
            stream: false,
        };

        let summary_resp = self.client.chat(summary_req).await?;
        let summary = summary_resp.message.content;

        // Remove summarised messages (in reverse to preserve indices).
        for &i in to_summarise_indices.iter().rev() {
            self.messages.remove(i);
        }

        // Insert summary as a system message right after any existing system prompt.
        let insert_pos = self
            .messages
            .iter()
            .position(|m| m.role != Role::System)
            .unwrap_or(0);

        self.messages.insert(
            insert_pos,
            Message {
                role: Role::System,
                content: format!("[Conversation summary]\n{summary}"),
                tool_calls: None,
            },
        );

        Ok(())
    }
}

// ── Tests ─────────────────────────────────────────────────────────────────────

#[cfg(test)]
mod tests {
    use super::*;
    use crate::client::OllamaClient;
    use crate::types::{
        ChatResponse, EmbedRequest, EmbedResponse, GenerateRequest, GenerateResponse,
        ListModelsResponse,
    };
    use crate::client::BoxStream;
    use async_trait::async_trait;

    struct EchoClient;

    #[async_trait]
    impl OllamaClient for EchoClient {
        async fn generate(&self, _: GenerateRequest) -> Result<GenerateResponse, OxideError> {
            unimplemented!()
        }
        async fn chat(&self, req: ChatRequest) -> Result<ChatResponse, OxideError> {
            // Echo the last user message back as the assistant reply.
            let last = req.messages.last().unwrap();
            Ok(ChatResponse {
                model: req.model,
                message: Message {
                    role: Role::Assistant,
                    content: format!("echo: {}", last.content),
                    tool_calls: None,
                },
                done: true,
            })
        }
        async fn embed(&self, _: EmbedRequest) -> Result<EmbedResponse, OxideError> {
            unimplemented!()
        }
        async fn list_models(&self) -> Result<ListModelsResponse, OxideError> {
            unimplemented!()
        }
        fn stream_generate(&self, _: GenerateRequest) -> BoxStream<GenerateResponse> {
            unimplemented!()
        }
        fn stream_chat(&self, _: ChatRequest) -> BoxStream<ChatResponse> {
            unimplemented!()
        }
    }

    #[tokio::test]
    async fn session_tracks_history() {
        let mut session = Session::new(
            Arc::new(EchoClient),
            "llama3",
            SessionConfig::default(),
        );

        let reply = session.ask("Hello").await.unwrap();
        assert_eq!(reply, "echo: Hello");
        // user + assistant = 2 messages
        assert_eq!(session.history().len(), 2);

        session.ask("Again").await.unwrap();
        assert_eq!(session.history().len(), 4);
    }

    #[tokio::test]
    async fn system_prompt_is_prepended() {
        let mut session = Session::new(
            Arc::new(EchoClient),
            "llama3",
            SessionConfig::default(),
        );
        session.set_system_prompt("You are helpful.");
        session.ask("Hi").await.unwrap();

        assert_eq!(session.history()[0].role, Role::System);
        assert_eq!(session.history()[1].role, Role::User);
        assert_eq!(session.history()[2].role, Role::Assistant);
    }

    #[tokio::test]
    async fn truncation_drops_oldest_messages() {
        let config = SessionConfig {
            max_tokens: 20,
            compression_threshold: 0.5, // trigger at 10 estimated tokens
            compression_strategy: CompressionStrategy::TruncateOldest,
        };
        let mut session = Session::new(Arc::new(EchoClient), "llama3", config);

        // Each message content is ~4 chars ≈ 1 token. After enough turns the
        // oldest messages should be pruned to stay under the budget.
        for i in 0..15 {
            session.ask(format!("msg{i}")).await.unwrap();
        }

        // History should be well under the max.
        assert!(session.estimated_tokens() <= 20);
    }
}