use std::collections::HashMap;
use crate::provider::{ContentPart, Message};
use crate::session::relevance::extract;
#[derive(Debug, Clone)]
pub struct OracleReport {
pub demand: Vec<Vec<usize>>,
pub horizon: usize,
}
impl OracleReport {
pub fn reference_count(&self) -> usize {
self.demand.iter().map(Vec::len).sum()
}
}
pub fn replay_oracle(messages: &[Message], h: usize) -> OracleReport {
let files_per_turn = index_files_per_turn(messages);
let tool_call_owner = index_tool_call_owners(messages);
let mut demand: Vec<Vec<usize>> = vec![Vec::new(); messages.len()];
for t in 0..messages.len() {
let window_end = (t + h).min(messages.len().saturating_sub(1));
for future_t in t..=window_end {
let future = &messages[future_t];
for u in references_for_turn(future, t, &files_per_turn, &tool_call_owner) {
if u < t && !demand[t].contains(&u) {
demand[t].push(u);
}
}
}
demand[t].sort_unstable();
}
OracleReport { demand, horizon: h }
}
fn index_files_per_turn(messages: &[Message]) -> Vec<Vec<String>> {
messages.iter().map(|m| extract(m).files).collect()
}
fn index_tool_call_owners(messages: &[Message]) -> HashMap<String, usize> {
let mut owners = HashMap::new();
for (idx, msg) in messages.iter().enumerate() {
for part in &msg.content {
if let ContentPart::ToolCall { id, .. } = part {
owners.insert(id.clone(), idx);
}
}
}
owners
}
fn references_for_turn(
future: &Message,
current_idx: usize,
files_per_turn: &[Vec<String>],
tool_call_owner: &HashMap<String, usize>,
) -> Vec<usize> {
let mut out = Vec::new();
for part in &future.content {
match part {
ContentPart::ToolResult { tool_call_id, .. } => {
if let Some(&owner) = tool_call_owner.get(tool_call_id) {
out.push(owner);
}
}
ContentPart::Text { text } => {
let file_refs = extract_text_file_tokens(text);
for file in file_refs {
for (u, files) in files_per_turn.iter().enumerate().take(current_idx) {
if files.iter().any(|f| f == &file) {
out.push(u);
break;
}
}
}
}
_ => {}
}
}
out
}
fn extract_text_file_tokens(text: &str) -> Vec<String> {
use crate::provider::Role;
let synthetic = Message {
role: Role::User,
content: vec![ContentPart::Text {
text: text.to_string(),
}],
};
extract(&synthetic).files
}
#[cfg(test)]
mod tests {
use super::*;
use crate::provider::{ContentPart, Message, Role};
fn text(role: Role, s: &str) -> Message {
Message {
role,
content: vec![ContentPart::Text {
text: s.to_string(),
}],
}
}
fn tool_call(id: &str, name: &str) -> Message {
Message {
role: Role::Assistant,
content: vec![ContentPart::ToolCall {
id: id.to_string(),
name: name.to_string(),
arguments: "{}".to_string(),
thought_signature: None,
}],
}
}
fn tool_result(id: &str, body: &str) -> Message {
Message {
role: Role::Tool,
content: vec![ContentPart::ToolResult {
tool_call_id: id.to_string(),
content: body.to_string(),
}],
}
}
#[test]
fn oracle_binds_tool_results_to_their_call_owners() {
let msgs = vec![
text(Role::User, "do the thing"),
tool_call("call-1", "Shell"),
tool_result("call-1", "ok"),
];
let report = replay_oracle(&msgs, 5);
assert!(report.demand[2].contains(&1));
assert!(report.reference_count() >= 1);
}
#[test]
fn oracle_tracks_file_references_across_turns() {
let msgs = vec![
text(Role::User, "edit src/lib.rs"),
text(Role::Assistant, "ok"),
text(Role::User, "now open src/lib.rs again"),
];
let report = replay_oracle(&msgs, 5);
assert!(report.demand[2].contains(&0));
}
#[test]
fn oracle_respects_horizon_bound() {
let msgs = vec![
text(Role::User, "edit src/lib.rs"),
text(Role::Assistant, "noop"),
text(Role::Assistant, "noop"),
text(Role::User, "reopen src/lib.rs"),
];
let short = replay_oracle(&msgs, 1);
let long = replay_oracle(&msgs, 10);
assert!(long.reference_count() >= short.reference_count());
}
#[test]
fn report_over_empty_trace_is_empty() {
let report = replay_oracle(&[], 4);
assert!(report.demand.is_empty());
assert_eq!(report.horizon, 4);
assert_eq!(report.reference_count(), 0);
}
}