use crate::product::agent::context_manager::is_user_turn_boundary;
use crate::product::protocol::protocol::EventMsg;
use crate::product::protocol::protocol::RolloutItem;
pub(crate) fn user_message_positions_in_rollout(items: &[RolloutItem]) -> Vec<usize> {
let mut user_positions = Vec::new();
for (idx, item) in items.iter().enumerate() {
match item {
RolloutItem::TranscriptItem(item) => {
if is_user_turn_boundary(item) {
user_positions.push(idx);
}
}
RolloutItem::EventMsg(EventMsg::ThreadRolledBack(rollback)) => {
let num_turns = usize::try_from(rollback.num_turns).unwrap_or(usize::MAX);
let new_len = user_positions.len().saturating_sub(num_turns);
user_positions.truncate(new_len);
}
_ => {}
}
}
user_positions
}
pub(crate) fn truncate_rollout_before_nth_user_message_from_start(
items: &[RolloutItem],
n_from_start: usize,
) -> Vec<RolloutItem> {
if n_from_start == usize::MAX {
return items.to_vec();
}
let user_positions = user_message_positions_in_rollout(items);
if user_positions.len() <= n_from_start {
return Vec::new();
}
let cut_idx = user_positions[n_from_start];
items[..cut_idx].to_vec()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::product::agent::codex::make_session_and_context;
use crate::product::agent::compact::active_goal_plan_reminder_items;
use crate::product::agent::compact::proposed_plan_backfill_items;
use crate::product::protocol::models::ContentItem;
use crate::product::protocol::models::ReasoningItemReasoningSummary;
use crate::product::protocol::models::TranscriptItem;
use crate::product::protocol::protocol::ThreadRolledBackEvent;
use assert_matches::assert_matches;
use pretty_assertions::assert_eq;
use std::path::Path;
fn user_msg(text: &str) -> TranscriptItem {
TranscriptItem::Message {
id: None,
role: "user".to_string(),
content: vec![ContentItem::OutputText {
text: text.to_string(),
}],
end_turn: None,
}
}
fn assistant_msg(text: &str) -> TranscriptItem {
TranscriptItem::Message {
id: None,
role: "assistant".to_string(),
content: vec![ContentItem::OutputText {
text: text.to_string(),
}],
end_turn: None,
}
}
#[test]
fn truncates_rollout_from_start_before_nth_user_only() {
let items = [
user_msg("u1"),
assistant_msg("a1"),
assistant_msg("a2"),
user_msg("u2"),
assistant_msg("a3"),
TranscriptItem::Reasoning {
id: "r1".to_string(),
summary: vec![ReasoningItemReasoningSummary::SummaryText {
text: "s".to_string(),
}],
content: None,
encrypted_content: None,
},
TranscriptItem::ToolCall {
id: None,
call_id: "c1".to_string(),
tool_name: "tool".to_string(),
payload: lha_llm::ToolCallPayload::JsonArguments {
arguments: "{}".to_string(),
},
},
assistant_msg("a4"),
];
let rollout: Vec<RolloutItem> = items
.iter()
.cloned()
.map(RolloutItem::TranscriptItem)
.collect();
let truncated = truncate_rollout_before_nth_user_message_from_start(&rollout, 1);
let expected = vec![
RolloutItem::TranscriptItem(items[0].clone()),
RolloutItem::TranscriptItem(items[1].clone()),
RolloutItem::TranscriptItem(items[2].clone()),
];
assert_eq!(
serde_json::to_value(&truncated).unwrap(),
serde_json::to_value(&expected).unwrap()
);
let truncated2 = truncate_rollout_before_nth_user_message_from_start(&rollout, 2);
assert_matches!(truncated2.as_slice(), []);
}
#[test]
fn truncation_max_keeps_full_rollout() {
let rollout = vec![
RolloutItem::TranscriptItem(user_msg("u1")),
RolloutItem::TranscriptItem(assistant_msg("a1")),
RolloutItem::TranscriptItem(user_msg("u2")),
];
let truncated = truncate_rollout_before_nth_user_message_from_start(&rollout, usize::MAX);
assert_eq!(
serde_json::to_value(&truncated).unwrap(),
serde_json::to_value(&rollout).unwrap()
);
}
#[test]
fn ignores_backfilled_plan_reminder_when_truncating_rollout_from_start() {
let mut items = vec![user_msg("u1"), assistant_msg("a1"), assistant_msg("a2")];
items.extend(proposed_plan_backfill_items("- Step 1\n"));
items.extend([user_msg("u2"), assistant_msg("a3")]);
let rollout_items: Vec<RolloutItem> = items
.iter()
.cloned()
.map(RolloutItem::TranscriptItem)
.collect();
let user_positions = user_message_positions_in_rollout(&rollout_items);
assert_eq!(user_positions, vec![0, 5]);
let truncated = truncate_rollout_before_nth_user_message_from_start(&rollout_items, 1);
let expected = rollout_items[..5].to_vec();
assert_eq!(
serde_json::to_value(&truncated).unwrap(),
serde_json::to_value(&expected).unwrap()
);
}
#[test]
fn ignores_active_goal_plan_reminder_when_truncating_rollout_from_start() {
let mut items = vec![user_msg("u1"), assistant_msg("a1")];
items.extend(active_goal_plan_reminder_items(Path::new(
"/tmp/proposed_plan.md",
)));
items.extend([user_msg("u2"), assistant_msg("a2")]);
let rollout_items: Vec<RolloutItem> = items
.iter()
.cloned()
.map(RolloutItem::TranscriptItem)
.collect();
let user_positions = user_message_positions_in_rollout(&rollout_items);
assert_eq!(user_positions, vec![0, 3]);
let truncated = truncate_rollout_before_nth_user_message_from_start(&rollout_items, 1);
let expected = rollout_items[..3].to_vec();
assert_eq!(
serde_json::to_value(&truncated).unwrap(),
serde_json::to_value(&expected).unwrap()
);
}
#[test]
fn truncates_rollout_from_start_applies_thread_rollback_markers() {
let rollout_items = vec![
RolloutItem::TranscriptItem(user_msg("u1")),
RolloutItem::TranscriptItem(assistant_msg("a1")),
RolloutItem::TranscriptItem(user_msg("u2")),
RolloutItem::TranscriptItem(assistant_msg("a2")),
RolloutItem::EventMsg(EventMsg::ThreadRolledBack(ThreadRolledBackEvent {
num_turns: 1,
})),
RolloutItem::TranscriptItem(user_msg("u3")),
RolloutItem::TranscriptItem(assistant_msg("a3")),
RolloutItem::TranscriptItem(user_msg("u4")),
RolloutItem::TranscriptItem(assistant_msg("a4")),
];
let truncated = truncate_rollout_before_nth_user_message_from_start(&rollout_items, 2);
let expected = rollout_items[..7].to_vec();
assert_eq!(
serde_json::to_value(&truncated).unwrap(),
serde_json::to_value(&expected).unwrap()
);
}
#[tokio::test]
async fn ignores_session_prefix_messages_when_truncating_rollout_from_start() {
let (session, turn_context) = make_session_and_context().await;
let mut items = session.build_initial_context(&turn_context).await;
items.push(user_msg("feature request"));
items.push(assistant_msg("ack"));
items.push(user_msg("second question"));
items.push(assistant_msg("answer"));
let rollout_items: Vec<RolloutItem> = items
.iter()
.cloned()
.map(RolloutItem::TranscriptItem)
.collect();
let truncated = truncate_rollout_before_nth_user_message_from_start(&rollout_items, 1);
let expected: Vec<RolloutItem> = vec![
RolloutItem::TranscriptItem(items[0].clone()),
RolloutItem::TranscriptItem(items[1].clone()),
RolloutItem::TranscriptItem(items[2].clone()),
RolloutItem::TranscriptItem(items[3].clone()),
];
assert_eq!(
serde_json::to_value(&truncated).unwrap(),
serde_json::to_value(&expected).unwrap()
);
}
}