use axum::{
body::Body, extract::State, http::HeaderMap, response::Response, routing::post, Router,
};
use std::{
collections::VecDeque,
net::SocketAddr,
sync::{Arc, Mutex},
};
use tokio::net::TcpListener;
#[derive(Clone)]
pub enum MockScenario {
TextResponse(String),
ToolCallResponse {
name: String,
args: String,
final_text: String,
},
ErrorResponse(u16),
}
pub struct MockState {
pub scenarios: Mutex<VecDeque<MockScenario>>,
}
impl MockState {
pub fn new(scenarios: Vec<MockScenario>) -> Arc<Self> {
Arc::new(Self {
scenarios: Mutex::new(scenarios.into_iter().collect()),
})
}
}
fn sse_text_chunk(text: &str) -> String {
format!(
"data: {{\"choices\":[{{\"delta\":{{\"content\":{}}},\"finish_reason\":null}}]}}\n\n",
serde_json::to_string(text).unwrap()
)
}
fn sse_tool_call_open(call_id: &str, name: &str) -> String {
format!(
"data: {{\"choices\":[{{\"delta\":{{\"tool_calls\":[{{\"index\":0,\"id\":{id},\"type\":\"function\",\"function\":{{\"name\":{name},\"arguments\":\"\"}}}}]}},\"finish_reason\":null}}]}}\n\n",
id = serde_json::to_string(call_id).unwrap(),
name = serde_json::to_string(name).unwrap(),
)
}
fn sse_tool_call_args(args: &str) -> String {
format!(
"data: {{\"choices\":[{{\"delta\":{{\"tool_calls\":[{{\"index\":0,\"function\":{{\"arguments\":{}}}}}]}},\"finish_reason\":null}}]}}\n\n",
serde_json::to_string(args).unwrap()
)
}
const SSE_DONE: &str = "data: [DONE]\n\n";
async fn chat_completions(
State(state): State<Arc<MockState>>,
_headers: HeaderMap,
_body: axum::body::Bytes,
) -> Response {
let scenario = {
let mut q = state.scenarios.lock().unwrap();
q.pop_front()
};
match scenario {
None => {
let body = format!("{}{}", sse_text_chunk("(no scenario)"), SSE_DONE);
Response::builder()
.status(200)
.header("content-type", "text/event-stream")
.body(Body::from(body))
.unwrap()
}
Some(MockScenario::TextResponse(text)) => {
let body = format!("{}{}", sse_text_chunk(&text), SSE_DONE);
Response::builder()
.status(200)
.header("content-type", "text/event-stream")
.body(Body::from(body))
.unwrap()
}
Some(MockScenario::ToolCallResponse {
name,
args,
final_text,
}) => {
{
let mut q = state.scenarios.lock().unwrap();
q.push_front(MockScenario::TextResponse(final_text));
}
let body = format!(
"{}{}{}",
sse_tool_call_open("call_test_001", &name),
sse_tool_call_args(&args),
SSE_DONE
);
Response::builder()
.status(200)
.header("content-type", "text/event-stream")
.body(Body::from(body))
.unwrap()
}
Some(MockScenario::ErrorResponse(code)) => Response::builder()
.status(code)
.body(Body::empty())
.unwrap(),
}
}
pub async fn start_mock_server(scenarios: Vec<MockScenario>) -> (SocketAddr, Arc<MockState>) {
let state = MockState::new(scenarios);
let app = Router::new()
.route("/v1/chat/completions", post(chat_completions))
.with_state(state.clone());
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(async move {
axum::serve(listener, app).await.unwrap();
});
(addr, state)
}