Skip to main content

oxide_agent/session/
mod.rs

1use std::sync::Arc;
2
3use crate::client::OllamaClient;
4use crate::error::OxideError;
5use crate::types::{ChatRequest, Message, Role};
6
7// ── Configuration ─────────────────────────────────────────────────────────────
8
9/// Rough token estimate: 1 token ≈ 4 UTF-8 characters.
10fn estimate_tokens(text: &str) -> usize {
11    (text.chars().count() + 3) / 4
12}
13
14fn messages_token_count(messages: &[Message]) -> usize {
15    messages.iter().map(|m| estimate_tokens(&m.content)).sum()
16}
17
18/// What to do when the context window fills up.
19#[derive(Debug, Clone)]
20pub enum CompressionStrategy {
21    /// Drop oldest non-system messages until the budget is met.
22    TruncateOldest,
23    /// Ask Ollama itself to summarise the oldest half of the history into one
24    /// compact system message, then discard the originals.
25    Summarize {
26        /// Model to use for summarisation (can differ from the chat model).
27        model: String,
28    },
29}
30
31impl Default for CompressionStrategy {
32    fn default() -> Self {
33        Self::TruncateOldest
34    }
35}
36
37#[derive(Debug, Clone)]
38pub struct SessionConfig {
39    /// Soft limit on the estimated number of tokens in the message history.
40    /// Compression triggers when history exceeds `max_tokens * threshold`.
41    pub max_tokens: usize,
42    /// Fraction of `max_tokens` at which compression is triggered (0.0–1.0).
43    pub compression_threshold: f32,
44    pub compression_strategy: CompressionStrategy,
45}
46
47impl Default for SessionConfig {
48    fn default() -> Self {
49        Self {
50            max_tokens: 8_000,
51            compression_threshold: 0.80,
52            compression_strategy: CompressionStrategy::default(),
53        }
54    }
55}
56
57// ── Session ───────────────────────────────────────────────────────────────────
58
59/// Stateful multi-turn conversation manager.
60///
61/// Automatically tracks message history and compresses the context when it
62/// approaches the configured token budget, so callers never have to think
63/// about context windows.
64pub struct Session {
65    client: Arc<dyn OllamaClient>,
66    model: String,
67    config: SessionConfig,
68    /// Full message history, including system prompt if set.
69    messages: Vec<Message>,
70}
71
72impl Session {
73    pub fn new<C: OllamaClient + 'static>(
74        client: Arc<C>,
75        model: impl Into<String>,
76        config: SessionConfig,
77    ) -> Self {
78        let client: Arc<dyn OllamaClient> = client;
79        Self {
80            client,
81            model: model.into(),
82            config,
83            messages: Vec::new(),
84        }
85    }
86
87    /// Prepend a system prompt. Replaces any existing system message.
88    pub fn set_system_prompt(&mut self, prompt: impl Into<String>) {
89        self.messages.retain(|m| m.role != Role::System);
90        self.messages.insert(
91            0,
92            Message {
93                role: Role::System,
94                content: prompt.into(),
95                tool_calls: None,
96            },
97        );
98    }
99
100    /// Send a user message and return the assistant's reply.
101    /// History is updated automatically on both sides.
102    pub async fn ask(&mut self, user_input: impl Into<String>) -> Result<String, OxideError> {
103        self.messages.push(Message {
104            role: Role::User,
105            content: user_input.into(),
106            tool_calls: None,
107        });
108
109        // Compress before sending if we're over the threshold.
110        self.maybe_compress().await?;
111
112        let req = ChatRequest {
113            model: self.model.clone(),
114            messages: self.messages.clone(),
115            tools: None,
116            stream: false,
117        };
118
119        let resp = self.client.chat(req).await?;
120        let content = resp.message.content.clone();
121
122        self.messages.push(resp.message);
123        Ok(content)
124    }
125
126    /// Expose read-only view of the current history.
127    pub fn history(&self) -> &[Message] {
128        &self.messages
129    }
130
131    /// Estimated tokens currently in the context.
132    pub fn estimated_tokens(&self) -> usize {
133        messages_token_count(&self.messages)
134    }
135
136    // ── Compression ───────────────────────────────────────────────────────────
137
138    async fn maybe_compress(&mut self) -> Result<(), OxideError> {
139        let limit = (self.config.max_tokens as f32 * self.config.compression_threshold) as usize;
140        if self.estimated_tokens() <= limit {
141            return Ok(());
142        }
143
144        match &self.config.compression_strategy.clone() {
145            CompressionStrategy::TruncateOldest => self.truncate_oldest(limit),
146            CompressionStrategy::Summarize { model } => {
147                self.summarize_oldest(model.clone(), limit).await?
148            }
149        }
150
151        Ok(())
152    }
153
154    /// Drop oldest non-system messages one at a time until under `limit`.
155    fn truncate_oldest(&mut self, limit: usize) {
156        while self.estimated_tokens() > limit {
157            // Find the first non-system message and remove it.
158            let pos = self.messages.iter().position(|m| m.role != Role::System);
159            match pos {
160                Some(i) => {
161                    self.messages.remove(i);
162                }
163                None => break, // Only system message left; nothing to drop.
164            }
165        }
166    }
167
168    /// Summarise the oldest half of non-system messages using Ollama, replacing
169    /// them with a compact summary injected as a system message.
170    async fn summarize_oldest(&mut self, model: String, limit: usize) -> Result<(), OxideError> {
171        // Collect oldest non-system messages up to half the history.
172        let non_system: Vec<usize> = self
173            .messages
174            .iter()
175            .enumerate()
176            .filter(|(_, m)| m.role != Role::System)
177            .map(|(i, _)| i)
178            .collect();
179
180        if non_system.len() < 2 {
181            // Fall back to truncation — not enough to summarise.
182            self.truncate_oldest(limit);
183            return Ok(());
184        }
185
186        let half = non_system.len() / 2;
187        let to_summarise_indices: Vec<usize> = non_system[..half].to_vec();
188
189        // Build a transcript to summarise.
190        let transcript: String = to_summarise_indices
191            .iter()
192            .map(|&i| {
193                let m = &self.messages[i];
194                format!("{:?}: {}", m.role, m.content)
195            })
196            .collect::<Vec<_>>()
197            .join("\n");
198
199        let summary_prompt = format!(
200            "Summarise the following conversation history concisely, preserving key facts:\n\n{transcript}"
201        );
202
203        let summary_req = ChatRequest {
204            model: model.clone(),
205            messages: vec![Message {
206                role: Role::User,
207                content: summary_prompt,
208                tool_calls: None,
209            }],
210            tools: None,
211            stream: false,
212        };
213
214        let summary_resp = self.client.chat(summary_req).await?;
215        let summary = summary_resp.message.content;
216
217        // Remove summarised messages (in reverse to preserve indices).
218        for &i in to_summarise_indices.iter().rev() {
219            self.messages.remove(i);
220        }
221
222        // Insert summary as a system message right after any existing system prompt.
223        let insert_pos = self
224            .messages
225            .iter()
226            .position(|m| m.role != Role::System)
227            .unwrap_or(0);
228
229        self.messages.insert(
230            insert_pos,
231            Message {
232                role: Role::System,
233                content: format!("[Conversation summary]\n{summary}"),
234                tool_calls: None,
235            },
236        );
237
238        Ok(())
239    }
240}
241
242// ── Tests ─────────────────────────────────────────────────────────────────────
243
244#[cfg(test)]
245mod tests {
246    use super::*;
247    use crate::client::OllamaClient;
248    use crate::types::{
249        ChatResponse, EmbedRequest, EmbedResponse, GenerateRequest, GenerateResponse,
250        ListModelsResponse,
251    };
252    use crate::client::BoxStream;
253    use async_trait::async_trait;
254
255    struct EchoClient;
256
257    #[async_trait]
258    impl OllamaClient for EchoClient {
259        async fn generate(&self, _: GenerateRequest) -> Result<GenerateResponse, OxideError> {
260            unimplemented!()
261        }
262        async fn chat(&self, req: ChatRequest) -> Result<ChatResponse, OxideError> {
263            // Echo the last user message back as the assistant reply.
264            let last = req.messages.last().unwrap();
265            Ok(ChatResponse {
266                model: req.model,
267                message: Message {
268                    role: Role::Assistant,
269                    content: format!("echo: {}", last.content),
270                    tool_calls: None,
271                },
272                done: true,
273            })
274        }
275        async fn embed(&self, _: EmbedRequest) -> Result<EmbedResponse, OxideError> {
276            unimplemented!()
277        }
278        async fn list_models(&self) -> Result<ListModelsResponse, OxideError> {
279            unimplemented!()
280        }
281        fn stream_generate(&self, _: GenerateRequest) -> BoxStream<GenerateResponse> {
282            unimplemented!()
283        }
284        fn stream_chat(&self, _: ChatRequest) -> BoxStream<ChatResponse> {
285            unimplemented!()
286        }
287    }
288
289    #[tokio::test]
290    async fn session_tracks_history() {
291        let mut session = Session::new(
292            Arc::new(EchoClient),
293            "llama3",
294            SessionConfig::default(),
295        );
296
297        let reply = session.ask("Hello").await.unwrap();
298        assert_eq!(reply, "echo: Hello");
299        // user + assistant = 2 messages
300        assert_eq!(session.history().len(), 2);
301
302        session.ask("Again").await.unwrap();
303        assert_eq!(session.history().len(), 4);
304    }
305
306    #[tokio::test]
307    async fn system_prompt_is_prepended() {
308        let mut session = Session::new(
309            Arc::new(EchoClient),
310            "llama3",
311            SessionConfig::default(),
312        );
313        session.set_system_prompt("You are helpful.");
314        session.ask("Hi").await.unwrap();
315
316        assert_eq!(session.history()[0].role, Role::System);
317        assert_eq!(session.history()[1].role, Role::User);
318        assert_eq!(session.history()[2].role, Role::Assistant);
319    }
320
321    #[tokio::test]
322    async fn truncation_drops_oldest_messages() {
323        let config = SessionConfig {
324            max_tokens: 20,
325            compression_threshold: 0.5, // trigger at 10 estimated tokens
326            compression_strategy: CompressionStrategy::TruncateOldest,
327        };
328        let mut session = Session::new(Arc::new(EchoClient), "llama3", config);
329
330        // Each message content is ~4 chars ≈ 1 token. After enough turns the
331        // oldest messages should be pruned to stay under the budget.
332        for i in 0..15 {
333            session.ask(format!("msg{i}")).await.unwrap();
334        }
335
336        // History should be well under the max.
337        assert!(session.estimated_tokens() <= 20);
338    }
339}