pub mod anthropic;
pub mod gemini;
pub mod openai;
pub mod responses;
use std::sync::Arc;
use std::time::{Duration, Instant};
use axum::body::Body;
use axum::http::{header, Response, StatusCode};
use axum::response::IntoResponse;
use tokio::time::sleep;
use crate::failure;
use crate::fixture::match_fixture;
use crate::format::Provider;
use crate::server::AppState;
fn elapsed_ms(start: &Instant) -> u64 {
u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX)
}
pub(crate) enum StreamOutput {
Sse(Vec<String>),
JsonArray(Vec<String>),
}
#[allow(clippy::too_many_arguments)]
pub(crate) trait ProviderHandler: Send + Sync {
fn provider(&self) -> Provider;
fn route_label(&self) -> &str;
fn build_error_body(&self, status: u16, message: &str) -> String {
failure::build_error_body(status, message)
}
fn extract_request_info(&self, body: &serde_json::Value) -> Result<(String, String), String>;
fn is_streaming(&self, body: &serde_json::Value) -> bool {
body["stream"].as_bool().unwrap_or(false)
}
fn default_stop_reason(&self) -> &str;
fn build_response(
&self,
state: &AppState,
model: &str,
content: &str,
prompt: &str,
stop_reason: &str,
has_explicit_reason: bool,
) -> String;
fn build_tool_call_response(
&self,
state: &AppState,
model: &str,
tool_calls: &[(&str, serde_json::Value)],
prompt: &str,
stop_reason: &str,
has_explicit_reason: bool,
) -> String;
fn build_stream_frames(
&self,
state: &AppState,
model: &str,
content: &str,
chunk_size: usize,
prompt: &str,
stop_reason: &str,
has_explicit_reason: bool,
) -> StreamOutput;
fn build_tool_call_stream_frames(
&self,
state: &AppState,
model: &str,
tool_calls: &[(&str, serde_json::Value)],
chunk_size: usize,
prompt: &str,
stop_reason: &str,
has_explicit_reason: bool,
) -> StreamOutput;
}
pub(crate) async fn handle_request(
handler: &dyn ProviderHandler,
state: Arc<AppState>,
body: String,
) -> Response<Body> {
let json_body: serde_json::Value = match serde_json::from_str(&body) {
Ok(v) => v,
Err(_) => {
return (
StatusCode::BAD_REQUEST,
[(header::CONTENT_TYPE, "application/json")],
handler.build_error_body(400, "Invalid JSON in request body"),
)
.into_response();
}
};
let (model, user_message) = match handler.extract_request_info(&json_body) {
Ok(info) => info,
Err(msg) => {
return (
StatusCode::BAD_REQUEST,
[(header::CONTENT_TYPE, "application/json")],
handler.build_error_body(400, &msg),
)
.into_response();
}
};
if handler.provider() != Provider::Gemini {
if let Some(sv) = json_body.get("stream") {
if sv.as_bool().is_none() {
return (
StatusCode::BAD_REQUEST,
[(header::CONTENT_TYPE, "application/json")],
handler.build_error_body(400, "\"stream\" must be a boolean"),
)
.into_response();
}
}
}
let is_streaming = handler.is_streaming(&json_body);
let (fixture, scenario_name) = {
let mut scenarios = state.scenarios.write().unwrap_or_else(|e| e.into_inner());
let matched = match_fixture(
&state.fixtures,
&user_message,
Some(&model),
Some(handler.provider()),
Some(&scenarios),
);
if let Some(f) = matched {
let name = if let Some(ref scenario) = f.scenario {
if let Some(ref next_state) = scenario.set_state {
scenarios.insert(scenario.name.clone(), next_state.clone());
}
Some(scenario.name.clone())
} else {
None
};
(Some(f), name)
} else {
(None, None)
}
};
state
.captured_requests
.write()
.unwrap_or_else(|e| e.into_inner())
.push(crate::server::CapturedRequest {
method: "POST".to_string(),
path: handler.route_label().to_string(),
body,
matched_scenario: scenario_name,
timestamp: std::time::Instant::now(),
});
let fixture = match fixture {
Some(f) => f,
None => {
if state.verbose {
let char_count = user_message.chars().count();
let preview: String = user_message.chars().take(50).collect();
eprintln!(
"[llmposter] POST {} → no match (model='{}', msg='{}...' ({} chars))",
handler.route_label(),
model,
preview,
char_count
);
}
let msg = format!("No fixture matched for model='{}'", model);
return (
StatusCode::NOT_FOUND,
[(header::CONTENT_TYPE, "application/json")],
handler.build_error_body(404, &msg),
)
.into_response();
}
};
if state.verbose {
eprintln!(
"[llmposter] POST {} → fixture matched",
handler.route_label()
);
}
if let Some(ref err) = fixture.error {
let status = StatusCode::from_u16(err.status).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
let body = handler.build_error_body(status.as_u16(), &err.message);
let mut builder = Response::builder().status(status);
for (name, value) in &err.headers {
builder = builder.header(name.as_str(), value.as_str());
}
let has_content_type = err
.headers
.keys()
.any(|k| k.eq_ignore_ascii_case("content-type"));
if !has_content_type {
builder = builder.header(header::CONTENT_TYPE, "application/json");
}
return match builder.body(Body::from(body)) {
Ok(resp) => resp.into_response(),
Err(_) => (
StatusCode::INTERNAL_SERVER_ERROR,
[(header::CONTENT_TYPE, "application/json")],
handler.build_error_body(500, "Fixture contains invalid header name or value"),
)
.into_response(),
};
}
let response = match fixture.response.as_ref() {
Some(r) => r,
None => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
[(header::CONTENT_TYPE, "application/json")],
handler.build_error_body(500, "Fixture has neither response nor error"),
)
.into_response();
}
};
let content = response.content.as_deref().unwrap_or("");
let has_explicit_reason = response.stop_reason.is_some() || response.finish_reason.is_some();
let stop_reason = response
.stop_reason
.as_deref()
.or(response.finish_reason.as_deref())
.unwrap_or(handler.default_stop_reason());
if let Some(ref fail) = fixture.failure {
if let Some(ms) = fail.latency_ms {
sleep(Duration::from_millis(ms)).await;
}
if fail.corrupt_body == Some(true) {
return (
StatusCode::OK,
[(header::CONTENT_TYPE, "text/plain")],
"overloaded".to_string(),
)
.into_response();
}
}
let tc_pairs: Option<Vec<(&str, serde_json::Value)>> =
response.tool_calls.as_ref().map(|tool_calls| {
tool_calls
.iter()
.map(|tc| (tc.name.as_str(), tc.arguments.clone()))
.collect()
});
if is_streaming {
let chunk_size = fixture
.streaming
.as_ref()
.and_then(|s| s.chunk_size)
.unwrap_or(20);
let latency = fixture
.streaming
.as_ref()
.and_then(|s| s.latency)
.unwrap_or(0);
let truncate_after = fixture
.failure
.as_ref()
.and_then(|f| f.truncate_after_frames);
let disconnect_after_ms = fixture.failure.as_ref().and_then(|f| f.disconnect_after_ms);
let stream_output = if let Some(ref tc) = tc_pairs {
handler.build_tool_call_stream_frames(
&state,
&model,
tc,
chunk_size,
&user_message,
stop_reason,
has_explicit_reason,
)
} else {
handler.build_stream_frames(
&state,
&model,
content,
chunk_size,
&user_message,
stop_reason,
has_explicit_reason,
)
};
match stream_output {
StreamOutput::Sse(frames) => {
stream_sse_frames(frames, latency, truncate_after, disconnect_after_ms).await
}
StreamOutput::JsonArray(frames) => {
stream_json_array(frames, latency, truncate_after, disconnect_after_ms).await
}
}
} else {
let json = if let Some(ref tc) = tc_pairs {
handler.build_tool_call_response(
&state,
&model,
tc,
&user_message,
stop_reason,
has_explicit_reason,
)
} else {
handler.build_response(
&state,
&model,
content,
&user_message,
stop_reason,
has_explicit_reason,
)
};
(
StatusCode::OK,
[(header::CONTENT_TYPE, "application/json")],
json,
)
.into_response()
}
}
async fn stream_sse_frames(
frames: Vec<String>,
latency: u64,
truncate_after: Option<u32>,
disconnect_after_ms: Option<u64>,
) -> Response<Body> {
let (tx, rx) = tokio::sync::mpsc::channel::<Result<String, std::io::Error>>(32);
tokio::spawn(async move {
let send_frames = async {
let total = frames.len();
for (sent, frame) in frames.into_iter().enumerate() {
tokio::task::yield_now().await;
if let Some(max) = truncate_after {
if sent as u32 >= max {
return;
}
}
if tx.send(Ok(frame)).await.is_err() {
return;
}
if latency > 0 && sent + 1 < total {
sleep(Duration::from_millis(latency)).await;
}
}
};
if let Some(ms) = disconnect_after_ms {
tokio::select! {
biased;
_ = sleep(Duration::from_millis(ms)) => {
let _ = tx
.send(Err(std::io::Error::new(
std::io::ErrorKind::ConnectionReset,
"llmposter: simulated disconnect",
)))
.await;
}
_ = send_frames => {}
}
} else {
send_frames.await;
}
});
let stream = tokio_stream::wrappers::ReceiverStream::new(rx);
Response::builder()
.status(StatusCode::OK)
.header(header::CONTENT_TYPE, "text/event-stream")
.header(header::CACHE_CONTROL, "no-cache")
.body(Body::from_stream(stream))
.expect("static SSE response headers")
}
async fn stream_json_array(
frames: Vec<String>,
latency: u64,
truncate_after: Option<u32>,
disconnect_after_ms: Option<u64>,
) -> Response<Body> {
let mut collected: Vec<String> = Vec::new();
let start = Instant::now();
for (i, frame) in frames.into_iter().enumerate() {
tokio::task::yield_now().await;
if let Some(ms) = disconnect_after_ms {
if start.elapsed() >= Duration::from_millis(ms) {
break;
}
}
if let Some(max) = truncate_after {
if i as u32 >= max {
break;
}
}
collected.push(frame);
if latency > 0 {
if let Some(ms) = disconnect_after_ms {
let remaining = ms.saturating_sub(elapsed_ms(&start));
if remaining == 0 {
break;
}
let wait = Duration::from_millis(latency.min(remaining));
sleep(wait).await;
if start.elapsed() >= Duration::from_millis(ms) {
collected.pop();
break;
}
} else {
sleep(Duration::from_millis(latency)).await;
}
}
}
let json = format!("[{}]", collected.join(","));
(
StatusCode::OK,
[(header::CONTENT_TYPE, "application/json")],
json,
)
.into_response()
}