use std::sync::Mutex;
use std::sync::atomic::{AtomicUsize, Ordering};
use serde_json::Value;
use crate::msg::LlmEvent;
use crate::request::{Content, Message, ToolCall};
use crate::server::anthropic::outbound::SseState;
use super::session::MCP_SERVER_NAME;
struct ToolMatch {
name: String,
args: Value,
result: String,
taken: bool,
}
pub(crate) enum TurnAction {
Fake(String),
Passthrough,
Halt(String),
}
pub(crate) struct ReplayState {
fakes: Vec<String>,
halt: String,
tool_queues: Vec<Mutex<Vec<ToolMatch>>>,
next_post: AtomicUsize,
current_turn: AtomicUsize,
}
impl ReplayState {
pub(crate) fn fake_count(&self) -> usize {
self.fakes.len()
}
pub(crate) fn next_action(&self) -> TurnAction {
let i = self.next_post.fetch_add(1, Ordering::SeqCst);
if i < self.fakes.len() {
self.current_turn.store(i, Ordering::SeqCst);
TurnAction::Fake(self.fakes[i].clone())
} else if i == self.fakes.len() {
TurnAction::Passthrough
} else {
TurnAction::Halt(self.halt.clone())
}
}
pub(crate) fn take_tool_result(&self, name: &str, args: &Value) -> Option<String> {
let turn = self.current_turn.load(Ordering::SeqCst);
let queue = self.tool_queues.get(turn)?;
let mut queue = queue.lock().ok()?;
let exact = queue
.iter()
.position(|m| !m.taken && m.name == name && &m.args == args);
let pos = exact.or_else(|| queue.iter().position(|m| !m.taken && m.name == name))?;
queue[pos].taken = true;
Some(queue[pos].result.clone())
}
}
pub(crate) fn build_replay(recorded: &[Message], model: &str) -> Option<ReplayState> {
let mut fakes: Vec<String> = Vec::new();
let mut tool_queues: Vec<Mutex<Vec<ToolMatch>>> = Vec::new();
let mut i = 0;
while i < recorded.len() {
let Message::Assistant {
content,
tool_calls,
..
} = &recorded[i]
else {
i += 1;
continue;
};
let mut results: std::collections::HashMap<String, String> =
std::collections::HashMap::new();
let mut j = i + 1;
while let Some(Message::ToolResult { call_id, content }) = recorded.get(j) {
results.insert(call_id.clone(), join_text(content));
j += 1;
}
let (sse, matches) = render_turn(content.as_deref(), tool_calls, &results, model);
fakes.push(sse);
tool_queues.push(Mutex::new(matches));
i = j;
}
if fakes.is_empty() {
return None;
}
Some(ReplayState {
halt: render_sse(model, vec![LlmEvent::Done]),
fakes,
tool_queues,
next_post: AtomicUsize::new(0),
current_turn: AtomicUsize::new(0),
})
}
fn render_turn(
content: Option<&str>,
tool_calls: &[ToolCall],
results: &std::collections::HashMap<String, String>,
model: &str,
) -> (String, Vec<ToolMatch>) {
let mut events: Vec<LlmEvent> = Vec::new();
if let Some(text) = content.filter(|t| !t.is_empty()) {
events.push(LlmEvent::Token(text.to_string()));
}
let mut matches: Vec<ToolMatch> = Vec::new();
for tc in tool_calls {
let input: Value = serde_json::from_str(&tc.arguments).unwrap_or(Value::Null);
events.push(LlmEvent::ToolCall(ToolCall {
id: normalise_id(&tc.id),
name: format!("mcp__{MCP_SERVER_NAME}__{}", tc.name),
arguments: tc.arguments.clone(),
}));
matches.push(ToolMatch {
name: tc.name.clone(),
args: if input == Value::Null {
serde_json::json!({})
} else {
input
},
result: results.get(&tc.id).cloned().unwrap_or_default(),
taken: false,
});
}
events.push(LlmEvent::Done);
(render_sse(model, events), matches)
}
fn render_sse(model: &str, events: Vec<LlmEvent>) -> String {
let mut state = SseState::new(model.to_string());
let mut out = String::new();
for ev in events {
for (name, payload) in state.on_event(ev) {
out.push_str("event: ");
out.push_str(name);
out.push_str("\ndata: ");
out.push_str(&payload.to_string());
out.push_str("\n\n");
}
}
out
}
fn normalise_id(id: &str) -> String {
if id.starts_with("toolu_") {
id.to_string()
} else {
format!("toolu_{}", uuid::Uuid::new_v4().simple())
}
}
fn join_text(content: &[Content]) -> String {
content
.iter()
.filter_map(|c| match c {
Content::Text { text } => Some(text.as_str()),
_ => None,
})
.collect::<Vec<_>>()
.join("\n")
}
#[cfg(test)]
mod tests {
use super::*;
fn asst(tool: &str, id: &str, args: &str) -> Message {
Message::Assistant {
content: None,
reasoning: None,
tool_calls: vec![ToolCall {
id: id.into(),
name: tool.into(),
arguments: args.into(),
}],
provider_data: None,
}
}
fn tr(id: &str, text: &str) -> Message {
Message::ToolResult {
call_id: id.into(),
content: vec![Content::text(text)],
}
}
#[test]
fn no_recorded_turns_returns_none() {
assert!(build_replay(&[], "m").is_none());
}
#[test]
fn replays_each_turn_then_passthrough_then_halt() {
let recorded = vec![
asst("bash", "c1", "{\"cmd\":\"ls\"}"),
tr("c1", "file.txt"),
asst("bash", "c2", "{\"cmd\":\"cat\"}"),
tr("c2", "hello"),
];
let st = build_replay(&recorded, "m").expect("two turns");
let a0 = st.next_action();
assert!(matches!(a0, TurnAction::Fake(_)));
assert_eq!(
st.take_tool_result("bash", &serde_json::json!({"cmd":"ls"}))
.as_deref(),
Some("file.txt")
);
assert!(matches!(st.next_action(), TurnAction::Fake(_)));
assert_eq!(
st.take_tool_result("bash", &serde_json::json!({"cmd":"cat"}))
.as_deref(),
Some("hello")
);
assert!(matches!(st.next_action(), TurnAction::Passthrough));
assert!(matches!(st.next_action(), TurnAction::Halt(_)));
assert!(matches!(st.next_action(), TurnAction::Halt(_)));
}
#[test]
fn fake_sse_is_wellformed_and_namespaced() {
let recorded = vec![asst("bash", "c1", "{\"cmd\":\"ls\"}"), tr("c1", "ok")];
let st = build_replay(&recorded, "m").unwrap();
let TurnAction::Fake(sse) = st.next_action() else {
panic!("expected fake")
};
assert!(sse.contains("event: message_start"));
assert!(sse.contains("event: content_block_start"));
assert!(sse.contains("event: message_stop"));
assert!(sse.contains("mcp__agentix__bash"));
assert!(sse.contains("toolu_"));
}
#[test]
fn duplicate_tool_calls_drain_in_order() {
let recorded = vec![
Message::Assistant {
content: None,
reasoning: None,
tool_calls: vec![
ToolCall {
id: "a".into(),
name: "t".into(),
arguments: "{}".into(),
},
ToolCall {
id: "b".into(),
name: "t".into(),
arguments: "{}".into(),
},
],
provider_data: None,
},
tr("a", "first"),
tr("b", "second"),
];
let st = build_replay(&recorded, "m").unwrap();
let _ = st.next_action(); let empty = serde_json::json!({});
let got1 = st.take_tool_result("t", &empty).unwrap();
let got2 = st.take_tool_result("t", &empty).unwrap();
let mut got = vec![got1, got2];
got.sort();
assert_eq!(got, vec!["first".to_string(), "second".to_string()]);
assert!(st.take_tool_result("t", &empty).is_none());
}
}