use crate::store::MessageRecord;
use crate::token::estimate_record_tokens;
#[derive(Debug, Clone)]
pub struct SelectionResult {
pub head: Vec<MessageRecord>,
pub tail: Vec<MessageRecord>,
pub tail_start_id: Option<uuid::Uuid>,
}
#[derive(Debug)]
struct Turn {
indices: Vec<usize>,
}
fn turns(messages: &[MessageRecord]) -> Vec<Turn> {
let mut result = Vec::new();
let mut current = Vec::new();
for (i, msg) in messages.iter().enumerate() {
if !msg.is_compaction && msg.role == crate::store::MessageRole::User && !current.is_empty()
{
result.push(Turn {
indices: std::mem::take(&mut current),
});
}
current.push(i);
}
if !current.is_empty() {
result.push(Turn { indices: current });
}
result
}
#[must_use]
pub fn select(
messages: &[MessageRecord],
tail_turns: usize,
keep_tokens: usize,
) -> SelectionResult {
if messages.is_empty() {
return SelectionResult {
head: Vec::new(),
tail: Vec::new(),
tail_start_id: None,
};
}
let all_turns = turns(messages);
if all_turns.is_empty() {
return SelectionResult {
head: messages.to_vec(),
tail: Vec::new(),
tail_start_id: None,
};
}
let preserved_turns = if all_turns.len() <= tail_turns {
all_turns.len()
} else {
tail_turns
};
let head_turns = &all_turns[..all_turns.len() - preserved_turns];
let tail_turns_slice = &all_turns[all_turns.len() - preserved_turns..];
let head_indices: Vec<usize> = head_turns
.iter()
.flat_map(|t| t.indices.iter().copied())
.collect();
let mut tail_indices: Vec<usize> = Vec::new();
let mut accumulated = 0usize;
for turn in tail_turns_slice.iter().rev() {
let mut turn_accumulated = 0usize;
let mut turn_indices: Vec<usize> = Vec::new();
for &idx in turn.indices.iter().rev() {
let tokens = estimate_record_tokens(&messages[idx]);
if accumulated + tokens > keep_tokens && !tail_indices.is_empty() {
break;
}
turn_accumulated += tokens;
turn_indices.push(idx);
}
if accumulated + turn_accumulated > keep_tokens && !tail_indices.is_empty() {
let overflow: Vec<usize> = turn
.indices
.iter()
.copied()
.filter(|i| !turn_indices.contains(i))
.collect();
let mut full_head: Vec<usize> = head_indices.into_iter().chain(overflow).collect();
full_head.sort_unstable();
tail_indices.reverse();
let head: Vec<MessageRecord> = full_head.iter().map(|&i| messages[i].clone()).collect();
let tail: Vec<MessageRecord> =
tail_indices.iter().map(|&i| messages[i].clone()).collect();
let tail_start_id = tail.first().map(|m| m.id);
return SelectionResult {
head,
tail,
tail_start_id,
};
}
accumulated += turn_accumulated;
turn_indices.reverse();
let mut combined = turn_indices;
combined.append(&mut tail_indices);
tail_indices = combined;
}
tail_indices.sort_unstable();
let head: Vec<MessageRecord> = head_indices.iter().map(|&i| messages[i].clone()).collect();
let tail: Vec<MessageRecord> = tail_indices.iter().map(|&i| messages[i].clone()).collect();
let tail_start_id = tail.first().map(|m| m.id);
SelectionResult {
head,
tail,
tail_start_id,
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
use crate::provider::ContentPart;
use uuid::Uuid;
fn make_msg(role: crate::store::MessageRole, text: &str) -> MessageRecord {
MessageRecord::new(Uuid::now_v7(), role, vec![ContentPart::text(text)])
}
fn make_user(text: &str) -> MessageRecord {
make_msg(crate::store::MessageRole::User, text)
}
fn make_assistant(text: &str) -> MessageRecord {
make_msg(crate::store::MessageRole::Assistant, text)
}
fn make_tool(text: &str) -> MessageRecord {
make_msg(crate::store::MessageRole::Tool, text)
}
#[test]
fn empty_messages_returns_empty() {
let result = select(&[], 2, 8_000);
assert!(result.head.is_empty());
assert!(result.tail.is_empty());
assert!(result.tail_start_id.is_none());
}
#[test]
fn single_turn_within_budget() {
let messages = vec![make_user("Hello"), make_assistant("Hi there!")];
let total = estimate_record_tokens(&messages[0]) + estimate_record_tokens(&messages[1]);
let result = select(&messages, 2, total * 2);
assert!(result.head.is_empty());
assert_eq!(result.tail.len(), 2);
}
#[test]
fn multiple_turns_preserves_recent() {
let messages = vec![
make_user("Turn 1"),
make_assistant("Response 1"),
make_user("Turn 2"),
make_assistant("Response 2"),
make_user("Turn 3"),
make_assistant("Response 3"),
];
let result = select(&messages, 1, 8_000);
assert!(!result.head.is_empty(), "head should have older turns");
assert!(!result.tail.is_empty(), "tail should have recent turn");
assert_eq!(result.tail[0].role, crate::store::MessageRole::User);
}
#[test]
fn compact_tool_messages_are_in_turn() {
let messages = vec![
make_user("Use tool"),
make_assistant("Calling tool..."),
make_tool("Tool result"),
make_assistant("Done with tool"),
];
let result = select(&messages, 2, 8_000);
assert!(result.head.is_empty());
assert_eq!(result.tail.len(), 4);
}
#[test]
fn tiny_budget_forces_compaction() {
let messages: Vec<MessageRecord> = (0..5)
.flat_map(|i| {
vec![
make_user(&format!("Question {i}")),
make_assistant(&format!("Answer {i}")),
]
})
.collect();
let result = select(&messages, 2, 10);
assert!(!result.head.is_empty());
assert!(!result.tail.is_empty());
}
}