lha 1.0.2

Long-Horizon Agent command-line package that installs the lha binary.
Documentation
//! Helpers for truncating rollouts based on "user turn" boundaries.
//!
//! In core, "user turns" are detected by scanning `TranscriptItem::Message` items and
//! filtering them through ContextManager's real user-turn boundary helper.

use crate::product::agent::context_manager::is_user_turn_boundary;
use crate::product::protocol::protocol::EventMsg;
use crate::product::protocol::protocol::RolloutItem;

/// Return the indices of user message boundaries in a rollout.
///
/// A user message boundary is a `RolloutItem::TranscriptItem(TranscriptItem::Message { .. })`
/// that represents a real user turn rather than synthetic context.
///
/// Rollouts can contain `ThreadRolledBack` markers. Those markers indicate that the
/// last N user turns were removed from the effective thread history; we apply them here so
/// indexing uses the post-rollback history rather than the raw stream.
pub(crate) fn user_message_positions_in_rollout(items: &[RolloutItem]) -> Vec<usize> {
    let mut user_positions = Vec::new();
    for (idx, item) in items.iter().enumerate() {
        match item {
            RolloutItem::TranscriptItem(item) => {
                if is_user_turn_boundary(item) {
                    user_positions.push(idx);
                }
            }
            RolloutItem::EventMsg(EventMsg::ThreadRolledBack(rollback)) => {
                let num_turns = usize::try_from(rollback.num_turns).unwrap_or(usize::MAX);
                let new_len = user_positions.len().saturating_sub(num_turns);
                user_positions.truncate(new_len);
            }
            _ => {}
        }
    }
    user_positions
}

/// Return a prefix of `items` obtained by cutting strictly before the nth user message.
///
/// The boundary index is 0-based from the start of `items` (so `n_from_start = 0` returns
/// a prefix that excludes the first user message and everything after it).
///
/// If `n_from_start` is `usize::MAX`, this returns the full rollout (no truncation).
/// If fewer than or equal to `n_from_start` user messages exist, this returns an empty
/// vector (out of range).
pub(crate) fn truncate_rollout_before_nth_user_message_from_start(
    items: &[RolloutItem],
    n_from_start: usize,
) -> Vec<RolloutItem> {
    if n_from_start == usize::MAX {
        return items.to_vec();
    }

    let user_positions = user_message_positions_in_rollout(items);

    // If fewer than or equal to n user messages exist, treat as empty (out of range).
    if user_positions.len() <= n_from_start {
        return Vec::new();
    }

    // Cut strictly before the nth user message (do not keep the nth itself).
    let cut_idx = user_positions[n_from_start];
    items[..cut_idx].to_vec()
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::product::agent::codex::make_session_and_context;
    use crate::product::agent::compact::active_goal_plan_reminder_items;
    use crate::product::agent::compact::proposed_plan_backfill_items;
    use crate::product::protocol::models::ContentItem;
    use crate::product::protocol::models::ReasoningItemReasoningSummary;
    use crate::product::protocol::models::TranscriptItem;
    use crate::product::protocol::protocol::ThreadRolledBackEvent;
    use assert_matches::assert_matches;
    use pretty_assertions::assert_eq;
    use std::path::Path;

    fn user_msg(text: &str) -> TranscriptItem {
        TranscriptItem::Message {
            id: None,
            role: "user".to_string(),
            content: vec![ContentItem::OutputText {
                text: text.to_string(),
            }],
            end_turn: None,
        }
    }

    fn assistant_msg(text: &str) -> TranscriptItem {
        TranscriptItem::Message {
            id: None,
            role: "assistant".to_string(),
            content: vec![ContentItem::OutputText {
                text: text.to_string(),
            }],
            end_turn: None,
        }
    }

    #[test]
    fn truncates_rollout_from_start_before_nth_user_only() {
        let items = [
            user_msg("u1"),
            assistant_msg("a1"),
            assistant_msg("a2"),
            user_msg("u2"),
            assistant_msg("a3"),
            TranscriptItem::Reasoning {
                id: "r1".to_string(),
                summary: vec![ReasoningItemReasoningSummary::SummaryText {
                    text: "s".to_string(),
                }],
                content: None,
                encrypted_content: None,
            },
            TranscriptItem::ToolCall {
                id: None,
                call_id: "c1".to_string(),
                tool_name: "tool".to_string(),
                payload: lha_llm::ToolCallPayload::JsonArguments {
                    arguments: "{}".to_string(),
                },
            },
            assistant_msg("a4"),
        ];

        let rollout: Vec<RolloutItem> = items
            .iter()
            .cloned()
            .map(RolloutItem::TranscriptItem)
            .collect();

        let truncated = truncate_rollout_before_nth_user_message_from_start(&rollout, 1);
        let expected = vec![
            RolloutItem::TranscriptItem(items[0].clone()),
            RolloutItem::TranscriptItem(items[1].clone()),
            RolloutItem::TranscriptItem(items[2].clone()),
        ];
        assert_eq!(
            serde_json::to_value(&truncated).unwrap(),
            serde_json::to_value(&expected).unwrap()
        );

        let truncated2 = truncate_rollout_before_nth_user_message_from_start(&rollout, 2);
        assert_matches!(truncated2.as_slice(), []);
    }

    #[test]
    fn truncation_max_keeps_full_rollout() {
        let rollout = vec![
            RolloutItem::TranscriptItem(user_msg("u1")),
            RolloutItem::TranscriptItem(assistant_msg("a1")),
            RolloutItem::TranscriptItem(user_msg("u2")),
        ];

        let truncated = truncate_rollout_before_nth_user_message_from_start(&rollout, usize::MAX);

        assert_eq!(
            serde_json::to_value(&truncated).unwrap(),
            serde_json::to_value(&rollout).unwrap()
        );
    }

    #[test]
    fn ignores_backfilled_plan_reminder_when_truncating_rollout_from_start() {
        let mut items = vec![user_msg("u1"), assistant_msg("a1"), assistant_msg("a2")];
        items.extend(proposed_plan_backfill_items("- Step 1\n"));
        items.extend([user_msg("u2"), assistant_msg("a3")]);

        let rollout_items: Vec<RolloutItem> = items
            .iter()
            .cloned()
            .map(RolloutItem::TranscriptItem)
            .collect();

        let user_positions = user_message_positions_in_rollout(&rollout_items);
        assert_eq!(user_positions, vec![0, 5]);

        let truncated = truncate_rollout_before_nth_user_message_from_start(&rollout_items, 1);
        let expected = rollout_items[..5].to_vec();
        assert_eq!(
            serde_json::to_value(&truncated).unwrap(),
            serde_json::to_value(&expected).unwrap()
        );
    }

    #[test]
    fn ignores_active_goal_plan_reminder_when_truncating_rollout_from_start() {
        let mut items = vec![user_msg("u1"), assistant_msg("a1")];
        items.extend(active_goal_plan_reminder_items(Path::new(
            "/tmp/proposed_plan.md",
        )));
        items.extend([user_msg("u2"), assistant_msg("a2")]);

        let rollout_items: Vec<RolloutItem> = items
            .iter()
            .cloned()
            .map(RolloutItem::TranscriptItem)
            .collect();

        let user_positions = user_message_positions_in_rollout(&rollout_items);
        assert_eq!(user_positions, vec![0, 3]);

        let truncated = truncate_rollout_before_nth_user_message_from_start(&rollout_items, 1);
        let expected = rollout_items[..3].to_vec();
        assert_eq!(
            serde_json::to_value(&truncated).unwrap(),
            serde_json::to_value(&expected).unwrap()
        );
    }

    #[test]
    fn truncates_rollout_from_start_applies_thread_rollback_markers() {
        let rollout_items = vec![
            RolloutItem::TranscriptItem(user_msg("u1")),
            RolloutItem::TranscriptItem(assistant_msg("a1")),
            RolloutItem::TranscriptItem(user_msg("u2")),
            RolloutItem::TranscriptItem(assistant_msg("a2")),
            RolloutItem::EventMsg(EventMsg::ThreadRolledBack(ThreadRolledBackEvent {
                num_turns: 1,
            })),
            RolloutItem::TranscriptItem(user_msg("u3")),
            RolloutItem::TranscriptItem(assistant_msg("a3")),
            RolloutItem::TranscriptItem(user_msg("u4")),
            RolloutItem::TranscriptItem(assistant_msg("a4")),
        ];

        // Effective user history after applying rollback(1) is: u1, u3, u4.
        // So n_from_start=2 should cut before u4 (not u3).
        let truncated = truncate_rollout_before_nth_user_message_from_start(&rollout_items, 2);
        let expected = rollout_items[..7].to_vec();
        assert_eq!(
            serde_json::to_value(&truncated).unwrap(),
            serde_json::to_value(&expected).unwrap()
        );
    }

    #[tokio::test]
    async fn ignores_session_prefix_messages_when_truncating_rollout_from_start() {
        let (session, turn_context) = make_session_and_context().await;
        let mut items = session.build_initial_context(&turn_context).await;
        items.push(user_msg("feature request"));
        items.push(assistant_msg("ack"));
        items.push(user_msg("second question"));
        items.push(assistant_msg("answer"));

        let rollout_items: Vec<RolloutItem> = items
            .iter()
            .cloned()
            .map(RolloutItem::TranscriptItem)
            .collect();

        let truncated = truncate_rollout_before_nth_user_message_from_start(&rollout_items, 1);
        let expected: Vec<RolloutItem> = vec![
            RolloutItem::TranscriptItem(items[0].clone()),
            RolloutItem::TranscriptItem(items[1].clone()),
            RolloutItem::TranscriptItem(items[2].clone()),
            RolloutItem::TranscriptItem(items[3].clone()),
        ];

        assert_eq!(
            serde_json::to_value(&truncated).unwrap(),
            serde_json::to_value(&expected).unwrap()
        );
    }
}