#![allow(dead_code)]
use axum::{
extract::State,
http::{header, StatusCode},
response::IntoResponse,
routing::post,
Router,
};
use serde_json::json;
use std::sync::{
atomic::{AtomicUsize, Ordering},
Arc,
};
use tokio_util::sync::CancellationToken;
#[derive(Clone)]
pub struct MultiDiffMockState {
pub call_count: Arc<AtomicUsize>,
}
fn extract_paths_from_request(body: &serde_json::Value) -> Vec<String> {
let mut paths = Vec::new();
let messages = match body.get("messages").and_then(|m| m.as_array()) {
Some(m) => m,
None => return paths,
};
for msg in messages {
let content = match msg.get("content").and_then(|c| c.as_str()) {
Some(c) => c,
None => continue,
};
let mut in_files_section = false;
for line in content.lines() {
if line.trim() == "Files:" {
in_files_section = true;
continue;
}
if in_files_section {
let trimmed = line.trim();
if trimmed.starts_with('/') {
let p = trimmed.to_string();
if !p.is_empty() && !paths.contains(&p) {
paths.push(p);
}
if paths.len() >= 2 {
break;
}
} else if !trimmed.is_empty() {
in_files_section = false;
}
}
}
if paths.len() >= 2 {
break;
}
for line in content.lines() {
if let Some(rest) = line.strip_prefix("=== Current file content (path=") {
if let Some(path) = rest.strip_suffix(") ===") {
let p = path.trim().to_string();
if !p.is_empty() && !paths.contains(&p) {
paths.push(p);
}
}
} else if let Some(rest) = line.strip_prefix("=== File (path=") {
if let Some(path) = rest.strip_suffix(") does not exist yet ===") {
let p = path.trim().to_string();
if !p.is_empty() && !paths.contains(&p) {
paths.push(p);
}
}
}
}
if paths.len() >= 2 {
break;
}
}
paths
}
async fn multi_diff_happy_handler(
State(state): State<MultiDiffMockState>,
body: axum::body::Bytes,
) -> impl IntoResponse {
let req_value = match serde_json::from_slice::<serde_json::Value>(&body) {
Ok(v) => v,
Err(e) => {
eprintln!("[compile_loop_diff_multi_anthropic_mock] bad request body: {e}");
let err_body = json!({ "error": format!("bad request: {e}") }).to_string();
return (
StatusCode::BAD_REQUEST,
[(header::CONTENT_TYPE, "application/json")],
err_body,
);
}
};
let prev = state.call_count.fetch_add(1, Ordering::SeqCst);
let paths = extract_paths_from_request(&req_value);
let (path_a, path_b) = if paths.len() >= 2 {
(paths[0].clone(), paths[1].clone())
} else if paths.len() == 1 {
(paths[0].clone(), "file_b.lua".to_string())
} else {
("file_a.lua".to_string(), "file_b.lua".to_string())
};
let text = format!(
"<<< path={path_a} >>>\n<<<<<<< SEARCH\nprint(\"a-old\")\n=======\nprint(\"a-new\")\n>>>>>>> REPLACE\n\n<<< path={path_b} >>>\n<<<<<<< SEARCH\nprint(\"b-old\")\n=======\nprint(\"b-new\")\n>>>>>>> REPLACE"
);
let response_json = json!({
"id": format!("msg_multi_diff_happy_{}", prev + 1),
"type": "message",
"role": "assistant",
"content": [{ "type": "text", "text": text }],
"model": "claude-haiku-mock",
"stop_reason": "end_turn",
"usage": { "input_tokens": 10, "output_tokens": 20 }
});
(
StatusCode::OK,
[(header::CONTENT_TYPE, "application/json")],
response_json.to_string(),
)
}
async fn multi_diff_two_iter_handler(
State(state): State<MultiDiffMockState>,
body: axum::body::Bytes,
) -> impl IntoResponse {
let req_value = match serde_json::from_slice::<serde_json::Value>(&body) {
Ok(v) => v,
Err(e) => {
eprintln!("[compile_loop_diff_multi_anthropic_mock] bad request body: {e}");
let err_body = json!({ "error": format!("bad request: {e}") }).to_string();
return (
StatusCode::BAD_REQUEST,
[(header::CONTENT_TYPE, "application/json")],
err_body,
);
}
};
let prev = state.call_count.fetch_add(1, Ordering::SeqCst);
let paths = extract_paths_from_request(&req_value);
let (path_a, path_b) = if paths.len() >= 2 {
(paths[0].clone(), paths[1].clone())
} else if paths.len() == 1 {
(paths[0].clone(), "file_b.lua".to_string())
} else {
("file_a.lua".to_string(), "file_b.lua".to_string())
};
let text = if prev == 0 {
format!(
"<<< path={path_a} >>>\n<<<<<<< SEARCH\nprint(\"WRONG\")\n=======\nprint(\"a-new\")\n>>>>>>> REPLACE"
)
} else {
format!(
"<<< path={path_a} >>>\n<<<<<<< SEARCH\nprint(\"a-old\")\n=======\nprint(\"a-new\")\n>>>>>>> REPLACE\n\n<<< path={path_b} >>>\n<<<<<<< SEARCH\nprint(\"b-old\")\n=======\nprint(\"b-new\")\n>>>>>>> REPLACE"
)
};
let response_json = json!({
"id": format!("msg_multi_diff_two_iter_{}", prev + 1),
"type": "message",
"role": "assistant",
"content": [{ "type": "text", "text": text }],
"model": "claude-haiku-mock",
"stop_reason": "end_turn",
"usage": { "input_tokens": 10, "output_tokens": 20 }
});
(
StatusCode::OK,
[(header::CONTENT_TYPE, "application/json")],
response_json.to_string(),
)
}
pub async fn spawn_compile_loop_diff_multi_anthropic_mock_server(
) -> (String, Arc<AtomicUsize>, CancellationToken) {
let call_count = Arc::new(AtomicUsize::new(0));
let ct = CancellationToken::new();
let state = MultiDiffMockState {
call_count: call_count.clone(),
};
let router = Router::new()
.route("/v1/messages", post(multi_diff_happy_handler))
.with_state(state);
let listener = tokio::net::TcpListener::bind("127.0.0.1:0")
.await
.expect("bind ephemeral port for compile_loop diff multi anthropic mock");
let addr = listener.local_addr().expect("local_addr");
let ct_shutdown = ct.clone();
tokio::spawn(async move {
let _ = axum::serve(listener, router)
.with_graceful_shutdown(async move { ct_shutdown.cancelled_owned().await })
.await;
});
(format!("http://{addr}"), call_count, ct)
}
pub async fn spawn_compile_loop_diff_multi_anthropic_mock_two_iter_server(
) -> (String, Arc<AtomicUsize>, CancellationToken) {
let call_count = Arc::new(AtomicUsize::new(0));
let ct = CancellationToken::new();
let state = MultiDiffMockState {
call_count: call_count.clone(),
};
let router = Router::new()
.route("/v1/messages", post(multi_diff_two_iter_handler))
.with_state(state);
let listener = tokio::net::TcpListener::bind("127.0.0.1:0")
.await
.expect("bind ephemeral port for compile_loop diff multi anthropic mock two-iter");
let addr = listener.local_addr().expect("local_addr");
let ct_shutdown = ct.clone();
tokio::spawn(async move {
let _ = axum::serve(listener, router)
.with_graceful_shutdown(async move { ct_shutdown.cancelled_owned().await })
.await;
});
(format!("http://{addr}"), call_count, ct)
}