use std::sync::Mutex;
use axum::Router;
use axum::body::Body;
use http::header::{COOKIE, HeaderName, HeaderValue, SET_COOKIE};
use http::{HeaderMap, Method, Request, StatusCode};
use http_body_util::BodyExt;
use serde::Serialize;
use serde::de::DeserializeOwned;
use sqlx::SqlitePool;
use sqlx::sqlite::{SqliteConnectOptions, SqlitePoolOptions};
use tempfile::TempDir;
use tower::ServiceExt;
pub struct TempPool {
pool: SqlitePool,
_dir: TempDir,
}
impl TempPool {
pub async fn new() -> Self {
Self::with_max_connections(5).await
}
pub async fn with_max_connections(n: u32) -> Self {
let dir = tempfile::tempdir().expect("tempdir for TempPool");
let path = dir.path().join("umbral_test.sqlite");
let pool = SqlitePoolOptions::new()
.max_connections(n)
.connect_with(
SqliteConnectOptions::new()
.filename(&path)
.create_if_missing(true),
)
.await
.expect("connect to tempfile sqlite");
Self { pool, _dir: dir }
}
pub fn handle(&self) -> &SqlitePool {
&self.pool
}
pub fn clone_handle(&self) -> SqlitePool {
self.pool.clone()
}
}
#[derive(Default)]
struct CookieJar {
cookies: Vec<(String, String)>,
}
impl CookieJar {
fn set_from_header(&mut self, header: &str) {
let pair = header.split(';').next().unwrap_or("").trim();
if let Some((name, value)) = pair.split_once('=') {
self.cookies.retain(|(n, _)| n != name);
self.cookies.push((name.to_string(), value.to_string()));
}
}
fn cookie_header(&self) -> Option<String> {
if self.cookies.is_empty() {
return None;
}
Some(
self.cookies
.iter()
.map(|(n, v)| format!("{n}={v}"))
.collect::<Vec<_>>()
.join("; "),
)
}
fn get(&self, name: &str) -> Option<&str> {
self.cookies
.iter()
.find(|(n, _)| n == name)
.map(|(_, v)| v.as_str())
}
}
pub struct TestClient {
router: Router,
jar: Mutex<CookieJar>,
default_headers: Mutex<HeaderMap>,
}
impl TestClient {
pub fn new(router: Router) -> Self {
Self {
router,
jar: Mutex::new(CookieJar::default()),
default_headers: Mutex::new(HeaderMap::new()),
}
}
pub fn set_default_header(&self, name: HeaderName, value: HeaderValue) {
self.default_headers
.lock()
.expect("default headers poisoned")
.insert(name, value);
}
pub fn cookie(&self, name: &str) -> Option<String> {
self.jar
.lock()
.expect("cookie jar poisoned")
.get(name)
.map(str::to_string)
}
pub async fn get(&self, uri: &str) -> TestResponse {
self.request(Method::GET, uri, Body::empty(), None).await
}
pub async fn post(&self, uri: &str, body: Body) -> TestResponse {
self.request(Method::POST, uri, body, None).await
}
pub async fn post_json<T: Serialize + ?Sized>(&self, uri: &str, body: &T) -> TestResponse {
let bytes = serde_json::to_vec(body).expect("serialize body");
self.request(
Method::POST,
uri,
Body::from(bytes),
Some(("content-type", "application/json")),
)
.await
}
pub async fn put_json<T: Serialize + ?Sized>(&self, uri: &str, body: &T) -> TestResponse {
let bytes = serde_json::to_vec(body).expect("serialize body");
self.request(
Method::PUT,
uri,
Body::from(bytes),
Some(("content-type", "application/json")),
)
.await
}
pub async fn delete(&self, uri: &str) -> TestResponse {
self.request(Method::DELETE, uri, Body::empty(), None).await
}
pub async fn send(&self, method: Method, uri: &str, body: Body) -> TestResponse {
self.request(method, uri, body, None).await
}
async fn request(
&self,
method: Method,
uri: &str,
body: Body,
content_type: Option<(&str, &str)>,
) -> TestResponse {
let mut builder = Request::builder().method(method).uri(uri);
for (k, v) in self.default_headers.lock().expect("dh").iter() {
builder = builder.header(k, v);
}
if let Some((k, v)) = content_type {
builder = builder.header(k, v);
}
if let Some(c) = self.jar.lock().expect("jar").cookie_header() {
builder = builder.header(COOKIE, c);
}
let req = builder.body(body).expect("build request");
let resp = self
.router
.clone()
.oneshot(req)
.await
.expect("router oneshot");
let status = resp.status();
let headers = resp.headers().clone();
for v in headers.get_all(SET_COOKIE) {
if let Ok(s) = v.to_str() {
self.jar.lock().expect("jar set").set_from_header(s);
}
}
let bytes = resp
.into_body()
.collect()
.await
.expect("collect body")
.to_bytes();
TestResponse {
status,
headers,
body: bytes.to_vec(),
}
}
}
pub struct TestResponse {
pub status: StatusCode,
pub headers: HeaderMap,
pub body: Vec<u8>,
}
impl TestResponse {
pub fn status(&self) -> StatusCode {
self.status
}
pub fn headers(&self) -> &HeaderMap {
&self.headers
}
pub fn body_bytes(&self) -> &[u8] {
&self.body
}
pub fn body_text(&self) -> String {
String::from_utf8_lossy(&self.body).into_owned()
}
pub fn body_json<T: DeserializeOwned>(&self) -> T {
serde_json::from_slice(&self.body).unwrap_or_else(|e| {
panic!(
"body_json: failed to parse response as JSON ({e}). raw body:\n{}",
self.body_text()
)
})
}
pub fn header(&self, name: &str) -> Option<String> {
self.headers
.get(name)
.and_then(|v| v.to_str().ok())
.map(str::to_string)
}
pub fn assert_status(&self, expected: StatusCode) -> &Self {
assert_eq!(
self.status,
expected,
"expected status {expected}, got {} with body:\n{}",
self.status,
self.body_text()
);
self
}
pub fn assert_status_ok(&self) -> &Self {
self.assert_status(StatusCode::OK)
}
pub fn assert_body_contains(&self, needle: &str) -> &Self {
let body = self.body_text();
assert!(
body.contains(needle),
"expected body to contain {needle:?}\n--- got ---\n{body}\n-----------"
);
self
}
pub fn assert_header(&self, name: &str, expected: &str) -> &Self {
let actual = self.header(name);
assert_eq!(
actual.as_deref(),
Some(expected),
"expected header {name} to be {expected:?}, got {actual:?}"
);
self
}
}
pub use fake;
use std::sync::atomic::{AtomicU64, Ordering};
pub fn seq() -> u64 {
static SEQ: AtomicU64 = AtomicU64::new(0);
SEQ.fetch_add(1, Ordering::Relaxed) + 1
}
#[derive(Debug)]
pub enum FactoryError {
Write(umbral::orm::write::WriteError),
}
impl std::fmt::Display for FactoryError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
FactoryError::Write(e) => write!(f, "factory write failed: {e}"),
}
}
}
impl std::error::Error for FactoryError {}
impl From<umbral::orm::write::WriteError> for FactoryError {
fn from(e: umbral::orm::write::WriteError) -> Self {
FactoryError::Write(e)
}
}
#[async_trait::async_trait]
pub trait Factory {
type Model: umbral::orm::Model
+ serde::Serialize
+ for<'r> sqlx::FromRow<'r, sqlx::sqlite::SqliteRow>
+ for<'r> sqlx::FromRow<'r, sqlx::postgres::PgRow>
+ umbral::orm::HydrateRelated;
fn build() -> Self::Model;
async fn create() -> Result<Self::Model, FactoryError> {
Self::create_with(|_| {}).await
}
async fn create_with<F>(tweak: F) -> Result<Self::Model, FactoryError>
where
F: FnOnce(&mut Self::Model) + Send,
{
let mut instance = Self::build();
tweak(&mut instance);
umbral::orm::Manager::<Self::Model>::default()
.create(instance)
.await
.map_err(FactoryError::Write)
}
async fn create_batch(n: usize) -> Result<Vec<Self::Model>, FactoryError> {
let mut out = Vec::with_capacity(n);
for _ in 0..n {
out.push(Self::create().await?);
}
Ok(out)
}
}