use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::Arc;
use std::time::Duration;
use bytes::Bytes;
use oxihttp_core::OxiHttpError;
use oxihttp_server::router::Request;
use oxihttp_server::{Router, Server};
async fn spawn_test_server(
router: Router,
) -> (std::net::SocketAddr, tokio::sync::oneshot::Sender<()>) {
let (tx, rx) = tokio::sync::oneshot::channel::<()>();
let (addr, _handle) = Server::bind("127.0.0.1:0")
.with_graceful_shutdown(async move {
let _ = rx.await;
})
.serve_with_addr(router)
.await
.expect("server bind");
tokio::time::sleep(Duration::from_millis(10)).await;
(addr, tx)
}
#[derive(Clone)]
struct CounterState {
counter: Arc<AtomicU32>,
}
#[derive(Clone)]
struct AppConfig {
greeting: String,
}
#[tokio::test]
async fn test_state_injection() {
let shared_counter = Arc::new(AtomicU32::new(0));
let state = CounterState {
counter: Arc::clone(&shared_counter),
};
let router = Router::new()
.with_state(state)
.get("/count", |req: Request| async move {
let s = req
.state::<CounterState>()
.ok_or_else(|| OxiHttpError::Server("missing state".into()))?;
let prev = s.counter.fetch_add(1, Ordering::SeqCst);
oxihttp_server::response::text_response(format!("{prev}"))
});
let (addr, _shutdown) = spawn_test_server(router).await;
let client = oxihttp_client::Client::builder().build().expect("client");
let body1 = client
.get(&format!("http://{addr}/count"))
.expect("builder")
.send()
.await
.expect("request 1")
.body_text()
.await
.expect("body 1");
assert_eq!(body1, "0");
let body2 = client
.get(&format!("http://{addr}/count"))
.expect("builder")
.send()
.await
.expect("request 2")
.body_text()
.await
.expect("body 2");
assert_eq!(body2, "1");
assert_eq!(shared_counter.load(Ordering::SeqCst), 2);
}
#[tokio::test]
async fn test_extension_read() {
#[derive(Clone)]
struct RequestTag(String);
let router = Router::new().get("/tag", |mut req: Request| async move {
req.extensions_mut().insert(RequestTag("hello-ext".into()));
let tag = req
.extension::<RequestTag>()
.ok_or_else(|| OxiHttpError::Server("extension missing".into()))?;
oxihttp_server::response::text_response(tag.0)
});
let (addr, _shutdown) = spawn_test_server(router).await;
let client = oxihttp_client::Client::builder().build().expect("client");
let body = client
.get(&format!("http://{addr}/tag"))
.expect("builder")
.send()
.await
.expect("request")
.body_text()
.await
.expect("body");
assert_eq!(body, "hello-ext");
}
#[tokio::test]
async fn test_extensions_accessor() {
#[derive(Clone, Debug, PartialEq)]
struct Marker(u32);
let router = Router::new().get("/ext", |mut req: Request| async move {
req.extensions_mut().insert(Marker(42));
let v = req
.extensions()
.get::<Marker>()
.cloned()
.ok_or_else(|| OxiHttpError::Server("no marker".into()))?;
oxihttp_server::response::text_response(format!("{}", v.0))
});
let (addr, _shutdown) = spawn_test_server(router).await;
let client = oxihttp_client::Client::builder().build().expect("client");
let body = client
.get(&format!("http://{addr}/ext"))
.expect("builder")
.send()
.await
.expect("request")
.body_text()
.await
.expect("body");
assert_eq!(body, "42");
}
#[tokio::test]
async fn test_nested_router_inherits_state() {
let config = AppConfig {
greeting: "hello-nested".into(),
};
let api = Router::new().get("/greet", |req: Request| async move {
let cfg = req
.state::<AppConfig>()
.ok_or_else(|| OxiHttpError::Server("no config".into()))?;
oxihttp_server::response::text_response(cfg.greeting.clone())
});
let router = Router::new().with_state(config).nest("/api", api);
let (addr, _shutdown) = spawn_test_server(router).await;
let client = oxihttp_client::Client::builder().build().expect("client");
let body = client
.get(&format!("http://{addr}/api/greet"))
.expect("builder")
.send()
.await
.expect("request")
.body_text()
.await
.expect("body");
assert_eq!(body, "hello-nested");
}
#[tokio::test]
async fn test_nested_router_own_state_wins() {
let parent_config = AppConfig {
greeting: "parent".into(),
};
let child_config = AppConfig {
greeting: "child".into(),
};
let api = Router::new()
.with_state(child_config)
.get("/greet", |req: Request| async move {
let cfg = req
.state::<AppConfig>()
.ok_or_else(|| OxiHttpError::Server("no config".into()))?;
oxihttp_server::response::text_response(cfg.greeting.clone())
});
let router = Router::new().with_state(parent_config).nest("/api", api);
let (addr, _shutdown) = spawn_test_server(router).await;
let client = oxihttp_client::Client::builder().build().expect("client");
let body = client
.get(&format!("http://{addr}/api/greet"))
.expect("builder")
.send()
.await
.expect("request")
.body_text()
.await
.expect("body");
assert_eq!(body, "child");
}
#[tokio::test]
async fn test_state_missing_returns_none() {
let router = Router::new().get("/nostate", |req: Request| async move {
let present = req.state::<AppConfig>().is_some();
oxihttp_server::response::text_response(if present { "present" } else { "absent" })
});
let (addr, _shutdown) = spawn_test_server(router).await;
let client = oxihttp_client::Client::builder().build().expect("client");
let body = client
.get(&format!("http://{addr}/nostate"))
.expect("builder")
.send()
.await
.expect("request")
.body_text()
.await
.expect("body");
assert_eq!(body, "absent");
}
#[tokio::test]
async fn test_state_in_fallback_handler() {
let config = AppConfig {
greeting: "fallback-state".into(),
};
let router = Router::new()
.with_state(config)
.fallback(|req: Request| async move {
let cfg = req
.state::<AppConfig>()
.ok_or_else(|| OxiHttpError::Server("no config".into()))?;
let body = Bytes::from(cfg.greeting.clone().into_bytes());
hyper::Response::builder()
.status(http::StatusCode::NOT_FOUND)
.header(http::header::CONTENT_TYPE, "text/plain; charset=utf-8")
.body(http_body_util::Full::new(body))
.map_err(|e| OxiHttpError::Http(std::sync::Arc::new(e)))
});
let (addr, _shutdown) = spawn_test_server(router).await;
let client = oxihttp_client::Client::builder()
.redirect_policy(oxihttp_client::RedirectPolicy::None)
.build()
.expect("client");
let resp = client
.get(&format!("http://{addr}/unknown"))
.expect("builder")
.send()
.await
.expect("request");
assert_eq!(resp.status(), http::StatusCode::NOT_FOUND);
let body = resp.body_text().await.expect("body");
assert_eq!(body, "fallback-state");
}