1use std::sync::Arc;
2
3use crate::client::OllamaClient;
4use crate::error::OxideError;
5use crate::types::{ChatRequest, Message, Role};
6
7fn estimate_tokens(text: &str) -> usize {
11 (text.chars().count() + 3) / 4
12}
13
14fn messages_token_count(messages: &[Message]) -> usize {
15 messages.iter().map(|m| estimate_tokens(&m.content)).sum()
16}
17
18#[derive(Debug, Clone)]
20pub enum CompressionStrategy {
21 TruncateOldest,
23 Summarize {
26 model: String,
28 },
29}
30
31impl Default for CompressionStrategy {
32 fn default() -> Self {
33 Self::TruncateOldest
34 }
35}
36
37#[derive(Debug, Clone)]
38pub struct SessionConfig {
39 pub max_tokens: usize,
42 pub compression_threshold: f32,
44 pub compression_strategy: CompressionStrategy,
45}
46
47impl Default for SessionConfig {
48 fn default() -> Self {
49 Self {
50 max_tokens: 8_000,
51 compression_threshold: 0.80,
52 compression_strategy: CompressionStrategy::default(),
53 }
54 }
55}
56
57pub struct Session {
65 client: Arc<dyn OllamaClient>,
66 model: String,
67 config: SessionConfig,
68 messages: Vec<Message>,
70}
71
72impl Session {
73 pub fn new<C: OllamaClient + 'static>(
74 client: Arc<C>,
75 model: impl Into<String>,
76 config: SessionConfig,
77 ) -> Self {
78 let client: Arc<dyn OllamaClient> = client;
79 Self {
80 client,
81 model: model.into(),
82 config,
83 messages: Vec::new(),
84 }
85 }
86
87 pub fn set_system_prompt(&mut self, prompt: impl Into<String>) {
89 self.messages.retain(|m| m.role != Role::System);
90 self.messages.insert(
91 0,
92 Message {
93 role: Role::System,
94 content: prompt.into(),
95 tool_calls: None,
96 },
97 );
98 }
99
100 pub async fn ask(&mut self, user_input: impl Into<String>) -> Result<String, OxideError> {
103 self.messages.push(Message {
104 role: Role::User,
105 content: user_input.into(),
106 tool_calls: None,
107 });
108
109 self.maybe_compress().await?;
111
112 let req = ChatRequest {
113 model: self.model.clone(),
114 messages: self.messages.clone(),
115 tools: None,
116 stream: false,
117 };
118
119 let resp = self.client.chat(req).await?;
120 let content = resp.message.content.clone();
121
122 self.messages.push(resp.message);
123 Ok(content)
124 }
125
126 pub fn history(&self) -> &[Message] {
128 &self.messages
129 }
130
131 pub fn estimated_tokens(&self) -> usize {
133 messages_token_count(&self.messages)
134 }
135
136 async fn maybe_compress(&mut self) -> Result<(), OxideError> {
139 let limit = (self.config.max_tokens as f32 * self.config.compression_threshold) as usize;
140 if self.estimated_tokens() <= limit {
141 return Ok(());
142 }
143
144 match &self.config.compression_strategy.clone() {
145 CompressionStrategy::TruncateOldest => self.truncate_oldest(limit),
146 CompressionStrategy::Summarize { model } => {
147 self.summarize_oldest(model.clone(), limit).await?
148 }
149 }
150
151 Ok(())
152 }
153
154 fn truncate_oldest(&mut self, limit: usize) {
156 while self.estimated_tokens() > limit {
157 let pos = self.messages.iter().position(|m| m.role != Role::System);
159 match pos {
160 Some(i) => {
161 self.messages.remove(i);
162 }
163 None => break, }
165 }
166 }
167
168 async fn summarize_oldest(&mut self, model: String, limit: usize) -> Result<(), OxideError> {
171 let non_system: Vec<usize> = self
173 .messages
174 .iter()
175 .enumerate()
176 .filter(|(_, m)| m.role != Role::System)
177 .map(|(i, _)| i)
178 .collect();
179
180 if non_system.len() < 2 {
181 self.truncate_oldest(limit);
183 return Ok(());
184 }
185
186 let half = non_system.len() / 2;
187 let to_summarise_indices: Vec<usize> = non_system[..half].to_vec();
188
189 let transcript: String = to_summarise_indices
191 .iter()
192 .map(|&i| {
193 let m = &self.messages[i];
194 format!("{:?}: {}", m.role, m.content)
195 })
196 .collect::<Vec<_>>()
197 .join("\n");
198
199 let summary_prompt = format!(
200 "Summarise the following conversation history concisely, preserving key facts:\n\n{transcript}"
201 );
202
203 let summary_req = ChatRequest {
204 model: model.clone(),
205 messages: vec![Message {
206 role: Role::User,
207 content: summary_prompt,
208 tool_calls: None,
209 }],
210 tools: None,
211 stream: false,
212 };
213
214 let summary_resp = self.client.chat(summary_req).await?;
215 let summary = summary_resp.message.content;
216
217 for &i in to_summarise_indices.iter().rev() {
219 self.messages.remove(i);
220 }
221
222 let insert_pos = self
224 .messages
225 .iter()
226 .position(|m| m.role != Role::System)
227 .unwrap_or(0);
228
229 self.messages.insert(
230 insert_pos,
231 Message {
232 role: Role::System,
233 content: format!("[Conversation summary]\n{summary}"),
234 tool_calls: None,
235 },
236 );
237
238 Ok(())
239 }
240}
241
242#[cfg(test)]
245mod tests {
246 use super::*;
247 use crate::client::OllamaClient;
248 use crate::types::{
249 ChatResponse, EmbedRequest, EmbedResponse, GenerateRequest, GenerateResponse,
250 ListModelsResponse,
251 };
252 use crate::client::BoxStream;
253 use async_trait::async_trait;
254
255 struct EchoClient;
256
257 #[async_trait]
258 impl OllamaClient for EchoClient {
259 async fn generate(&self, _: GenerateRequest) -> Result<GenerateResponse, OxideError> {
260 unimplemented!()
261 }
262 async fn chat(&self, req: ChatRequest) -> Result<ChatResponse, OxideError> {
263 let last = req.messages.last().unwrap();
265 Ok(ChatResponse {
266 model: req.model,
267 message: Message {
268 role: Role::Assistant,
269 content: format!("echo: {}", last.content),
270 tool_calls: None,
271 },
272 done: true,
273 })
274 }
275 async fn embed(&self, _: EmbedRequest) -> Result<EmbedResponse, OxideError> {
276 unimplemented!()
277 }
278 async fn list_models(&self) -> Result<ListModelsResponse, OxideError> {
279 unimplemented!()
280 }
281 fn stream_generate(&self, _: GenerateRequest) -> BoxStream<GenerateResponse> {
282 unimplemented!()
283 }
284 fn stream_chat(&self, _: ChatRequest) -> BoxStream<ChatResponse> {
285 unimplemented!()
286 }
287 }
288
289 #[tokio::test]
290 async fn session_tracks_history() {
291 let mut session = Session::new(
292 Arc::new(EchoClient),
293 "llama3",
294 SessionConfig::default(),
295 );
296
297 let reply = session.ask("Hello").await.unwrap();
298 assert_eq!(reply, "echo: Hello");
299 assert_eq!(session.history().len(), 2);
301
302 session.ask("Again").await.unwrap();
303 assert_eq!(session.history().len(), 4);
304 }
305
306 #[tokio::test]
307 async fn system_prompt_is_prepended() {
308 let mut session = Session::new(
309 Arc::new(EchoClient),
310 "llama3",
311 SessionConfig::default(),
312 );
313 session.set_system_prompt("You are helpful.");
314 session.ask("Hi").await.unwrap();
315
316 assert_eq!(session.history()[0].role, Role::System);
317 assert_eq!(session.history()[1].role, Role::User);
318 assert_eq!(session.history()[2].role, Role::Assistant);
319 }
320
321 #[tokio::test]
322 async fn truncation_drops_oldest_messages() {
323 let config = SessionConfig {
324 max_tokens: 20,
325 compression_threshold: 0.5, compression_strategy: CompressionStrategy::TruncateOldest,
327 };
328 let mut session = Session::new(Arc::new(EchoClient), "llama3", config);
329
330 for i in 0..15 {
333 session.ask(format!("msg{i}")).await.unwrap();
334 }
335
336 assert!(session.estimated_tokens() <= 20);
338 }
339}