use axum::{
body::Body,
extract::State,
http::{Request, StatusCode},
middleware::Next,
response::{IntoResponse, Response},
};
use serde_json::Value;
use std::sync::Arc;
use std::time::Duration;
use tracing::{debug, warn};
use crate::image_normalizer::{ImageInput, ImageNormalizer, Mode, NormalizeError, walker};
use sqlx::PgPool;
#[derive(Clone)]
pub struct ImageNormalizerMiddlewareState {
pub enabled: bool,
pub normalizer: Arc<dyn ImageNormalizer>,
pub realtime_ttl: Duration,
pub pool: Option<PgPool>,
}
fn extract_bearer_token(request: &Request<Body>) -> Option<String> {
let auth = request.headers().get(axum::http::header::AUTHORIZATION)?.to_str().ok()?;
let auth = auth.trim();
if auth.len() > 7 && auth[..7].eq_ignore_ascii_case("bearer ") {
Some(auth[7..].to_string())
} else {
None
}
}
pub async fn image_normalizer_middleware(
State(state): State<ImageNormalizerMiddlewareState>,
mut request: Request<Body>,
next: Next,
) -> Response {
if !state.enabled {
return next.run(request).await;
}
if !path_accepts_images(request.uri().path()) {
return next.run(request).await;
}
let body_bytes = match axum::body::to_bytes(std::mem::take(request.body_mut()), usize::MAX).await {
Ok(b) => b,
Err(e) => {
warn!(error = %e, "Failed to read request body in image_normalizer middleware");
let body = serde_json::json!({
"error": {
"message": format!("failed to read request body: {e}"),
"type": "invalid_request_error",
"code": "body_read_failed",
}
});
return (StatusCode::BAD_REQUEST, axum::Json(body)).into_response();
}
};
let mut body_value: Value = match serde_json::from_slice(&body_bytes) {
Ok(v) => v,
Err(_) => {
*request.body_mut() = Body::from(body_bytes);
return next.run(request).await;
}
};
let mode = Mode::All;
let attribution_for_access = match (state.pool.as_ref(), extract_bearer_token(&request)) {
(Some(pool), Some(bearer)) => crate::api::handlers::images::resolve_image_attribution(pool, &bearer).await,
_ => None,
};
let normalizer = state.normalizer.clone();
let realtime_ttl = state.realtime_ttl;
let pool_for_access = state.pool.clone();
let substitute = move |url: String| {
let normalizer = normalizer.clone();
let pool_for_access = pool_for_access.clone();
let is_data_uri = url.starts_with("data:");
async move {
let input = if is_data_uri {
ImageInput::DataUri(url)
} else {
ImageInput::HttpUrl(url)
};
let ingested = normalizer.ingest(input).await?;
let signed = normalizer.sign(ingested.token, realtime_ttl).await?;
if let (Some(pool), Some(attribution)) = (pool_for_access, attribution_for_access) {
let mime = ingested.mime.clone();
let bytes_len = ingested.bytes_len;
let token = ingested.token;
tokio::spawn(async move {
crate::api::handlers::images::record_image_access(&pool, attribution, token, &mime, bytes_len).await;
});
}
Ok::<String, NormalizeError>(signed.url)
}
};
let substituted = match walker::substitute_with(&mut body_value, mode, substitute).await {
Ok(n) => n,
Err(e) => {
warn!(error = %e, "image normalisation failed");
return normalize_error_response(e);
}
};
if substituted == 0 {
*request.body_mut() = Body::from(body_bytes);
} else {
debug!(substituted, "image normaliser replaced URLs in request body");
let new_bytes = match serde_json::to_vec(&body_value) {
Ok(b) => b,
Err(e) => {
warn!(error = %e, "Failed to re-serialise body after image normalisation");
let body = serde_json::json!({
"error": {
"message": format!("failed to re-serialise request body: {e}"),
"type": "internal_error",
"code": "body_reserialize_failed",
}
});
return (StatusCode::INTERNAL_SERVER_ERROR, axum::Json(body)).into_response();
}
};
let len = new_bytes.len();
request.headers_mut().insert(
axum::http::header::CONTENT_LENGTH,
len.to_string().parse().expect("digit string is a valid header value"),
);
*request.body_mut() = Body::from(new_bytes);
}
next.run(request).await
}
pub(crate) async fn normalize_value_to_tokens(
body: &mut Value,
normalizer: &Arc<dyn ImageNormalizer>,
access_pool: Option<PgPool>,
attribution: Option<crate::api::handlers::images::ImageAttribution>,
) -> Result<usize, NormalizeError> {
let normalizer = normalizer.clone();
let substitute = move |url: String| {
let normalizer = normalizer.clone();
let access_pool = access_pool.clone();
let is_data_uri = url.starts_with("data:");
async move {
let input = if is_data_uri {
ImageInput::DataUri(url)
} else {
ImageInput::HttpUrl(url)
};
let ingested = normalizer.ingest(input).await?;
if let (Some(pool), Some(attribution)) = (access_pool, attribution) {
let mime = ingested.mime.clone();
let bytes_len = ingested.bytes_len;
let token = ingested.token;
tokio::spawn(async move {
crate::api::handlers::images::record_image_access(&pool, attribution, token, &mime, bytes_len).await;
});
}
Ok::<String, NormalizeError>(ingested.token.to_dw_img_uri())
}
};
walker::substitute_with(body, Mode::All, substitute).await
}
pub(crate) fn normalize_error_response(err: NormalizeError) -> Response {
let (status, code) = match &err {
NormalizeError::BadInput(_) => (StatusCode::BAD_REQUEST, "image_url_rejected"),
NormalizeError::Transient(_) => (StatusCode::SERVICE_UNAVAILABLE, "image_fetch_transient"),
NormalizeError::FetchFailed(_) => (StatusCode::BAD_GATEWAY, "image_fetch_failed"),
NormalizeError::StoreFailed(_) => (StatusCode::SERVICE_UNAVAILABLE, "image_store_failed"),
NormalizeError::NotFound => (StatusCode::INTERNAL_SERVER_ERROR, "image_token_not_found"),
};
let body = serde_json::json!({
"error": {
"message": err.to_string(),
"type": "invalid_request_error",
"code": code,
}
});
(status, axum::Json(body)).into_response()
}
fn path_accepts_images(path: &str) -> bool {
path.ends_with("/chat/completions") || path.ends_with("/responses")
}
#[cfg(test)]
mod tests {
use super::*;
use crate::image_normalizer::{DefaultImageNormalizer, DisabledNormalizer, ImageNormalizer, MemoryStore, config::FetcherConfig};
use axum::{Router, body::to_bytes, http::Method, middleware, routing::post};
use serde_json::json;
use std::sync::Arc;
use tower::ServiceExt;
const TINY_PNG_DATA_URI: &str =
"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNkYAAAAAYAAjCB0C8AAAAASUVORK5CYII=";
fn build_router(state: ImageNormalizerMiddlewareState) -> Router {
let inner = post(|body: axum::body::Bytes| async move { (StatusCode::OK, body) });
Router::new()
.route("/chat/completions", inner.clone())
.route("/responses", inner.clone())
.route("/embeddings", inner)
.layer(middleware::from_fn_with_state(state, image_normalizer_middleware))
}
fn state_for_tests() -> ImageNormalizerMiddlewareState {
let store = Arc::new(MemoryStore::new().with_base_url("http://test.local/dw-img"));
let normalizer = Arc::new(DefaultImageNormalizer::new(FetcherConfig::default(), store));
ImageNormalizerMiddlewareState {
enabled: true,
normalizer,
realtime_ttl: Duration::from_secs(900),
pool: None,
}
}
async fn post_json(router: Router, path: &str, body: Value) -> (StatusCode, Value) {
let resp = router
.oneshot(
Request::builder()
.method(Method::POST)
.uri(path)
.header("content-type", "application/json")
.body(Body::from(serde_json::to_vec(&body).unwrap()))
.unwrap(),
)
.await
.unwrap();
let status = resp.status();
let bytes = to_bytes(resp.into_body(), usize::MAX).await.unwrap();
let v: Value = serde_json::from_slice(&bytes).unwrap_or(Value::Null);
(status, v)
}
#[tokio::test]
async fn passes_through_when_no_image_urls_present() {
let router = build_router(state_for_tests());
let (status, echoed) = post_json(
router,
"/chat/completions",
json!({ "model": "vision", "messages": [ { "role": "user", "content": "hi" } ] }),
)
.await;
assert_eq!(status, StatusCode::OK);
assert_eq!(echoed["messages"][0]["content"], "hi");
}
#[tokio::test]
async fn substitutes_data_uri_with_signed_url_in_all_mode() {
let router = build_router(state_for_tests());
let (status, echoed) = post_json(
router,
"/chat/completions",
json!({
"model": "vision",
"messages": [{
"role": "user",
"content": [{ "type": "image_url", "image_url": { "url": TINY_PNG_DATA_URI } }]
}]
}),
)
.await;
assert_eq!(status, StatusCode::OK);
let substituted = echoed["messages"][0]["content"][0]["image_url"]["url"]
.as_str()
.expect("image_url.url should still be a string");
assert_ne!(
substituted, TINY_PNG_DATA_URI,
"data: URI should be substituted, not passed through"
);
assert!(
substituted.starts_with("http://test.local/dw-img/"),
"expected MemoryStore-backed signed URL, got: {substituted}",
);
}
#[tokio::test]
async fn rejects_http_url_to_link_local_with_400() {
let router = build_router(state_for_tests());
let (status, body) = post_json(
router,
"/chat/completions",
json!({
"model": "vision",
"messages": [{
"role": "user",
"content": [{
"type": "image_url",
"image_url": { "url": "http://169.254.169.254/latest/meta-data/" }
}]
}]
}),
)
.await;
assert_eq!(status, StatusCode::BAD_REQUEST);
assert_eq!(body["error"]["code"], "image_url_rejected");
}
#[tokio::test]
async fn does_not_touch_unrelated_paths() {
let router = build_router(state_for_tests());
let resp = router
.oneshot(
Request::builder()
.method(Method::POST)
.uri("/embeddings")
.header("content-type", "application/json")
.body(Body::from("not json at all"))
.unwrap(),
)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let bytes = to_bytes(resp.into_body(), usize::MAX).await.unwrap();
assert_eq!(&bytes[..], b"not json at all");
}
#[tokio::test]
async fn non_json_body_on_image_path_passes_through() {
let router = build_router(state_for_tests());
let resp = router
.oneshot(
Request::builder()
.method(Method::POST)
.uri("/chat/completions")
.header("content-type", "application/json")
.body(Body::from("garbage"))
.unwrap(),
)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let bytes = to_bytes(resp.into_body(), usize::MAX).await.unwrap();
assert_eq!(&bytes[..], b"garbage");
}
#[tokio::test]
async fn disabled_passes_image_request_through_unchanged() {
let state = ImageNormalizerMiddlewareState {
enabled: false,
normalizer: Arc::new(DisabledNormalizer),
realtime_ttl: Duration::from_secs(900),
pool: None,
};
let router = build_router(state);
let body = json!({
"model": "m",
"messages": [{"role": "user", "content": [
{"type": "image_url", "image_url": {"url": TINY_PNG_DATA_URI}}
]}]
});
let (status, echoed) = post_json(router, "/chat/completions", body.clone()).await;
assert_eq!(status, StatusCode::OK);
assert_eq!(echoed, body);
}
#[tokio::test]
async fn normalize_value_to_tokens_replaces_image_with_token() {
let store = Arc::new(MemoryStore::new());
let normalizer: Arc<dyn ImageNormalizer> = Arc::new(DefaultImageNormalizer::new(FetcherConfig::default(), store));
let mut body = json!({
"model": "m",
"messages": [{"role": "user", "content": [
{"type": "text", "text": "hi"},
{"type": "image_url", "image_url": {"url": TINY_PNG_DATA_URI}}
]}]
});
let n = normalize_value_to_tokens(&mut body, &normalizer, None, None).await.unwrap();
assert_eq!(n, 1);
let url = body["messages"][0]["content"][1]["image_url"]["url"].as_str().unwrap();
assert!(url.starts_with("dw-img://"), "expected a dw-img token, got {url}");
assert!(!url.contains("base64"), "raw base64 must be replaced");
}
#[test]
fn error_response_maps_each_variant_to_the_right_status() {
let cases = [
(NormalizeError::BadInput("x".into()), StatusCode::BAD_REQUEST),
(NormalizeError::Transient("x".into()), StatusCode::SERVICE_UNAVAILABLE),
(NormalizeError::FetchFailed("x".into()), StatusCode::BAD_GATEWAY),
(NormalizeError::StoreFailed("x".into()), StatusCode::SERVICE_UNAVAILABLE),
(NormalizeError::NotFound, StatusCode::INTERNAL_SERVER_ERROR),
];
for (err, expected) in cases {
assert_eq!(normalize_error_response(err).status(), expected);
}
}
}