use std::any::Any;
use std::future::poll_fn;
use std::marker::PhantomData;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::sync::Arc;
use async_trait::async_trait;
use cot::project::run_at_with_shutdown;
use derive_more::Debug;
use tokio::net::TcpListener;
use tokio::sync::oneshot;
use tower::Service;
use tower_sessions::MemoryStore;
#[cfg(feature = "db")]
use crate::auth::db::DatabaseUserBackend;
use crate::auth::{Auth, AuthBackend, NoAuthBackend, User, UserId};
use crate::config::ProjectConfig;
#[cfg(feature = "db")]
use crate::db::Database;
#[cfg(feature = "db")]
use crate::db::migrations::{
DynMigration, MigrationDependency, MigrationEngine, MigrationWrapper, Operation,
};
use crate::handler::BoxedHandler;
use crate::project::prepare_request;
use crate::request::Request;
use crate::response::Response;
use crate::router::Router;
use crate::session::Session;
use crate::static_files::{StaticFile, StaticFiles};
use crate::{Body, Bootstrapper, Project, ProjectContext, Result};
#[derive(Debug)]
pub struct Client {
context: Arc<ProjectContext>,
handler: BoxedHandler,
}
impl Client {
#[must_use]
#[expect(clippy::future_not_send)] pub async fn new<P>(project: P) -> Self
where
P: Project + 'static,
{
let config = project.config("test").expect("Could not get test config");
let bootstrapper = Bootstrapper::new(project)
.with_config(config)
.boot()
.await
.expect("Could not boot project");
let (context, handler) = bootstrapper.into_context_and_handler();
Self {
context: Arc::new(context),
handler,
}
}
pub async fn get(&mut self, path: &str) -> Result<Response> {
self.request(match http::Request::get(path).body(Body::empty()) {
Ok(request) => request,
Err(_) => {
unreachable!("Test request should be valid")
}
})
.await
}
pub async fn request(&mut self, mut request: Request) -> Result<Response> {
prepare_request(&mut request, self.context.clone());
poll_fn(|cx| self.handler.poll_ready(cx)).await?;
self.handler.call(request).await
}
}
#[derive(Debug, Clone)]
pub struct TestRequestBuilder {
method: http::Method,
url: String,
router: Option<Router>,
session: Option<Session>,
config: Option<Arc<ProjectConfig>>,
auth_backend: Option<AuthBackendWrapper>,
auth: Option<Auth>,
#[cfg(feature = "db")]
database: Option<Arc<Database>>,
form_data: Option<Vec<(String, String)>>,
#[cfg(feature = "json")]
json_data: Option<String>,
static_files: Vec<StaticFile>,
}
#[derive(Debug, Clone)]
struct AuthBackendWrapper {
#[debug("..")]
inner: Arc<dyn AuthBackend>,
}
impl AuthBackendWrapper {
pub(crate) fn new<AB>(inner: AB) -> Self
where
AB: AuthBackend + 'static,
{
Self {
inner: Arc::new(inner),
}
}
}
#[async_trait]
impl AuthBackend for AuthBackendWrapper {
async fn authenticate(
&self,
credentials: &(dyn Any + Send + Sync),
) -> cot::auth::Result<Option<Box<dyn User + Send + Sync>>> {
self.inner.authenticate(credentials).await
}
async fn get_by_id(
&self,
id: UserId,
) -> cot::auth::Result<Option<Box<dyn User + Send + Sync>>> {
self.inner.get_by_id(id).await
}
}
impl Default for TestRequestBuilder {
fn default() -> Self {
Self {
method: http::Method::GET,
url: "/".to_string(),
router: None,
session: None,
config: None,
auth_backend: None,
auth: None,
#[cfg(feature = "db")]
database: None,
form_data: None,
#[cfg(feature = "json")]
json_data: None,
static_files: Vec::new(),
}
}
}
impl TestRequestBuilder {
#[must_use]
pub fn get(url: &str) -> Self {
Self::with_method(url, crate::Method::GET)
}
#[must_use]
pub fn post(url: &str) -> Self {
Self::with_method(url, crate::Method::POST)
}
#[must_use]
pub fn with_method(url: &str, method: crate::Method) -> Self {
Self {
method,
url: url.to_string(),
..Self::default()
}
}
pub fn config(&mut self, config: ProjectConfig) -> &mut Self {
self.config = Some(Arc::new(config));
self
}
pub fn with_default_config(&mut self) -> &mut Self {
self.config = Some(Arc::new(ProjectConfig::default()));
self
}
pub fn auth_backend<T: AuthBackend + 'static>(&mut self, auth_backend: T) -> &mut Self {
self.auth_backend = Some(AuthBackendWrapper::new(auth_backend));
self
}
pub fn router(&mut self, router: Router) -> &mut Self {
self.router = Some(router);
self
}
pub fn with_session(&mut self) -> &mut Self {
let session_store = MemoryStore::default();
let session_inner = tower_sessions::Session::new(None, Arc::new(session_store), None);
self.session = Some(Session::new(session_inner));
self
}
pub fn with_session_from(&mut self, request: &Request) -> &mut Self {
self.session = Some(Session::from_request(request).clone());
self
}
pub fn session(&mut self, session: Session) -> &mut Self {
self.session = Some(session);
self
}
#[cfg(feature = "db")]
pub fn database<DB: Into<Arc<Database>>>(&mut self, database: DB) -> &mut Self {
self.database = Some(database.into());
self
}
#[cfg(feature = "db")]
pub async fn with_db_auth(&mut self, db: Arc<Database>) -> &mut Self {
self.auth_backend(DatabaseUserBackend::new(Arc::clone(&db)));
self.with_session();
self.database(db);
self.auth = Some(
Auth::new(
self.session.clone().expect("Session was just set"),
self.auth_backend
.clone()
.expect("Auth backend was just set")
.inner,
crate::config::SecretKey::from("000000"),
&[],
)
.await
.expect("Failed to create Auth"),
);
self
}
pub fn form_data<T: ToString>(&mut self, form_data: &[(T, T)]) -> &mut Self {
self.form_data = Some(
form_data
.iter()
.map(|(k, v)| (k.to_string(), v.to_string()))
.collect(),
);
self
}
#[cfg(feature = "json")]
pub fn json<T: serde::Serialize>(&mut self, data: &T) -> &mut Self {
self.json_data = Some(serde_json::to_string(data).expect("Failed to serialize JSON"));
self
}
pub fn static_file<Path, Content>(&mut self, path: Path, content: Content) -> &mut Self
where
Path: Into<String>,
Content: Into<bytes::Bytes>,
{
self.static_files.push(StaticFile::new(path, content));
self
}
#[must_use]
pub fn build(&mut self) -> http::Request<Body> {
let Ok(mut request) = http::Request::builder()
.method(self.method.clone())
.uri(self.url.clone())
.body(Body::empty())
else {
unreachable!("Test request should be valid");
};
let auth_backend = std::mem::take(&mut self.auth_backend);
#[expect(trivial_casts)]
let auth_backend = match auth_backend {
Some(auth_backend) => Arc::new(auth_backend) as Arc<dyn AuthBackend>,
None => Arc::new(NoAuthBackend),
};
let context = ProjectContext::initialized(
self.config.clone().unwrap_or_default(),
Vec::new(),
Arc::new(self.router.clone().unwrap_or_else(Router::empty)),
auth_backend,
#[cfg(feature = "db")]
self.database.clone(),
);
prepare_request(&mut request, Arc::new(context));
if let Some(session) = &self.session {
request.extensions_mut().insert(session.clone());
}
if let Some(auth) = &self.auth {
request.extensions_mut().insert(auth.clone());
}
if let Some(form_data) = &self.form_data {
if self.method != http::Method::POST {
todo!("Form data can currently only be used with POST requests");
}
let mut data = form_urlencoded::Serializer::new(String::new());
for (key, value) in form_data {
data.append_pair(key, value);
}
*request.body_mut() = Body::fixed(data.finish());
request.headers_mut().insert(
http::header::CONTENT_TYPE,
http::HeaderValue::from_static("application/x-www-form-urlencoded"),
);
}
#[cfg(feature = "json")]
if let Some(json_data) = &self.json_data {
*request.body_mut() = Body::fixed(json_data.clone());
request.headers_mut().insert(
http::header::CONTENT_TYPE,
http::HeaderValue::from_static("application/json"),
);
}
if !self.static_files.is_empty() {
let config = self.config.clone().unwrap_or_default();
let mut static_files = StaticFiles::new(&config.static_files);
for file in std::mem::take(&mut self.static_files) {
static_files.add_file(file);
}
request.extensions_mut().insert(Arc::new(static_files));
}
request
}
}
#[cfg(feature = "db")]
#[derive(Debug)]
pub struct TestDatabase {
database: Arc<Database>,
kind: TestDatabaseKind,
migrations: Vec<MigrationWrapper>,
}
#[cfg(feature = "db")]
impl TestDatabase {
fn new(database: Database, kind: TestDatabaseKind) -> TestDatabase {
Self {
database: Arc::new(database),
kind,
migrations: Vec::new(),
}
}
pub async fn new_sqlite() -> Result<Self> {
let database = Database::new("sqlite::memory:").await?;
Ok(Self::new(database, TestDatabaseKind::Sqlite))
}
pub async fn new_postgres(test_name: &str) -> Result<Self> {
let db_url = std::env::var("POSTGRES_URL")
.unwrap_or_else(|_| "postgresql://cot:cot@localhost".to_string());
let database = Database::new(format!("{db_url}/postgres")).await?;
let test_database_name = format!("test_cot__{test_name}");
database
.raw(&format!("DROP DATABASE IF EXISTS {test_database_name}"))
.await?;
database
.raw(&format!("CREATE DATABASE {test_database_name}"))
.await?;
database.close().await?;
let database = Database::new(format!("{db_url}/{test_database_name}")).await?;
Ok(Self::new(
database,
TestDatabaseKind::Postgres {
db_url,
db_name: test_database_name,
},
))
}
pub async fn new_mysql(test_name: &str) -> Result<Self> {
let db_url =
std::env::var("MYSQL_URL").unwrap_or_else(|_| "mysql://root:@localhost".to_string());
let database = Database::new(format!("{db_url}/mysql")).await?;
let test_database_name = format!("test_cot__{test_name}");
database
.raw(&format!("DROP DATABASE IF EXISTS {test_database_name}"))
.await?;
database
.raw(&format!("CREATE DATABASE {test_database_name}"))
.await?;
database.close().await?;
let database = Database::new(format!("{db_url}/{test_database_name}")).await?;
Ok(Self::new(
database,
TestDatabaseKind::MySql {
db_url,
db_name: test_database_name,
},
))
}
#[cfg(feature = "db")]
pub fn with_auth(&mut self) -> &mut Self {
self.add_migrations(cot::auth::db::migrations::MIGRATIONS.to_vec());
self
}
pub fn add_migrations<T: DynMigration + Send + Sync + 'static, V: IntoIterator<Item = T>>(
&mut self,
migrations: V,
) -> &mut Self {
self.migrations
.extend(migrations.into_iter().map(MigrationWrapper::new));
self
}
pub async fn run_migrations(&mut self) -> &mut Self {
if !self.migrations.is_empty() {
let engine = MigrationEngine::new(std::mem::take(&mut self.migrations))
.expect("Failed to initialize the migration engine");
engine
.run(&self.database())
.await
.expect("Failed to run migrations");
}
self
}
#[must_use]
pub fn database(&self) -> Arc<Database> {
self.database.clone()
}
pub async fn cleanup(&self) -> Result<()> {
self.database.close().await?;
match &self.kind {
TestDatabaseKind::Sqlite => {}
TestDatabaseKind::Postgres { db_url, db_name } => {
let database = Database::new(format!("{db_url}/postgres")).await?;
database
.raw(&format!("DROP DATABASE {db_name} WITH (FORCE)"))
.await?;
database.close().await?;
}
TestDatabaseKind::MySql { db_url, db_name } => {
let database = Database::new(format!("{db_url}/mysql")).await?;
database.raw(&format!("DROP DATABASE {db_name}")).await?;
database.close().await?;
}
}
Ok(())
}
}
#[cfg(feature = "db")]
impl std::ops::Deref for TestDatabase {
type Target = Database;
fn deref(&self) -> &Self::Target {
&self.database
}
}
#[cfg(feature = "db")]
#[derive(Debug, Clone)]
enum TestDatabaseKind {
Sqlite,
Postgres { db_url: String, db_name: String },
MySql { db_url: String, db_name: String },
}
#[cfg(feature = "db")]
#[derive(Debug, Clone)]
pub struct TestMigration {
app_name: &'static str,
name: &'static str,
dependencies: Vec<MigrationDependency>,
operations: Vec<Operation>,
}
#[cfg(feature = "db")]
impl TestMigration {
#[must_use]
pub fn new<D: Into<Vec<MigrationDependency>>, O: Into<Vec<Operation>>>(
app_name: &'static str,
name: &'static str,
dependencies: D,
operations: O,
) -> Self {
Self {
app_name,
name,
dependencies: dependencies.into(),
operations: operations.into(),
}
}
}
#[cfg(feature = "db")]
impl DynMigration for TestMigration {
fn app_name(&self) -> &str {
self.app_name
}
fn name(&self) -> &str {
self.name
}
fn dependencies(&self) -> &[MigrationDependency] {
&self.dependencies
}
fn operations(&self) -> &[Operation] {
&self.operations
}
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct TestServerBuilder<T> {
project: T,
}
impl<T: Project + 'static> TestServerBuilder<T> {
#[must_use]
pub fn new(project: T) -> Self {
Self { project }
}
pub async fn start(self) -> TestServer<T> {
TestServer::start(self.project).await
}
}
#[must_use = "TestServer must be used to close the server"]
#[derive(Debug)]
pub struct TestServer<T> {
address: SocketAddr,
channel_send: oneshot::Sender<()>,
server_handle: tokio::task::JoinHandle<()>,
project: PhantomData<fn() -> T>,
}
impl<T: Project + 'static> TestServer<T> {
async fn start(project: T) -> Self {
let tcp_listener = TcpListener::bind("0.0.0.0:0")
.await
.expect("Failed to bind to a port");
let mut address = tcp_listener
.local_addr()
.expect("Failed to get the listening address");
address.set_ip(IpAddr::V4(Ipv4Addr::LOCALHOST));
let (send, recv) = oneshot::channel::<()>();
let server_handle = tokio::task::spawn_local(async move {
let bootstrapper = Bootstrapper::new(project)
.with_config_name("test")
.expect("Failed to get the \"test\" config")
.boot()
.await
.expect("Failed to boot the project");
run_at_with_shutdown(bootstrapper, tcp_listener, async move {
recv.await.expect("Failed to receive a shutdown signal");
})
.await
.expect("Failed to run the server");
});
Self {
address,
channel_send: send,
server_handle,
project: PhantomData,
}
}
#[must_use]
pub fn address(&self) -> SocketAddr {
self.address
}
#[must_use]
pub fn url(&self) -> String {
if let Ok(host) = std::env::var("COT_TEST_SERVER_HOST") {
format!("http://{}:{}", host, self.address.port())
} else {
format!("http://{}", self.address)
}
}
pub async fn close(self) {
self.channel_send
.send(())
.expect("Failed to send a shutdown signal");
self.server_handle
.await
.expect("Failed to join the server task");
}
}
#[doc(hidden)] pub fn serial_guard() -> std::sync::MutexGuard<'static, ()> {
static LOCK: std::sync::OnceLock<std::sync::Mutex<()>> = std::sync::OnceLock::new();
let lock = LOCK.get_or_init(|| std::sync::Mutex::new(()));
match lock.lock() {
Ok(guard) => guard,
Err(poison_error) => {
lock.clear_poison();
poison_error.into_inner()
}
}
}