use std::sync::Arc;
use crate::client::OllamaClient;
use crate::error::OxideError;
use crate::types::{ChatRequest, Message, Role};
fn estimate_tokens(text: &str) -> usize {
(text.chars().count() + 3) / 4
}
fn messages_token_count(messages: &[Message]) -> usize {
messages.iter().map(|m| estimate_tokens(&m.content)).sum()
}
#[derive(Debug, Clone)]
pub enum CompressionStrategy {
TruncateOldest,
Summarize {
model: String,
},
}
impl Default for CompressionStrategy {
fn default() -> Self {
Self::TruncateOldest
}
}
#[derive(Debug, Clone)]
pub struct SessionConfig {
pub max_tokens: usize,
pub compression_threshold: f32,
pub compression_strategy: CompressionStrategy,
}
impl Default for SessionConfig {
fn default() -> Self {
Self {
max_tokens: 8_000,
compression_threshold: 0.80,
compression_strategy: CompressionStrategy::default(),
}
}
}
pub struct Session {
client: Arc<dyn OllamaClient>,
model: String,
config: SessionConfig,
messages: Vec<Message>,
}
impl Session {
pub fn new<C: OllamaClient + 'static>(
client: Arc<C>,
model: impl Into<String>,
config: SessionConfig,
) -> Self {
let client: Arc<dyn OllamaClient> = client;
Self {
client,
model: model.into(),
config,
messages: Vec::new(),
}
}
pub fn set_system_prompt(&mut self, prompt: impl Into<String>) {
self.messages.retain(|m| m.role != Role::System);
self.messages.insert(
0,
Message {
role: Role::System,
content: prompt.into(),
tool_calls: None,
},
);
}
pub async fn ask(&mut self, user_input: impl Into<String>) -> Result<String, OxideError> {
self.messages.push(Message {
role: Role::User,
content: user_input.into(),
tool_calls: None,
});
self.maybe_compress().await?;
let req = ChatRequest {
model: self.model.clone(),
messages: self.messages.clone(),
tools: None,
stream: false,
};
let resp = self.client.chat(req).await?;
let content = resp.message.content.clone();
self.messages.push(resp.message);
Ok(content)
}
pub fn history(&self) -> &[Message] {
&self.messages
}
pub fn estimated_tokens(&self) -> usize {
messages_token_count(&self.messages)
}
async fn maybe_compress(&mut self) -> Result<(), OxideError> {
let limit = (self.config.max_tokens as f32 * self.config.compression_threshold) as usize;
if self.estimated_tokens() <= limit {
return Ok(());
}
match &self.config.compression_strategy.clone() {
CompressionStrategy::TruncateOldest => self.truncate_oldest(limit),
CompressionStrategy::Summarize { model } => {
self.summarize_oldest(model.clone(), limit).await?
}
}
Ok(())
}
fn truncate_oldest(&mut self, limit: usize) {
while self.estimated_tokens() > limit {
let pos = self.messages.iter().position(|m| m.role != Role::System);
match pos {
Some(i) => {
self.messages.remove(i);
}
None => break, }
}
}
async fn summarize_oldest(&mut self, model: String, limit: usize) -> Result<(), OxideError> {
let non_system: Vec<usize> = self
.messages
.iter()
.enumerate()
.filter(|(_, m)| m.role != Role::System)
.map(|(i, _)| i)
.collect();
if non_system.len() < 2 {
self.truncate_oldest(limit);
return Ok(());
}
let half = non_system.len() / 2;
let to_summarise_indices: Vec<usize> = non_system[..half].to_vec();
let transcript: String = to_summarise_indices
.iter()
.map(|&i| {
let m = &self.messages[i];
format!("{:?}: {}", m.role, m.content)
})
.collect::<Vec<_>>()
.join("\n");
let summary_prompt = format!(
"Summarise the following conversation history concisely, preserving key facts:\n\n{transcript}"
);
let summary_req = ChatRequest {
model: model.clone(),
messages: vec![Message {
role: Role::User,
content: summary_prompt,
tool_calls: None,
}],
tools: None,
stream: false,
};
let summary_resp = self.client.chat(summary_req).await?;
let summary = summary_resp.message.content;
for &i in to_summarise_indices.iter().rev() {
self.messages.remove(i);
}
let insert_pos = self
.messages
.iter()
.position(|m| m.role != Role::System)
.unwrap_or(0);
self.messages.insert(
insert_pos,
Message {
role: Role::System,
content: format!("[Conversation summary]\n{summary}"),
tool_calls: None,
},
);
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::client::OllamaClient;
use crate::types::{
ChatResponse, EmbedRequest, EmbedResponse, GenerateRequest, GenerateResponse,
ListModelsResponse,
};
use crate::client::BoxStream;
use async_trait::async_trait;
struct EchoClient;
#[async_trait]
impl OllamaClient for EchoClient {
async fn generate(&self, _: GenerateRequest) -> Result<GenerateResponse, OxideError> {
unimplemented!()
}
async fn chat(&self, req: ChatRequest) -> Result<ChatResponse, OxideError> {
let last = req.messages.last().unwrap();
Ok(ChatResponse {
model: req.model,
message: Message {
role: Role::Assistant,
content: format!("echo: {}", last.content),
tool_calls: None,
},
done: true,
})
}
async fn embed(&self, _: EmbedRequest) -> Result<EmbedResponse, OxideError> {
unimplemented!()
}
async fn list_models(&self) -> Result<ListModelsResponse, OxideError> {
unimplemented!()
}
fn stream_generate(&self, _: GenerateRequest) -> BoxStream<GenerateResponse> {
unimplemented!()
}
fn stream_chat(&self, _: ChatRequest) -> BoxStream<ChatResponse> {
unimplemented!()
}
}
#[tokio::test]
async fn session_tracks_history() {
let mut session = Session::new(
Arc::new(EchoClient),
"llama3",
SessionConfig::default(),
);
let reply = session.ask("Hello").await.unwrap();
assert_eq!(reply, "echo: Hello");
assert_eq!(session.history().len(), 2);
session.ask("Again").await.unwrap();
assert_eq!(session.history().len(), 4);
}
#[tokio::test]
async fn system_prompt_is_prepended() {
let mut session = Session::new(
Arc::new(EchoClient),
"llama3",
SessionConfig::default(),
);
session.set_system_prompt("You are helpful.");
session.ask("Hi").await.unwrap();
assert_eq!(session.history()[0].role, Role::System);
assert_eq!(session.history()[1].role, Role::User);
assert_eq!(session.history()[2].role, Role::Assistant);
}
#[tokio::test]
async fn truncation_drops_oldest_messages() {
let config = SessionConfig {
max_tokens: 20,
compression_threshold: 0.5, compression_strategy: CompressionStrategy::TruncateOldest,
};
let mut session = Session::new(Arc::new(EchoClient), "llama3", config);
for i in 0..15 {
session.ask(format!("msg{i}")).await.unwrap();
}
assert!(session.estimated_tokens() <= 20);
}
}