use crate::tokenizer::{CharApproxTokenizer, Tokenizer};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct Message {
pub role: String,
pub content: String,
}
impl Message {
pub fn system(text: impl Into<String>) -> Self {
Self { role: "system".into(), content: text.into() }
}
pub fn user(text: impl Into<String>) -> Self {
Self { role: "user".into(), content: text.into() }
}
pub fn assistant(text: impl Into<String>) -> Self {
Self { role: "assistant".into(), content: text.into() }
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Strategy {
DropOldest,
DropMiddle,
TruncateLargest,
}
#[derive(Clone)]
pub struct Fitter {
max_tokens: usize,
tokenizer: Arc<dyn Tokenizer>,
per_message_overhead: usize,
}
impl Fitter {
pub fn new(max_tokens: usize) -> Self {
Self {
max_tokens,
tokenizer: Arc::new(CharApproxTokenizer),
per_message_overhead: 4,
}
}
pub fn with_tokenizer<T: Tokenizer + 'static>(mut self, t: T) -> Self {
self.tokenizer = Arc::new(t);
self
}
pub fn with_per_message_overhead(mut self, n: usize) -> Self {
self.per_message_overhead = n;
self
}
pub fn count(&self, messages: &[Message]) -> usize {
messages
.iter()
.map(|m| self.tokenizer.count(&m.content) + self.per_message_overhead)
.sum()
}
pub fn fit(&self, mut messages: Vec<Message>, strategy: Strategy) -> Vec<Message> {
if self.count(&messages) <= self.max_tokens {
return messages;
}
let has_system = messages.first().map(|m| m.role == "system").unwrap_or(false);
let last_idx = messages.len().checked_sub(1);
let last_is_user = last_idx
.map(|i| messages[i].role == "user")
.unwrap_or(false);
let drop_start = if has_system { 1 } else { 0 };
let drop_end_exclusive = if last_is_user {
messages.len() - 1
} else {
messages.len()
};
match strategy {
Strategy::DropOldest => {
while self.count(&messages) > self.max_tokens && drop_start < messages.len() {
let cur_drop_end = if last_is_user {
messages.len() - 1
} else {
messages.len()
};
if drop_start >= cur_drop_end {
break;
}
messages.remove(drop_start);
}
}
Strategy::DropMiddle => {
while self.count(&messages) > self.max_tokens {
let lo = drop_start;
let hi = if last_is_user {
messages.len().saturating_sub(1)
} else {
messages.len()
};
if hi <= lo {
break;
}
let mid = lo + (hi - lo) / 2;
messages.remove(mid);
}
}
Strategy::TruncateLargest => {
const MARKER: &str = " …[truncated]";
loop {
if self.count(&messages) <= self.max_tokens {
break;
}
let candidate_indices: Vec<usize> = (drop_start..drop_end_exclusive).collect();
if candidate_indices.is_empty() {
break;
}
let (idx, cur_len) = candidate_indices
.into_iter()
.map(|i| (i, messages[i].content.chars().count()))
.max_by_key(|&(_, n)| n)
.unwrap();
if cur_len == 0 {
break;
}
let chars: Vec<char> = messages[idx].content.chars().collect();
let keep = chars.len() / 2;
let new_content: String =
chars.iter().take(keep).collect::<String>() + MARKER;
let new_len = new_content.chars().count();
if new_len >= cur_len {
break;
}
messages[idx].content = new_content;
if keep == 0 {
break;
}
}
}
}
messages
}
}