use std::collections::VecDeque;
use std::path::PathBuf;
use std::sync::Arc;
use async_trait::async_trait;
use rust_decimal::Decimal;
use serde::{Deserialize, Serialize};
use tokio::sync::Mutex;
use crate::llm::error::LlmError;
use crate::llm::provider::{
ChatMessage, CompletionRequest, CompletionResponse, LlmProvider, ModelMetadata, Role,
ToolCompletionRequest, ToolCompletionResponse,
};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TraceFile {
pub model_name: String,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub memory_snapshot: Vec<MemorySnapshotEntry>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub http_exchanges: Vec<HttpExchange>,
pub steps: Vec<TraceStep>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MemorySnapshotEntry {
pub path: String,
pub content: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HttpExchange {
pub request: HttpExchangeRequest,
pub response: HttpExchangeResponse,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HttpExchangeRequest {
pub method: String,
pub url: String,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub headers: Vec<(String, String)>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub body: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HttpExchangeResponse {
pub status: u16,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub headers: Vec<(String, String)>,
pub body: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TraceStep {
#[serde(skip_serializing_if = "Option::is_none")]
pub request_hint: Option<RequestHint>,
pub response: TraceResponse,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub expected_tool_results: Vec<ExpectedToolResult>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RequestHint {
#[serde(skip_serializing_if = "Option::is_none")]
pub last_user_message_contains: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub min_message_count: Option<usize>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum TraceResponse {
Text {
content: String,
input_tokens: u32,
output_tokens: u32,
},
ToolCalls {
tool_calls: Vec<TraceToolCall>,
input_tokens: u32,
output_tokens: u32,
},
UserInput { content: String },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TraceToolCall {
pub id: String,
pub name: String,
pub arguments: serde_json::Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExpectedToolResult {
pub tool_call_id: String,
pub name: String,
pub content: String,
}
#[async_trait]
pub trait HttpInterceptor: Send + Sync + std::fmt::Debug {
async fn before_request(&self, request: &HttpExchangeRequest) -> Option<HttpExchangeResponse>;
async fn after_response(&self, request: &HttpExchangeRequest, response: &HttpExchangeResponse);
}
#[derive(Debug)]
pub struct RecordingHttpInterceptor {
exchanges: Mutex<Vec<HttpExchange>>,
}
impl Default for RecordingHttpInterceptor {
fn default() -> Self {
Self::new()
}
}
impl RecordingHttpInterceptor {
pub fn new() -> Self {
Self {
exchanges: Mutex::new(Vec::new()),
}
}
pub async fn take_exchanges(&self) -> Vec<HttpExchange> {
self.exchanges.lock().await.clone()
}
}
#[async_trait]
impl HttpInterceptor for RecordingHttpInterceptor {
async fn before_request(&self, _request: &HttpExchangeRequest) -> Option<HttpExchangeResponse> {
None
}
async fn after_response(&self, request: &HttpExchangeRequest, response: &HttpExchangeResponse) {
self.exchanges.lock().await.push(HttpExchange {
request: request.clone(),
response: response.clone(),
});
}
}
#[derive(Debug)]
pub struct ReplayingHttpInterceptor {
exchanges: Mutex<VecDeque<HttpExchange>>,
}
impl ReplayingHttpInterceptor {
pub fn new(exchanges: Vec<HttpExchange>) -> Self {
Self {
exchanges: Mutex::new(VecDeque::from(exchanges)),
}
}
}
#[async_trait]
impl HttpInterceptor for ReplayingHttpInterceptor {
async fn before_request(&self, request: &HttpExchangeRequest) -> Option<HttpExchangeResponse> {
let mut queue = self.exchanges.lock().await;
if let Some(exchange) = queue.pop_front() {
if exchange.request.url != request.url || exchange.request.method != request.method {
tracing::warn!(
expected_url = %exchange.request.url,
actual_url = %request.url,
expected_method = %exchange.request.method,
actual_method = %request.method,
"HTTP replay: request mismatch (returning recorded response anyway)"
);
}
Some(exchange.response)
} else {
tracing::error!(
url = %request.url,
method = %request.method,
"HTTP replay: no more recorded exchanges, returning error"
);
Some(HttpExchangeResponse {
status: 599,
headers: Vec::new(),
body: "trace replay: no more recorded HTTP exchanges".to_string(),
})
}
}
async fn after_response(
&self,
_request: &HttpExchangeRequest,
_response: &HttpExchangeResponse,
) {
}
}
pub struct RecordingLlm {
inner: Arc<dyn LlmProvider>,
steps: Mutex<Vec<TraceStep>>,
prev_message_count: Mutex<usize>,
output_path: PathBuf,
model_name: String,
memory_snapshot: Mutex<Vec<MemorySnapshotEntry>>,
http_interceptor: Arc<RecordingHttpInterceptor>,
}
impl RecordingLlm {
pub fn new(inner: Arc<dyn LlmProvider>, output_path: PathBuf, model_name: String) -> Self {
Self {
inner,
steps: Mutex::new(Vec::new()),
prev_message_count: Mutex::new(0),
output_path,
model_name,
memory_snapshot: Mutex::new(Vec::new()),
http_interceptor: Arc::new(RecordingHttpInterceptor::new()),
}
}
pub fn from_env(inner: Arc<dyn LlmProvider>) -> Option<Arc<Self>> {
let enabled = std::env::var("IRONCLAW_RECORD_TRACE")
.ok()
.filter(|v| !v.is_empty());
enabled?;
let output_path = std::env::var("IRONCLAW_TRACE_OUTPUT")
.ok()
.filter(|v| !v.is_empty())
.map(PathBuf::from)
.unwrap_or_else(|| {
let ts = chrono::Local::now().format("%Y%m%dT%H%M%S");
PathBuf::from(format!("trace_{ts}.json"))
});
let model_name = std::env::var("IRONCLAW_TRACE_MODEL_NAME")
.ok()
.filter(|v| !v.is_empty())
.unwrap_or_else(|| format!("recorded-{}", inner.model_name()));
tracing::info!(
output = %output_path.display(),
model = %model_name,
"LLM trace recording enabled"
);
Some(Arc::new(Self::new(inner, output_path, model_name)))
}
pub fn http_interceptor(&self) -> Arc<dyn HttpInterceptor> {
Arc::clone(&self.http_interceptor) as Arc<dyn HttpInterceptor>
}
pub async fn snapshot_memory(&self, workspace: &crate::workspace::Workspace) {
match workspace.list_all().await {
Ok(paths) => {
let mut snapshot = self.memory_snapshot.lock().await;
for path in paths {
match workspace.read(&path).await {
Ok(doc) => {
snapshot.push(MemorySnapshotEntry {
path: doc.path,
content: doc.content,
});
}
Err(e) => {
tracing::debug!(path = %path, error = %e, "Skipped memory doc in snapshot");
}
}
}
tracing::info!(
documents = snapshot.len(),
"Captured memory snapshot for trace recording"
);
}
Err(e) => {
tracing::warn!("Failed to snapshot memory for trace recording: {}", e);
}
}
}
pub async fn flush(&self) -> Result<(), std::io::Error> {
let steps = self.steps.lock().await;
let memory_snapshot = self.memory_snapshot.lock().await;
let http_exchanges = self.http_interceptor.take_exchanges().await;
let trace = TraceFile {
model_name: self.model_name.clone(),
memory_snapshot: memory_snapshot.clone(),
http_exchanges,
steps: steps.clone(),
};
let json = serde_json::to_string_pretty(&trace).map_err(std::io::Error::other)?;
tokio::fs::write(&self.output_path, json).await?;
tracing::info!(
steps = steps.len(),
memory_docs = memory_snapshot.len(),
path = %self.output_path.display(),
"Flushed LLM trace recording"
);
Ok(())
}
async fn capture_new_messages(
&self,
messages: &[ChatMessage],
) -> (Option<RequestHint>, Vec<ExpectedToolResult>) {
let mut prev_count = self.prev_message_count.lock().await;
let current_count = messages.len();
let start = (*prev_count).min(current_count);
let new_messages = &messages[start..];
let new_user_messages: Vec<&ChatMessage> = new_messages
.iter()
.filter(|m| m.role == Role::User)
.collect();
if !new_user_messages.is_empty() {
let mut steps = self.steps.lock().await;
for msg in &new_user_messages {
steps.push(TraceStep {
request_hint: None,
response: TraceResponse::UserInput {
content: msg.content.clone(),
},
expected_tool_results: Vec::new(),
});
}
}
let tool_results: Vec<ExpectedToolResult> = new_messages
.iter()
.filter(|m| m.role == Role::Tool)
.map(|m| ExpectedToolResult {
tool_call_id: m.tool_call_id.clone().unwrap_or_default(),
name: m.name.clone().unwrap_or_default(),
content: m.content.clone(),
})
.collect();
*prev_count = current_count;
let hint = messages
.iter()
.rev()
.find(|m| m.role == Role::User)
.map(|msg| {
let hint_text = if msg.content.len() > 80 {
let mut end = 80;
while end > 0 && !msg.content.is_char_boundary(end) {
end -= 1;
}
msg.content[..end].to_string()
} else {
msg.content.clone()
};
RequestHint {
last_user_message_contains: Some(hint_text),
min_message_count: Some(current_count),
}
});
(hint, tool_results)
}
}
#[async_trait]
impl LlmProvider for RecordingLlm {
fn model_name(&self) -> &str {
self.inner.model_name()
}
fn cost_per_token(&self) -> (Decimal, Decimal) {
self.inner.cost_per_token()
}
fn cache_write_multiplier(&self) -> Decimal {
self.inner.cache_write_multiplier()
}
fn cache_read_discount(&self) -> Decimal {
self.inner.cache_read_discount()
}
async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse, LlmError> {
let (hint, tool_results) = self.capture_new_messages(&request.messages).await;
let response = self.inner.complete(request).await?;
self.steps.lock().await.push(TraceStep {
request_hint: hint,
response: TraceResponse::Text {
content: response.content.clone(),
input_tokens: response.input_tokens,
output_tokens: response.output_tokens,
},
expected_tool_results: tool_results,
});
Ok(response)
}
async fn complete_with_tools(
&self,
request: ToolCompletionRequest,
) -> Result<ToolCompletionResponse, LlmError> {
let (hint, tool_results) = self.capture_new_messages(&request.messages).await;
let response = self.inner.complete_with_tools(request).await?;
let step = if response.tool_calls.is_empty() {
TraceStep {
request_hint: hint,
response: TraceResponse::Text {
content: response.content.clone().unwrap_or_default(),
input_tokens: response.input_tokens,
output_tokens: response.output_tokens,
},
expected_tool_results: tool_results,
}
} else {
TraceStep {
request_hint: hint,
response: TraceResponse::ToolCalls {
tool_calls: response
.tool_calls
.iter()
.map(|tc| TraceToolCall {
id: tc.id.clone(),
name: tc.name.clone(),
arguments: tc.arguments.clone(),
})
.collect(),
input_tokens: response.input_tokens,
output_tokens: response.output_tokens,
},
expected_tool_results: tool_results,
}
};
self.steps.lock().await.push(step);
Ok(response)
}
async fn list_models(&self) -> Result<Vec<String>, LlmError> {
self.inner.list_models().await
}
async fn model_metadata(&self) -> Result<ModelMetadata, LlmError> {
self.inner.model_metadata().await
}
fn effective_model_name(&self, requested_model: Option<&str>) -> String {
self.inner.effective_model_name(requested_model)
}
fn active_model_name(&self) -> String {
self.inner.active_model_name()
}
fn set_model(&self, model: &str) -> Result<(), LlmError> {
self.inner.set_model(model)
}
fn calculate_cost(&self, input_tokens: u32, output_tokens: u32) -> Decimal {
self.inner.calculate_cost(input_tokens, output_tokens)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::testing::StubLlm;
fn make_recorder(stub: Arc<StubLlm>) -> RecordingLlm {
let dir = tempfile::tempdir().expect("failed to create temp dir");
RecordingLlm::new(
stub,
dir.path().join("test_recording.json"),
"test-recording".to_string(),
)
}
#[tokio::test]
async fn captures_user_input_before_first_response() {
let stub = Arc::new(StubLlm::new("hello back"));
let recorder = make_recorder(stub);
let request = CompletionRequest::new(vec![
ChatMessage::system("You are helpful."),
ChatMessage::user("Hello!"),
]);
recorder.complete(request).await.unwrap();
let steps = recorder.steps.lock().await;
assert_eq!(steps.len(), 2);
assert!(
matches!(&steps[0].response, TraceResponse::UserInput { content } if content == "Hello!")
);
assert!(
matches!(&steps[1].response, TraceResponse::Text { content, .. } if content == "hello back")
);
}
#[tokio::test]
async fn captures_text_response_correctly() {
let stub = Arc::new(StubLlm::new("test response"));
let recorder = make_recorder(stub);
let request = CompletionRequest::new(vec![ChatMessage::user("question")]);
recorder.complete(request).await.unwrap();
let steps = recorder.steps.lock().await;
assert_eq!(steps.len(), 2);
match &steps[1].response {
TraceResponse::Text {
content,
input_tokens,
output_tokens,
} => {
assert_eq!(content, "test response");
let _ = (*input_tokens, *output_tokens);
}
_ => panic!("Expected Text response"),
}
}
#[tokio::test]
async fn captures_tool_calls_response() {
let stub = Arc::new(StubLlm::new("tool result"));
let recorder = make_recorder(stub);
let request = ToolCompletionRequest::new(vec![ChatMessage::user("use a tool")], vec![]);
recorder.complete_with_tools(request).await.unwrap();
let steps = recorder.steps.lock().await;
assert_eq!(steps.len(), 2); }
#[tokio::test]
async fn no_spurious_user_input_for_tool_iterations() {
let stub = Arc::new(StubLlm::new("response"));
let recorder = make_recorder(stub);
let request = CompletionRequest::new(vec![
ChatMessage::system("sys"),
ChatMessage::user("Do something"),
]);
recorder.complete(request).await.unwrap();
let request = CompletionRequest::new(vec![
ChatMessage::system("sys"),
ChatMessage::user("Do something"),
ChatMessage::assistant("I'll use a tool"),
ChatMessage::tool_result("call_1", "echo", "result"),
]);
recorder.complete(request).await.unwrap();
let steps = recorder.steps.lock().await;
assert_eq!(steps.len(), 3);
assert!(matches!(
&steps[0].response,
TraceResponse::UserInput { .. }
));
assert!(matches!(&steps[1].response, TraceResponse::Text { .. }));
assert!(matches!(&steps[2].response, TraceResponse::Text { .. }));
}
#[tokio::test]
async fn captures_tool_results_for_verification() {
let stub = Arc::new(StubLlm::new("response"));
let recorder = make_recorder(stub);
let request = CompletionRequest::new(vec![
ChatMessage::system("sys"),
ChatMessage::user("Do something"),
]);
recorder.complete(request).await.unwrap();
let request = CompletionRequest::new(vec![
ChatMessage::system("sys"),
ChatMessage::user("Do something"),
ChatMessage::assistant("I'll use a tool"),
ChatMessage::tool_result("call_1", "echo", "echoed: hello"),
ChatMessage::tool_result("call_2", "time", "2026-03-04T14:00:00Z"),
]);
recorder.complete(request).await.unwrap();
let steps = recorder.steps.lock().await;
let step = &steps[2];
assert_eq!(step.expected_tool_results.len(), 2);
assert_eq!(step.expected_tool_results[0].name, "echo");
assert_eq!(step.expected_tool_results[0].content, "echoed: hello");
assert_eq!(step.expected_tool_results[1].name, "time");
}
#[tokio::test]
async fn request_hint_extraction() {
let stub = Arc::new(StubLlm::new("response"));
let recorder = make_recorder(stub);
let request = CompletionRequest::new(vec![
ChatMessage::system("sys"),
ChatMessage::user("What time is it?"),
]);
recorder.complete(request).await.unwrap();
let steps = recorder.steps.lock().await;
let text_step = &steps[1];
let hint = text_step.request_hint.as_ref().unwrap();
assert_eq!(
hint.last_user_message_contains.as_deref(),
Some("What time is it?")
);
assert_eq!(hint.min_message_count, Some(2));
}
#[tokio::test]
async fn flush_writes_valid_json_with_all_fields() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("trace.json");
let stub = Arc::new(StubLlm::new("response"));
let recorder = RecordingLlm::new(stub, path.clone(), "flush-test".to_string());
recorder
.memory_snapshot
.lock()
.await
.push(MemorySnapshotEntry {
path: "context/test.md".to_string(),
content: "test content".to_string(),
});
recorder
.http_interceptor
.after_response(
&HttpExchangeRequest {
method: "GET".to_string(),
url: "https://api.example.com/data".to_string(),
headers: Vec::new(),
body: None,
},
&HttpExchangeResponse {
status: 200,
headers: Vec::new(),
body: r#"{"ok": true}"#.to_string(),
},
)
.await;
let request = CompletionRequest::new(vec![ChatMessage::user("hello")]);
recorder.complete(request).await.unwrap();
recorder.flush().await.unwrap();
let content = tokio::fs::read_to_string(&path).await.unwrap();
let trace: TraceFile = serde_json::from_str(&content).unwrap();
assert_eq!(trace.model_name, "flush-test");
assert_eq!(trace.memory_snapshot.len(), 1);
assert_eq!(trace.memory_snapshot[0].path, "context/test.md");
assert_eq!(trace.http_exchanges.len(), 1);
assert_eq!(trace.http_exchanges[0].response.status, 200);
assert_eq!(trace.steps.len(), 2);
}
#[test]
fn from_env_returns_none_when_unset() {
unsafe { std::env::remove_var("IRONCLAW_RECORD_TRACE") };
let stub = Arc::new(StubLlm::new("response"));
let result = RecordingLlm::from_env(stub);
assert!(result.is_none());
}
#[tokio::test]
async fn recording_http_interceptor_passes_through_and_records() {
let interceptor = RecordingHttpInterceptor::new();
let req = HttpExchangeRequest {
method: "GET".to_string(),
url: "https://example.com".to_string(),
headers: Vec::new(),
body: None,
};
assert!(interceptor.before_request(&req).await.is_none());
let resp = HttpExchangeResponse {
status: 200,
headers: Vec::new(),
body: "ok".to_string(),
};
interceptor.after_response(&req, &resp).await;
let exchanges = interceptor.take_exchanges().await;
assert_eq!(exchanges.len(), 1);
assert_eq!(exchanges[0].request.url, "https://example.com");
}
#[tokio::test]
async fn replaying_http_interceptor_returns_recorded_responses() {
let exchanges = vec![HttpExchange {
request: HttpExchangeRequest {
method: "GET".to_string(),
url: "https://api.example.com/data".to_string(),
headers: Vec::new(),
body: None,
},
response: HttpExchangeResponse {
status: 200,
headers: Vec::new(),
body: r#"{"items": []}"#.to_string(),
},
}];
let interceptor = ReplayingHttpInterceptor::new(exchanges);
let req = HttpExchangeRequest {
method: "GET".to_string(),
url: "https://api.example.com/data".to_string(),
headers: Vec::new(),
body: None,
};
let resp = interceptor.before_request(&req).await.unwrap();
assert_eq!(resp.status, 200);
assert_eq!(resp.body, r#"{"items": []}"#);
let resp = interceptor.before_request(&req).await.unwrap();
assert_eq!(resp.status, 599);
}
#[test]
fn serde_roundtrip_extended_format() {
let trace = TraceFile {
model_name: "test".to_string(),
memory_snapshot: vec![MemorySnapshotEntry {
path: "context/vision.md".to_string(),
content: "Be helpful.".to_string(),
}],
http_exchanges: vec![HttpExchange {
request: HttpExchangeRequest {
method: "GET".to_string(),
url: "https://api.example.com".to_string(),
headers: vec![("Accept".to_string(), "application/json".to_string())],
body: None,
},
response: HttpExchangeResponse {
status: 200,
headers: Vec::new(),
body: "{}".to_string(),
},
}],
steps: vec![
TraceStep {
request_hint: None,
response: TraceResponse::UserInput {
content: "hello".to_string(),
},
expected_tool_results: Vec::new(),
},
TraceStep {
request_hint: Some(RequestHint {
last_user_message_contains: Some("hello".to_string()),
min_message_count: Some(2),
}),
response: TraceResponse::ToolCalls {
tool_calls: vec![TraceToolCall {
id: "call_1".to_string(),
name: "echo".to_string(),
arguments: serde_json::json!({"message": "hi"}),
}],
input_tokens: 50,
output_tokens: 20,
},
expected_tool_results: Vec::new(),
},
TraceStep {
request_hint: None,
response: TraceResponse::Text {
content: "done".to_string(),
input_tokens: 80,
output_tokens: 10,
},
expected_tool_results: vec![ExpectedToolResult {
tool_call_id: "call_1".to_string(),
name: "echo".to_string(),
content: "hi".to_string(),
}],
},
],
};
let json = serde_json::to_string_pretty(&trace).unwrap();
let parsed: TraceFile = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.model_name, "test");
assert_eq!(parsed.memory_snapshot.len(), 1);
assert_eq!(parsed.http_exchanges.len(), 1);
assert_eq!(parsed.steps.len(), 3);
assert_eq!(parsed.steps[2].expected_tool_results.len(), 1);
}
#[tokio::test]
async fn request_hint_handles_multibyte_utf8() {
let stub = Arc::new(StubLlm::new("response"));
let recorder = make_recorder(stub);
let long_cjk = "你".repeat(27); assert!(long_cjk.len() > 80);
let request = CompletionRequest::new(vec![
ChatMessage::system("sys"),
ChatMessage::user(&long_cjk),
]);
recorder.complete(request).await.unwrap();
let steps = recorder.steps.lock().await;
let text_step = &steps[1];
let hint = text_step.request_hint.as_ref().unwrap();
let hint_text = hint.last_user_message_contains.as_deref().unwrap();
assert!(hint_text.len() <= 80);
assert!(hint_text.is_ascii() || hint_text.chars().count() > 0);
}
#[test]
fn backward_compatible_with_old_format() {
let json = r#"{
"model_name": "old-trace",
"steps": [
{
"response": {
"type": "text",
"content": "hello",
"input_tokens": 10,
"output_tokens": 5
}
}
]
}"#;
let trace: TraceFile = serde_json::from_str(json).unwrap();
assert_eq!(trace.model_name, "old-trace");
assert!(trace.memory_snapshot.is_empty());
assert!(trace.http_exchanges.is_empty());
assert!(trace.steps[0].expected_tool_results.is_empty());
}
}