use std::any::Any;
use std::future::poll_fn;
use std::marker::PhantomData;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::sync::Arc;
use async_trait::async_trait;
#[cfg(feature = "cache")]
use cot::config::CacheUrl;
#[cfg(feature = "redis")]
use deadpool_redis::Connection;
use derive_more::Debug;
#[cfg(feature = "redis")]
use redis::AsyncCommands;
use tokio::net::TcpListener;
use tokio::sync::oneshot;
use tower::Service;
use tower_sessions::MemoryStore;
#[cfg(feature = "db")]
use crate::auth::db::DatabaseUserBackend;
use crate::auth::{Auth, AuthBackend, NoAuthBackend, User, UserId};
#[cfg(feature = "cache")]
use crate::cache::Cache;
#[cfg(feature = "cache")]
use crate::cache::store::memory::Memory;
#[cfg(feature = "redis")]
use crate::cache::store::redis::Redis;
use crate::config::ProjectConfig;
#[cfg(feature = "cache")]
use crate::config::Timeout;
#[cfg(feature = "db")]
use crate::db::Database;
#[cfg(feature = "db")]
use crate::db::migrations::{
DynMigration, MigrationDependency, MigrationEngine, MigrationWrapper, Operation,
};
#[cfg(feature = "email")]
use crate::email::Email;
#[cfg(feature = "email")]
use crate::email::transport::console::Console;
#[cfg(feature = "redis")]
use crate::error::error_impl::impl_into_cot_error;
use crate::handler::BoxedHandler;
use crate::project::{prepare_request, prepare_request_for_error_handler, run_at_with_shutdown};
use crate::request::Request;
use crate::response::Response;
use crate::router::Router;
use crate::session::Session;
use crate::static_files::{StaticFile, StaticFiles};
use crate::{Body, Bootstrapper, Project, ProjectContext, Result};
#[derive(Debug)]
pub struct Client {
context: Arc<ProjectContext>,
handler: BoxedHandler,
error_handler: BoxedHandler,
}
impl Client {
#[must_use]
#[expect(clippy::future_not_send)] pub async fn new<P>(project: P) -> Self
where
P: Project + 'static,
{
let config = project.config("test").expect("Could not get test config");
let bootstrapper = Bootstrapper::new(project)
.with_config(config)
.boot()
.await
.expect("Could not boot project");
let bootstrapped_project = bootstrapper.finish();
Self {
context: Arc::new(bootstrapped_project.context),
handler: bootstrapped_project.handler,
error_handler: bootstrapped_project.error_handler,
}
}
pub async fn get(&mut self, path: &str) -> Result<Response> {
self.request(match http::Request::get(path).body(Body::empty()) {
Ok(request) => request,
Err(_) => {
unreachable!("Test request should be valid")
}
})
.await
}
pub async fn request(&mut self, mut request: Request) -> Result<Response> {
prepare_request(&mut request, self.context.clone());
let (head, body) = request.into_parts();
let mut error_head = head.clone();
let request = Request::from_parts(head, body);
poll_fn(|cx| self.handler.poll_ready(cx)).await?;
match self.handler.call(request).await {
Ok(result) => Ok(result),
Err(error) => {
prepare_request_for_error_handler(&mut error_head, error);
let request = Request::from_parts(error_head, Body::empty());
poll_fn(|cx| self.error_handler.poll_ready(cx)).await?;
self.error_handler.call(request).await
}
}
}
}
#[derive(Debug, Clone)]
pub struct TestRequestBuilder {
method: http::Method,
url: String,
router: Option<Router>,
session: Option<Session>,
config: Option<Arc<ProjectConfig>>,
auth_backend: Option<AuthBackendWrapper>,
auth: Option<Auth>,
#[cfg(feature = "db")]
database: Option<Database>,
form_data: Option<Vec<(String, String)>>,
#[cfg(feature = "json")]
json_data: Option<String>,
static_files: Vec<StaticFile>,
#[cfg(feature = "cache")]
cache: Option<Cache>,
#[cfg(feature = "email")]
email: Option<Email>,
}
#[derive(Debug, Clone)]
struct AuthBackendWrapper {
#[debug("..")]
inner: Arc<dyn AuthBackend>,
}
impl AuthBackendWrapper {
pub(crate) fn new<AB>(inner: AB) -> Self
where
AB: AuthBackend + 'static,
{
Self {
inner: Arc::new(inner),
}
}
}
#[async_trait]
impl AuthBackend for AuthBackendWrapper {
async fn authenticate(
&self,
credentials: &(dyn Any + Send + Sync),
) -> cot::auth::Result<Option<Box<dyn User + Send + Sync>>> {
self.inner.authenticate(credentials).await
}
async fn get_by_id(
&self,
id: UserId,
) -> cot::auth::Result<Option<Box<dyn User + Send + Sync>>> {
self.inner.get_by_id(id).await
}
}
impl Default for TestRequestBuilder {
fn default() -> Self {
Self {
method: http::Method::GET,
url: "/".to_string(),
router: None,
session: None,
config: None,
auth_backend: None,
auth: None,
#[cfg(feature = "db")]
database: None,
form_data: None,
#[cfg(feature = "json")]
json_data: None,
static_files: Vec::new(),
#[cfg(feature = "cache")]
cache: None,
#[cfg(feature = "email")]
email: None,
}
}
}
impl TestRequestBuilder {
#[must_use]
pub fn get(url: &str) -> Self {
Self::with_method(url, crate::Method::GET)
}
#[must_use]
pub fn post(url: &str) -> Self {
Self::with_method(url, crate::Method::POST)
}
#[must_use]
pub fn with_method(url: &str, method: crate::Method) -> Self {
Self {
method,
url: url.to_string(),
..Self::default()
}
}
pub fn config(&mut self, config: ProjectConfig) -> &mut Self {
self.config = Some(Arc::new(config));
self
}
pub fn with_default_config(&mut self) -> &mut Self {
self.config = Some(Arc::new(ProjectConfig::default()));
self
}
pub fn auth_backend<T: AuthBackend + 'static>(&mut self, auth_backend: T) -> &mut Self {
self.auth_backend = Some(AuthBackendWrapper::new(auth_backend));
self
}
pub fn router(&mut self, router: Router) -> &mut Self {
self.router = Some(router);
self
}
pub fn with_session(&mut self) -> &mut Self {
let session_store = MemoryStore::default();
let session_inner = tower_sessions::Session::new(None, Arc::new(session_store), None);
self.session = Some(Session::new(session_inner));
self
}
pub fn with_session_from(&mut self, request: &Request) -> &mut Self {
self.session = Some(Session::from_request(request).clone());
self
}
pub fn session(&mut self, session: Session) -> &mut Self {
self.session = Some(session);
self
}
#[cfg(feature = "db")]
pub fn database<DB: Into<Database>>(&mut self, database: DB) -> &mut Self {
self.database = Some(database.into());
self
}
#[cfg(feature = "db")]
pub async fn with_db_auth(&mut self, db: Database) -> &mut Self {
self.auth_backend(DatabaseUserBackend::new(db.clone()));
self.with_session();
self.database(db);
self.auth = Some(
Auth::new(
self.session.clone().expect("Session was just set"),
self.auth_backend
.clone()
.expect("Auth backend was just set")
.inner,
crate::config::SecretKey::from("000000"),
&[],
)
.await
.expect("Failed to create Auth"),
);
self
}
pub fn form_data<T: ToString>(&mut self, form_data: &[(T, T)]) -> &mut Self {
self.form_data = Some(
form_data
.iter()
.map(|(k, v)| (k.to_string(), v.to_string()))
.collect(),
);
self
}
#[cfg(feature = "json")]
pub fn json<T: serde::Serialize>(&mut self, data: &T) -> &mut Self {
self.json_data = Some(serde_json::to_string(data).expect("Failed to serialize JSON"));
self
}
pub fn static_file<Path, Content>(&mut self, path: Path, content: Content) -> &mut Self
where
Path: Into<String>,
Content: Into<bytes::Bytes>,
{
self.static_files.push(StaticFile::new(path, content));
self
}
#[must_use]
pub fn build(&mut self) -> http::Request<Body> {
let Ok(mut request) = http::Request::builder()
.method(self.method.clone())
.uri(self.url.clone())
.body(Body::empty())
else {
unreachable!("Test request should be valid");
};
let auth_backend = std::mem::take(&mut self.auth_backend);
#[expect(trivial_casts)]
let auth_backend = match auth_backend {
Some(auth_backend) => Arc::new(auth_backend) as Arc<dyn AuthBackend>,
None => Arc::new(NoAuthBackend),
};
let context = ProjectContext::initialized(
self.config.clone().unwrap_or_default(),
Vec::new(),
Arc::new(self.router.clone().unwrap_or_else(Router::empty)),
auth_backend,
#[cfg(feature = "db")]
self.database.clone(),
#[cfg(feature = "cache")]
self.cache
.clone()
.unwrap_or_else(|| Cache::new(Memory::new(), None, Timeout::default())),
#[cfg(feature = "email")]
self.email
.clone()
.unwrap_or_else(|| Email::new(Console::new())),
);
prepare_request(&mut request, Arc::new(context));
if let Some(session) = &self.session {
request.extensions_mut().insert(session.clone());
}
if let Some(auth) = &self.auth {
request.extensions_mut().insert(auth.clone());
}
if let Some(form_data) = &self.form_data {
if self.method != http::Method::POST {
todo!("Form data can currently only be used with POST requests");
}
let mut data = form_urlencoded::Serializer::new(String::new());
for (key, value) in form_data {
data.append_pair(key, value);
}
*request.body_mut() = Body::fixed(data.finish());
request.headers_mut().insert(
http::header::CONTENT_TYPE,
http::HeaderValue::from_static("application/x-www-form-urlencoded"),
);
}
#[cfg(feature = "json")]
if let Some(json_data) = &self.json_data {
*request.body_mut() = Body::fixed(json_data.clone());
request.headers_mut().insert(
http::header::CONTENT_TYPE,
http::HeaderValue::from_static("application/json"),
);
}
if !self.static_files.is_empty() {
let config = self.config.clone().unwrap_or_default();
let mut static_files = StaticFiles::new(&config.static_files);
for file in std::mem::take(&mut self.static_files) {
static_files.add_file(file);
}
request.extensions_mut().insert(Arc::new(static_files));
}
request
}
}
#[cfg(feature = "db")]
#[derive(Debug)]
pub struct TestDatabase {
database: Database,
kind: TestDatabaseKind,
migrations: Vec<MigrationWrapper>,
}
#[cfg(feature = "db")]
impl TestDatabase {
fn new(database: Database, kind: TestDatabaseKind) -> TestDatabase {
Self {
database,
kind,
migrations: Vec::new(),
}
}
pub async fn new_sqlite() -> Result<Self> {
let database = Database::new("sqlite::memory:").await?;
Ok(Self::new(database, TestDatabaseKind::Sqlite))
}
pub async fn new_postgres(test_name: &str) -> Result<Self> {
let db_url = std::env::var("POSTGRES_URL")
.unwrap_or_else(|_| "postgresql://cot:cot@localhost".to_string());
let database = Database::new(format!("{db_url}/postgres")).await?;
let test_database_name = format!("test_cot__{test_name}");
database
.raw(&format!("DROP DATABASE IF EXISTS {test_database_name}"))
.await?;
database
.raw(&format!("CREATE DATABASE {test_database_name}"))
.await?;
database.close().await?;
let database = Database::new(format!("{db_url}/{test_database_name}")).await?;
Ok(Self::new(
database,
TestDatabaseKind::Postgres {
db_url,
db_name: test_database_name,
},
))
}
pub async fn new_mysql(test_name: &str) -> Result<Self> {
let db_url =
std::env::var("MYSQL_URL").unwrap_or_else(|_| "mysql://root:@localhost".to_string());
let database = Database::new(format!("{db_url}/mysql")).await?;
let test_database_name = format!("test_cot__{test_name}");
database
.raw(&format!("DROP DATABASE IF EXISTS {test_database_name}"))
.await?;
database
.raw(&format!("CREATE DATABASE {test_database_name}"))
.await?;
database.close().await?;
let database = Database::new(format!("{db_url}/{test_database_name}")).await?;
Ok(Self::new(
database,
TestDatabaseKind::MySql {
db_url,
db_name: test_database_name,
},
))
}
#[cfg(feature = "db")]
pub fn with_auth(&mut self) -> &mut Self {
self.add_migrations(cot::auth::db::migrations::MIGRATIONS.to_vec());
self
}
pub fn add_migrations<T: DynMigration + Send + Sync + 'static, V: IntoIterator<Item = T>>(
&mut self,
migrations: V,
) -> &mut Self {
self.migrations
.extend(migrations.into_iter().map(MigrationWrapper::new));
self
}
pub async fn run_migrations(&mut self) -> &mut Self {
if !self.migrations.is_empty() {
let engine = MigrationEngine::new(std::mem::take(&mut self.migrations))
.expect("Failed to initialize the migration engine");
engine
.run(&self.database())
.await
.expect("Failed to run migrations");
}
self
}
#[must_use]
pub fn database(&self) -> Database {
self.database.clone()
}
pub async fn cleanup(&self) -> Result<()> {
self.database.close().await?;
match &self.kind {
TestDatabaseKind::Sqlite => {}
TestDatabaseKind::Postgres { db_url, db_name } => {
let database = Database::new(format!("{db_url}/postgres")).await?;
database
.raw(&format!("DROP DATABASE {db_name} WITH (FORCE)"))
.await?;
database.close().await?;
}
TestDatabaseKind::MySql { db_url, db_name } => {
let database = Database::new(format!("{db_url}/mysql")).await?;
database.raw(&format!("DROP DATABASE {db_name}")).await?;
database.close().await?;
}
}
Ok(())
}
}
#[cfg(feature = "db")]
impl std::ops::Deref for TestDatabase {
type Target = Database;
fn deref(&self) -> &Self::Target {
&self.database
}
}
#[cfg(feature = "db")]
#[derive(Debug, Clone)]
enum TestDatabaseKind {
Sqlite,
Postgres { db_url: String, db_name: String },
MySql { db_url: String, db_name: String },
}
#[cfg(feature = "db")]
#[derive(Debug, Clone)]
pub struct TestMigration {
app_name: &'static str,
name: &'static str,
dependencies: Vec<MigrationDependency>,
operations: Vec<Operation>,
}
#[cfg(feature = "db")]
impl TestMigration {
#[must_use]
pub fn new<D: Into<Vec<MigrationDependency>>, O: Into<Vec<Operation>>>(
app_name: &'static str,
name: &'static str,
dependencies: D,
operations: O,
) -> Self {
Self {
app_name,
name,
dependencies: dependencies.into(),
operations: operations.into(),
}
}
}
#[cfg(feature = "db")]
impl DynMigration for TestMigration {
fn app_name(&self) -> &str {
self.app_name
}
fn name(&self) -> &str {
self.name
}
fn dependencies(&self) -> &[MigrationDependency] {
&self.dependencies
}
fn operations(&self) -> &[Operation] {
&self.operations
}
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct TestServerBuilder<T> {
project: T,
}
impl<T: Project + 'static> TestServerBuilder<T> {
#[must_use]
pub fn new(project: T) -> Self {
Self { project }
}
pub async fn start(self) -> TestServer<T> {
TestServer::start(self.project).await
}
}
#[must_use = "TestServer must be used to close the server"]
#[derive(Debug)]
pub struct TestServer<T> {
address: SocketAddr,
channel_send: oneshot::Sender<()>,
server_handle: tokio::task::JoinHandle<()>,
project: PhantomData<fn() -> T>,
}
impl<T: Project + 'static> TestServer<T> {
async fn start(project: T) -> Self {
let tcp_listener = TcpListener::bind("0.0.0.0:0")
.await
.expect("Failed to bind to a port");
let mut address = tcp_listener
.local_addr()
.expect("Failed to get the listening address");
address.set_ip(IpAddr::V4(Ipv4Addr::LOCALHOST));
let (send, recv) = oneshot::channel::<()>();
let server_handle = tokio::task::spawn_local(async move {
let bootstrapper = Bootstrapper::new(project)
.with_config_name("test")
.expect("Failed to get the \"test\" config")
.boot()
.await
.expect("Failed to boot the project");
run_at_with_shutdown(bootstrapper, tcp_listener, async move {
recv.await.expect("Failed to receive a shutdown signal");
})
.await
.expect("Failed to run the server");
});
Self {
address,
channel_send: send,
server_handle,
project: PhantomData,
}
}
#[must_use]
pub fn address(&self) -> SocketAddr {
self.address
}
#[must_use]
pub fn url(&self) -> String {
if let Ok(host) = std::env::var("COT_TEST_SERVER_HOST") {
format!("http://{}:{}", host, self.address.port())
} else {
format!("http://{}", self.address)
}
}
pub async fn close(self) {
self.channel_send
.send(())
.expect("Failed to send a shutdown signal");
self.server_handle
.await
.expect("Failed to join the server task");
}
}
#[doc(hidden)] pub fn serial_guard() -> std::sync::MutexGuard<'static, ()> {
static LOCK: std::sync::OnceLock<std::sync::Mutex<()>> = std::sync::OnceLock::new();
let lock = LOCK.get_or_init(|| std::sync::Mutex::new(()));
match lock.lock() {
Ok(guard) => guard,
Err(poison_error) => {
lock.clear_poison();
poison_error.into_inner()
}
}
}
#[cfg(feature = "redis")]
const POOL_KEY: &str = "cot:test:db_pool";
#[cfg(feature = "redis")]
async fn get_db_num(conn: &mut Connection) -> usize {
let cfg = redis::cmd("CONFIG")
.arg("GET")
.arg("databases")
.query_async::<Vec<String>>(conn)
.await
.expect("Failed to get Redis config");
cfg.get(1)
.and_then(|s| s.parse::<usize>().ok())
.unwrap_or(16)
}
#[cfg(feature = "redis")]
async fn set_current_db(conn: &mut Connection, db_num: usize) {
redis::cmd("SELECT")
.arg(db_num)
.query_async::<()>(conn)
.await
.expect("Failed to select Redis database");
}
#[cfg(feature = "redis")]
#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
enum RedisDbAllocatorError {
#[error(transparent)]
Io(#[from] std::io::Error),
#[error("Redis error: {0}")]
Redis(String),
}
#[cfg(feature = "redis")]
impl_into_cot_error!(RedisDbAllocatorError);
#[cfg(feature = "redis")]
#[derive(Debug, Clone)]
struct RedisDbAllocator {
alloc_db: usize,
redis: Redis,
}
#[cfg(feature = "redis")]
type RedisAllocatorResult<T> = std::result::Result<T, RedisDbAllocatorError>;
#[cfg(feature = "redis")]
impl RedisDbAllocator {
fn new(alloc_db: usize, redis: Redis) -> Self {
Self { alloc_db, redis }
}
async fn get_conn(&self) -> RedisAllocatorResult<Connection> {
let conn = self
.redis
.get_connection()
.await
.map_err(|err| RedisDbAllocatorError::Redis(err.to_string()))?;
Ok(conn)
}
async fn init(&self) -> RedisAllocatorResult<Option<String>> {
const KEY_TIMEOUT_SECS: u64 = 300;
const INIT_KEY: &str = "cot:test:db_pool:initialized";
let mut con = self.get_conn().await?;
let last_eligible_db = self.alloc_db - 1;
redis::cmd("WATCH")
.arg(INIT_KEY)
.query_async::<redis::Value>(&mut con)
.await
.map_err(|err| RedisDbAllocatorError::Redis(err.to_string()))?;
let prev = redis::cmd("GET")
.arg(INIT_KEY)
.query_async::<Option<String>>(&mut con)
.await
.map_err(|err| RedisDbAllocatorError::Redis(err.to_string()))?;
if prev.is_some() {
redis::cmd("UNWATCH")
.query_async::<redis::Value>(&mut con)
.await
.map_err(|err| RedisDbAllocatorError::Redis(err.to_string()))?;
return Ok(prev);
}
redis::cmd("MULTI")
.query_async::<redis::Value>(&mut con)
.await
.map_err(|err| RedisDbAllocatorError::Redis(err.to_string()))?;
let mut set_cmd = redis::cmd("SET");
set_cmd.arg(INIT_KEY).arg("1");
set_cmd.arg("EX").arg(KEY_TIMEOUT_SECS);
set_cmd
.query_async::<redis::Value>(&mut con)
.await
.map_err(|err| RedisDbAllocatorError::Redis(err.to_string()))?;
redis::cmd("DEL")
.arg(POOL_KEY)
.query_async::<redis::Value>(&mut con)
.await
.map_err(|err| RedisDbAllocatorError::Redis(err.to_string()))?;
let vals: Vec<String> = (1..=last_eligible_db).map(|i| i.to_string()).collect();
redis::cmd("RPUSH")
.arg(POOL_KEY)
.arg(vals)
.query_async::<redis::Value>(&mut con)
.await
.map_err(|err| RedisDbAllocatorError::Redis(err.to_string()))?;
redis::cmd("EXPIRE")
.arg(POOL_KEY)
.arg(KEY_TIMEOUT_SECS)
.query_async::<redis::Value>(&mut con)
.await
.map_err(|err| RedisDbAllocatorError::Redis(err.to_string()))?;
redis::cmd("EXEC")
.query_async::<Option<Vec<redis::Value>>>(&mut con)
.await
.map_err(|err| RedisDbAllocatorError::Redis(err.to_string()))?;
Ok(None)
}
async fn allocate(&self) -> RedisAllocatorResult<Option<usize>> {
let mut connection = self.get_conn().await?;
let db_index: Option<String> = connection
.lpop(POOL_KEY, None)
.await
.map_err(|err| RedisDbAllocatorError::Redis(err.to_string()))?;
Ok(db_index.and_then(|i| i.parse::<usize>().ok()))
}
}
#[cfg(feature = "cache")]
#[derive(Debug, Clone)]
enum CacheKind {
Memory,
#[cfg(feature = "redis")]
Redis {
#[expect(unused)]
allocator: RedisDbAllocator,
},
}
#[cfg(feature = "cache")]
#[derive(Debug, Clone)]
pub struct TestCache {
cache: Cache,
kind: CacheKind,
}
#[cfg(feature = "cache")]
impl TestCache {
fn new(cache: Cache, kind: CacheKind) -> Self {
Self { cache, kind }
}
#[must_use]
pub fn new_memory() -> Self {
let cache = Cache::new(Memory::new(), None, Timeout::default());
Self::new(cache, CacheKind::Memory)
}
#[cfg(feature = "redis")]
pub async fn new_redis() -> Result<Self> {
let url = std::env::var("REDIS_URL").unwrap_or_else(|_| "redis://localhost".to_string());
let mut url = CacheUrl::from(url);
let redis = Redis::new(&url, crate::config::DEFAULT_REDIS_POOL_SIZE)?;
let mut conn = redis.get_connection().await?;
let db_num = get_db_num(&mut conn).await;
assert!(
db_num > 1,
"Redis must be configured with at least 2 databases for testing"
);
let alloc_db = db_num - 1;
set_current_db(&mut conn, db_num - 1).await;
let allocator = RedisDbAllocator::new(alloc_db, redis);
allocator.init().await?;
let current_db = allocator
.allocate()
.await?
.expect("Failed to allocate a Redis database for testing");
url.inner_mut().set_path(current_db.to_string().as_str());
let redis = Redis::new(&url, crate::config::DEFAULT_REDIS_POOL_SIZE)?;
let cache = Cache::new(redis, Some("test_harness".to_string()), Timeout::default());
let this = Self::new(cache, CacheKind::Redis { allocator });
Ok(this)
}
#[must_use]
pub fn cache(&self) -> Cache {
self.cache.clone()
}
pub async fn cleanup(&self) -> Result<()> {
#[cfg(feature = "redis")]
if let CacheKind::Redis { allocator: _ } = &self.kind {
self.cache.clear().await?;
}
Ok(())
}
}