pub mod anthropic;
pub mod gemini;
pub mod openai;
pub mod responses;
use std::sync::Arc;
use std::time::{Duration, Instant};
use axum::body::Body;
use axum::http::{header, Response, StatusCode};
use axum::response::IntoResponse;
use tokio::time::sleep;
use crate::failure;
use crate::format::Provider;
use crate::server::AppState;
fn elapsed_ms(start: &Instant) -> u64 {
u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX)
}
pub(crate) enum StreamOutput {
Sse(Vec<String>),
JsonArray(Vec<String>),
}
#[allow(clippy::too_many_arguments)]
pub(crate) trait ProviderHandler: Send + Sync {
fn provider(&self) -> Provider;
fn route_label(&self) -> &str;
fn build_error_body(&self, status: u16, message: &str) -> String {
failure::build_error_body(status, message)
}
fn extract_request_info(&self, body: &serde_json::Value) -> Result<(String, String), String>;
fn is_streaming(&self, body: &serde_json::Value) -> bool {
body["stream"].as_bool().unwrap_or(false)
}
fn streaming_is_sse(&self) -> bool {
true
}
fn default_stop_reason(&self) -> &str;
fn build_response(
&self,
state: &AppState,
model: &str,
content: &str,
prompt: &str,
stop_reason: &str,
has_explicit_reason: bool,
) -> String;
fn build_tool_call_response(
&self,
state: &AppState,
model: &str,
tool_calls: &[(&str, serde_json::Value)],
prompt: &str,
stop_reason: &str,
has_explicit_reason: bool,
) -> String;
fn build_refusal_response(
&self,
state: &AppState,
model: &str,
reason: &str,
prompt: &str,
) -> String;
fn build_stream_frames(
&self,
state: &AppState,
model: &str,
content: &str,
chunk_size: usize,
prompt: &str,
stop_reason: &str,
has_explicit_reason: bool,
) -> StreamOutput;
fn build_tool_call_stream_frames(
&self,
state: &AppState,
model: &str,
tool_calls: &[(&str, serde_json::Value)],
chunk_size: usize,
prompt: &str,
stop_reason: &str,
has_explicit_reason: bool,
) -> StreamOutput;
}
pub(crate) fn header_map_to_lowercase(
headers: &axum::http::HeaderMap,
) -> std::collections::HashMap<String, String> {
let mut out: std::collections::HashMap<String, String> =
std::collections::HashMap::with_capacity(headers.keys_len());
for (name, value) in headers.iter() {
let Ok(v) = value.to_str() else {
continue;
};
out.entry(name.as_str().to_owned())
.and_modify(|existing| {
existing.push_str(", ");
existing.push_str(v);
})
.or_insert_with(|| v.to_string());
}
out
}
pub(crate) fn push_captured(
state: &AppState,
method: &str,
path: &str,
body: String,
outcome: crate::server::RequestOutcome,
matched_scenario: Option<String>,
) {
if state.capture_capacity == Some(0) {
return;
}
let mut guard = state
.captured_requests
.write()
.unwrap_or_else(|e| e.into_inner());
if let Some(cap) = state.capture_capacity {
while guard.len() >= cap {
guard.pop_front();
}
}
guard.push_back(crate::server::CapturedRequest {
method: method.to_string(),
path: path.to_string(),
body,
outcome,
matched_scenario,
timestamp: std::time::Instant::now(),
});
}
pub(crate) fn capture_non_matched(
state: &AppState,
method: &str,
path: &str,
body: &str,
outcome: crate::server::RequestOutcome,
) {
if state.capture_capacity == Some(0) {
return;
}
push_captured(state, method, path, body.to_string(), outcome, None);
}
pub(crate) async fn handle_request(
handler: &dyn ProviderHandler,
state: Arc<AppState>,
headers: std::collections::HashMap<String, String>,
body: String,
) -> Response<Body> {
let bad_request = |msg: &str| -> Response<Body> {
capture_non_matched(
&state,
"POST",
handler.route_label(),
&body,
crate::server::RequestOutcome::BadRequest,
);
(
StatusCode::BAD_REQUEST,
[(header::CONTENT_TYPE, "application/json")],
handler.build_error_body(400, msg),
)
.into_response()
};
let json_body: serde_json::Value = match serde_json::from_str(&body) {
Ok(v) => v,
Err(_) => return bad_request("Invalid JSON in request body"),
};
let (model, user_message) = match handler.extract_request_info(&json_body) {
Ok(info) => info,
Err(msg) => return bad_request(&msg),
};
if handler.provider() != Provider::Gemini {
if let Some(sv) = json_body.get("stream") {
if sv.as_bool().is_none() {
return bad_request("\"stream\" must be a boolean");
}
}
}
let is_streaming = handler.is_streaming(&json_body);
let fixture = {
let fixtures = state.fixtures.read().unwrap_or_else(|e| e.into_inner());
let mut scenarios = state.scenarios.write().unwrap_or_else(|e| e.into_inner());
let ctx = crate::fixture::MatchContext::new(
&user_message,
Some(&model),
Some(handler.provider()),
Some(&scenarios),
&headers,
&json_body,
);
let sort_by_priority = |idx: &mut Vec<usize>, all: &[Arc<crate::fixture::Fixture>]| {
idx.sort_by_key(|&i| std::cmp::Reverse(all[i].priority.unwrap_or(0)));
};
let mut primary_idx: Vec<usize> = fixtures
.iter()
.enumerate()
.filter(|(_, f)| !f.catch_all)
.map(|(i, _)| i)
.collect();
sort_by_priority(&mut primary_idx, &fixtures);
let matched = primary_idx
.into_iter()
.map(|i| &fixtures[i])
.find(|f| crate::fixture::fixture_matches(f, &ctx))
.or_else(|| {
let mut catch_idx: Vec<usize> = fixtures
.iter()
.enumerate()
.filter(|(_, f)| f.catch_all)
.map(|(i, _)| i)
.collect();
sort_by_priority(&mut catch_idx, &fixtures);
catch_idx
.into_iter()
.map(|i| &fixtures[i])
.find(|f| crate::fixture::fixture_matches(f, &ctx))
});
let (arc_fixture, scenario_name) = if let Some(f) = matched {
let name = if let Some(ref scenario) = f.scenario {
if let Some(ref next_state) = scenario.set_state {
scenarios.insert(scenario.name.clone(), next_state.clone());
}
Some(scenario.name.clone())
} else {
None
};
(Some(std::sync::Arc::clone(f)), name)
} else {
(None, None)
};
let outcome = if arc_fixture.is_some() {
crate::server::RequestOutcome::Matched
} else {
crate::server::RequestOutcome::NoFixtureMatch
};
push_captured(
&state,
"POST",
handler.route_label(),
body,
outcome,
scenario_name,
);
arc_fixture
};
let fixture = match fixture {
Some(f) => f,
None => {
if state.verbose {
let char_count = user_message.chars().count();
eprintln!(
"[llmposter] POST {} → no match (model='{}', msg len={} chars)",
handler.route_label(),
model,
char_count
);
}
let msg = format!("No fixture matched for model='{}'", model);
return (
StatusCode::NOT_FOUND,
[(header::CONTENT_TYPE, "application/json")],
handler.build_error_body(404, &msg),
)
.into_response();
}
};
if state.verbose {
eprintln!(
"[llmposter] POST {} → fixture matched",
handler.route_label()
);
}
if let Some(ref refusal) = fixture.refusal {
if is_streaming {
return (
StatusCode::BAD_REQUEST,
[(header::CONTENT_TYPE, "application/json")],
handler.build_error_body(
400,
"refusal fixtures do not currently support streaming — \
re-run with `stream: false` or use a regular `response:` fixture",
),
)
.into_response();
}
let body = handler.build_refusal_response(&state, &model, &refusal.reason, &user_message);
return (
StatusCode::OK,
[(header::CONTENT_TYPE, "application/json")],
body,
)
.into_response();
}
if let Some(ref err) = fixture.error {
let status = StatusCode::from_u16(err.status).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
let body = handler.build_error_body(status.as_u16(), &err.message);
let mut builder = Response::builder().status(status);
for (name, value) in &err.headers {
builder = builder.header(name.as_str(), value.as_str());
}
let has_content_type = err
.headers
.keys()
.any(|k| k.eq_ignore_ascii_case("content-type"));
if !has_content_type {
builder = builder.header(header::CONTENT_TYPE, "application/json");
}
return match builder.body(Body::from(body)) {
Ok(resp) => resp.into_response(),
Err(_) => (
StatusCode::INTERNAL_SERVER_ERROR,
[(header::CONTENT_TYPE, "application/json")],
handler.build_error_body(500, "Fixture contains invalid header name or value"),
)
.into_response(),
};
}
let response = match fixture.response.as_ref() {
Some(r) => r,
None => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
[(header::CONTENT_TYPE, "application/json")],
handler.build_error_body(500, "Fixture has neither response nor error"),
)
.into_response();
}
};
#[cfg(feature = "templating")]
let rendered_template: Option<String> = match response.content_template.as_deref() {
Some(tmpl) => match crate::templating::render(
tmpl,
&response.template_cache,
&user_message,
&model,
handler.provider().as_str(),
&json_body,
) {
Ok(s) => Some(s),
Err(e) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
[(header::CONTENT_TYPE, "application/json")],
handler.build_error_body(500, &format!("content_template: {}", e)),
)
.into_response();
}
},
None => None,
};
#[cfg(feature = "templating")]
let content = rendered_template
.as_deref()
.or(response.content.as_deref())
.unwrap_or("");
#[cfg(not(feature = "templating"))]
let content = response.content.as_deref().unwrap_or("");
let has_explicit_reason = response.stop_reason.is_some() || response.finish_reason.is_some();
let stop_reason = response
.stop_reason
.as_deref()
.or(response.finish_reason.as_deref())
.unwrap_or(handler.default_stop_reason());
if let Some(ref fail) = fixture.failure {
if let Some(ms) = fail.latency_ms {
sleep(Duration::from_millis(ms)).await;
}
if fail.corrupt_body == Some(true) {
if is_streaming && handler.streaming_is_sse() {
return (
StatusCode::OK,
[(header::CONTENT_TYPE, "text/event-stream")],
"data: overloaded\n\n".to_string(),
)
.into_response();
}
return (
StatusCode::OK,
[(header::CONTENT_TYPE, "text/plain")],
"overloaded".to_string(),
)
.into_response();
}
}
let tc_pairs: Option<Vec<(&str, serde_json::Value)>> =
response.tool_calls.as_ref().map(|tool_calls| {
tool_calls
.iter()
.map(|tc| (tc.name.as_str(), tc.arguments.clone()))
.collect()
});
if is_streaming {
let chunk_size = fixture
.streaming
.as_ref()
.and_then(|s| s.chunk_size)
.unwrap_or(20);
let latency = fixture
.streaming
.as_ref()
.and_then(|s| s.latency)
.unwrap_or(0);
let truncate_after = fixture
.failure
.as_ref()
.and_then(|f| f.truncate_after_frames);
let disconnect_after_ms = fixture.failure.as_ref().and_then(|f| f.disconnect_after_ms);
let stream_output = if let Some(ref tc) = tc_pairs {
handler.build_tool_call_stream_frames(
&state,
&model,
tc,
chunk_size,
&user_message,
stop_reason,
has_explicit_reason,
)
} else {
handler.build_stream_frames(
&state,
&model,
content,
chunk_size,
&user_message,
stop_reason,
has_explicit_reason,
)
};
let failure_ref = fixture.failure.as_ref();
let has_chaos = failure_ref.map(|f| f.has_chaos()).unwrap_or(false);
let chaos_n = if has_chaos {
state
.chaos_counter
.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
} else {
0
};
let frame_count = match &stream_output {
StreamOutput::Sse(v) | StreamOutput::JsonArray(v) => v.len(),
};
let plan =
crate::chaos::ChaosPlan::from_failure(failure_ref, latency, frame_count, chaos_n);
if state.verbose && plan.active && (plan.duplicate || plan.frame_delays_ms.is_some()) {
eprintln!("[llmposter] POST {} → chaos active", handler.route_label());
}
match stream_output {
StreamOutput::Sse(frames) => {
let frames = plan.apply_frame_duplication(frames);
stream_sse_frames(frames, latency, &plan, truncate_after, disconnect_after_ms).await
}
StreamOutput::JsonArray(frames) => {
let frames = plan.apply_frame_duplication(frames);
stream_json_array(frames, latency, &plan, truncate_after, disconnect_after_ms).await
}
}
} else {
let json = if let Some(ref tc) = tc_pairs {
handler.build_tool_call_response(
&state,
&model,
tc,
&user_message,
stop_reason,
has_explicit_reason,
)
} else {
handler.build_response(
&state,
&model,
content,
&user_message,
stop_reason,
has_explicit_reason,
)
};
(
StatusCode::OK,
[(header::CONTENT_TYPE, "application/json")],
json,
)
.into_response()
}
}
async fn stream_sse_frames(
frames: Vec<String>,
base_latency: u64,
plan: &crate::chaos::ChaosPlan,
truncate_after: Option<u32>,
disconnect_after_ms: Option<u64>,
) -> Response<Body> {
let delays_override = match plan.frame_delays_ms.as_ref() {
Some(v) if v.len() == frames.len() => Some(v.clone()),
Some(v) => {
eprintln!(
"[llmposter] stream_sse_frames: frame_delays_ms length mismatch \
(frames={}, delays={}) — falling back to base latency",
frames.len(),
v.len()
);
None
}
None => None,
};
let (tx, rx) = tokio::sync::mpsc::channel::<Result<String, std::io::Error>>(32);
tokio::spawn(async move {
let send_frames = async {
let total = frames.len();
for (sent, frame) in frames.into_iter().enumerate() {
tokio::task::yield_now().await;
if let Some(max) = truncate_after {
if sent as u32 >= max {
return;
}
}
if tx.send(Ok(frame)).await.is_err() {
return;
}
let delay = delays_override
.as_ref()
.and_then(|v| v.get(sent).copied())
.unwrap_or(base_latency);
if delay > 0 && sent + 1 < total {
sleep(Duration::from_millis(delay)).await;
}
}
};
if let Some(ms) = disconnect_after_ms {
tokio::select! {
biased;
_ = sleep(Duration::from_millis(ms)) => {
let _ = tx
.send(Err(std::io::Error::new(
std::io::ErrorKind::ConnectionReset,
"llmposter: simulated disconnect",
)))
.await;
}
_ = send_frames => {}
}
} else {
send_frames.await;
}
});
let stream = tokio_stream::wrappers::ReceiverStream::new(rx);
Response::builder()
.status(StatusCode::OK)
.header(header::CONTENT_TYPE, "text/event-stream")
.header(header::CACHE_CONTROL, "no-cache")
.body(Body::from_stream(stream))
.expect("static SSE response headers")
}
async fn stream_json_array(
frames: Vec<String>,
base_latency: u64,
plan: &crate::chaos::ChaosPlan,
truncate_after: Option<u32>,
disconnect_after_ms: Option<u64>,
) -> Response<Body> {
let delays_override = match plan.frame_delays_ms.as_ref() {
Some(v) if v.len() == frames.len() => Some(v.as_slice()),
Some(v) => {
eprintln!(
"[llmposter] stream_json_array: frame_delays_ms length mismatch \
(frames={}, delays={}) — falling back to base latency",
frames.len(),
v.len()
);
None
}
None => None,
};
let mut collected: Vec<String> = Vec::new();
let start = Instant::now();
let total = frames.len();
for (i, frame) in frames.into_iter().enumerate() {
tokio::task::yield_now().await;
if let Some(ms) = disconnect_after_ms {
if start.elapsed() >= Duration::from_millis(ms) {
break;
}
}
if let Some(max) = truncate_after {
if i as u32 >= max {
break;
}
}
collected.push(frame);
if i + 1 >= total {
break;
}
let delay = delays_override
.and_then(|v| v.get(i).copied())
.unwrap_or(base_latency);
if delay > 0 {
if let Some(ms) = disconnect_after_ms {
let remaining = ms.saturating_sub(elapsed_ms(&start));
if remaining == 0 {
break;
}
let wait = Duration::from_millis(delay.min(remaining));
sleep(wait).await;
if start.elapsed() >= Duration::from_millis(ms) {
collected.pop();
break;
}
} else {
sleep(Duration::from_millis(delay)).await;
}
}
}
let json = format!("[{}]", collected.join(","));
(
StatusCode::OK,
[(header::CONTENT_TYPE, "application/json")],
json,
)
.into_response()
}
#[cfg(test)]
mod mod_tests {
use super::*;
use crate::chaos::ChaosPlan;
fn mismatched_plan() -> ChaosPlan {
ChaosPlan {
frame_delays_ms: Some(vec![5, 5]), duplicate: false,
active: true,
}
}
async fn collect_body(resp: Response<Body>) -> String {
use axum::body::to_bytes;
let bytes = to_bytes(resp.into_body(), usize::MAX).await.unwrap();
String::from_utf8(bytes.to_vec()).unwrap()
}
fn sse_data_frames(body: &str) -> Vec<&str> {
body.lines()
.filter_map(|l| l.strip_prefix("data: "))
.collect()
}
fn json_array_elements(body: &str) -> Vec<String> {
let v: serde_json::Value = serde_json::from_str(body).unwrap();
v.as_array()
.unwrap()
.iter()
.map(|el| serde_json::to_string(el).unwrap())
.collect()
}
#[tokio::test]
async fn stream_sse_frames_falls_back_on_length_mismatch() {
let frames = vec![
"data: a\n\n".to_string(),
"data: b\n\n".to_string(),
"data: c\n\n".to_string(),
];
let resp = stream_sse_frames(frames, 0, &mismatched_plan(), None, None).await;
assert_eq!(resp.status(), StatusCode::OK);
let body = collect_body(resp).await;
assert_eq!(sse_data_frames(&body), vec!["a", "b", "c"]);
}
#[tokio::test]
async fn stream_json_array_falls_back_on_length_mismatch() {
let frames = vec![
"\"a\"".to_string(),
"\"b\"".to_string(),
"\"c\"".to_string(),
];
let resp = stream_json_array(frames, 0, &mismatched_plan(), None, None).await;
assert_eq!(resp.status(), StatusCode::OK);
let body = collect_body(resp).await;
assert_eq!(json_array_elements(&body), vec!["\"a\"", "\"b\"", "\"c\""]);
}
#[tokio::test]
async fn stream_sse_frames_uses_override_when_lengths_match() {
let frames = vec!["data: a\n\n".to_string(), "data: b\n\n".to_string()];
let plan = ChaosPlan {
frame_delays_ms: Some(vec![0, 0]),
duplicate: false,
active: true,
};
let resp = stream_sse_frames(frames, 100, &plan, None, None).await;
assert_eq!(resp.status(), StatusCode::OK);
let body = collect_body(resp).await;
assert_eq!(sse_data_frames(&body), vec!["a", "b"]);
}
#[tokio::test]
async fn stream_json_array_disconnect_during_latency_drops_last_frame() {
let frames = vec![
"\"a\"".to_string(),
"\"b\"".to_string(),
"\"c\"".to_string(),
];
let plan = ChaosPlan::PASSTHROUGH;
let resp = stream_json_array(frames, 50, &plan, None, Some(10)).await;
assert_eq!(resp.status(), StatusCode::OK);
let body = collect_body(resp).await;
let elements = json_array_elements(&body);
assert!(
elements.len() < 3,
"expected disconnect to truncate the stream, got {:?}",
elements
);
}
#[tokio::test]
async fn stream_json_array_disconnect_remaining_zero_break() {
let frames = vec![
"\"a\"".to_string(),
"\"b\"".to_string(),
"\"c\"".to_string(),
];
let plan = ChaosPlan::PASSTHROUGH;
let resp = stream_json_array(frames, 15, &plan, None, Some(5)).await;
assert_eq!(resp.status(), StatusCode::OK);
let body = collect_body(resp).await;
let elements = json_array_elements(&body);
assert!(
elements.len() < 3,
"expected disconnect to truncate the stream, got {:?}",
elements
);
}
}