use std::any::{Any, TypeId};
use std::collections::HashMap;
use std::sync::Arc;
use crate::extract::RequestContext;
mod client;
mod cookie;
mod recorder;
mod request;
mod response;
mod sse;
mod websocket;
pub use client::{TestClient, TestClientBuilder};
pub use recorder::{LogRecord, LogRecorder};
pub use request::{TestMultipartBuilder, TestRequestBuilder};
pub use response::TestResponse;
pub use sse::{TestSseEvent, TestSseStream};
pub use websocket::{TestWebSocket, TestWebSocketBuilder};
type OverrideFactory = Arc<dyn Fn() -> Box<dyn Any + Send> + Send + Sync>;
#[derive(Default, Clone)]
pub(crate) struct TestOverrides {
factories: HashMap<TypeId, OverrideFactory>,
}
impl TestOverrides {
#[allow(dead_code)]
pub(crate) fn insert<T, F>(&mut self, factory: F)
where
T: Send + 'static,
F: Fn() -> T + Send + Sync + 'static,
{
self.factories
.insert(TypeId::of::<T>(), Arc::new(move || Box::new(factory())));
}
fn produce<T: 'static>(&self) -> Option<T> {
let factory = self.factories.get(&TypeId::of::<T>())?;
factory().downcast::<T>().ok().map(|boxed| *boxed)
}
#[allow(dead_code)]
pub(crate) fn is_empty(&self) -> bool {
self.factories.is_empty()
}
}
#[doc(hidden)]
pub fn __take_override<T: 'static>(ctx: &RequestContext) -> Option<T> {
ctx.state().get::<TestOverrides>()?.produce::<T>()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::body::box_body;
use crate::extract::PathParams;
use crate::state::StateMap;
use bytes::Bytes;
use http_body_util::Full;
fn context_with_overrides(overrides: TestOverrides) -> RequestContext {
let mut state = StateMap::new();
state.insert(overrides);
let head = http::Request::new(()).into_parts().0;
RequestContext::new(
head,
PathParams::new(),
Arc::new(state),
box_body(Full::new(Bytes::new())),
)
}
#[test]
fn override_registry_reports_empty_and_produces_fresh_values() {
let mut overrides = TestOverrides::default();
assert!(overrides.is_empty());
overrides.insert::<String, _>(|| "hello".to_owned());
assert!(!overrides.is_empty());
assert_eq!(overrides.produce::<String>().as_deref(), Some("hello"));
}
#[test]
fn take_override_reads_registered_override() {
let mut overrides = TestOverrides::default();
overrides.insert::<usize, _>(|| 7usize);
let ctx = context_with_overrides(overrides);
assert_eq!(__take_override::<usize>(&ctx), Some(7));
assert_eq!(__take_override::<String>(&ctx), None);
}
}