cot/
test.rs

1//! Test utilities for Cot projects.
2
3use std::any::Any;
4use std::future::poll_fn;
5use std::sync::Arc;
6
7use async_trait::async_trait;
8use derive_more::Debug;
9use tower::Service;
10use tower_sessions::MemoryStore;
11
12#[cfg(feature = "db")]
13use crate::auth::db::DatabaseUserBackend;
14use crate::auth::{Auth, AuthBackend, NoAuthBackend, User, UserId};
15use crate::config::ProjectConfig;
16#[cfg(feature = "db")]
17use crate::db::Database;
18#[cfg(feature = "db")]
19use crate::db::migrations::{
20    DynMigration, MigrationDependency, MigrationEngine, MigrationWrapper, Operation,
21};
22use crate::handler::BoxedHandler;
23use crate::project::prepare_request;
24use crate::request::Request;
25use crate::response::Response;
26use crate::router::Router;
27use crate::session::Session;
28use crate::{Body, Bootstrapper, Project, ProjectContext, Result};
29
30/// A test client for making requests to a Cot project.
31///
32/// Useful for End-to-End testing Cot projects.
33#[derive(Debug)]
34pub struct Client {
35    context: Arc<ProjectContext>,
36    handler: BoxedHandler,
37}
38
39impl Client {
40    /// Create a new test client for a Cot project.
41    ///
42    /// # Panics
43    ///
44    /// Panics if the test config could not be loaded.
45    /// Panics if the project could not be initialized.
46    ///
47    /// # Examples
48    ///
49    /// ```
50    /// use cot::test::Client;
51    /// use cot::Project;
52    ///    use cot::config::ProjectConfig;
53    ///
54    /// struct MyProject;
55    /// impl Project for MyProject {
56    ///     fn config(&self, config_name: &str) -> cot::Result<ProjectConfig> {
57    ///         Ok(ProjectConfig::default())
58    ///     }
59    /// }
60    ///
61    /// # #[tokio::main]
62    /// # async fn main() -> cot::Result<()> {
63    /// let mut client = Client::new(MyProject).await;
64    /// let response = client.get("/").await?;
65    /// assert!(!response.into_body().into_bytes().await?.is_empty());
66    /// # Ok(())
67    /// }
68    /// ```
69    #[must_use]
70    #[expect(clippy::future_not_send)] // used in the test code
71    pub async fn new<P>(project: P) -> Self
72    where
73        P: Project + 'static,
74    {
75        let config = project.config("test").expect("Could not get test config");
76        let bootstrapper = Bootstrapper::new(project)
77            .with_config(config)
78            .boot()
79            .await
80            .expect("Could not boot project");
81
82        let (context, handler) = bootstrapper.into_context_and_handler();
83        Self {
84            context: Arc::new(context),
85            handler,
86        }
87    }
88
89    /// Send a GET request to the given path.
90    ///
91    /// # Errors
92    ///
93    /// Propagates any errors that the request handler might return.
94    ///
95    /// # Examples
96    ///
97    /// ```
98    /// use cot::test::Client;
99    /// use cot::Project;
100    ///    use cot::config::ProjectConfig;
101    ///
102    /// struct MyProject;
103    /// impl Project for MyProject {
104    ///     fn config(&self, config_name: &str) -> cot::Result<ProjectConfig> {
105    ///         Ok(ProjectConfig::default())
106    ///     }
107    /// }
108    ///
109    /// # #[tokio::main]
110    /// # async fn main() -> cot::Result<()> {
111    /// let mut client = Client::new(MyProject).await;
112    /// let response = client.get("/").await?;
113    /// assert!(!response.into_body().into_bytes().await?.is_empty());
114    /// # Ok(())
115    /// }
116    /// ```
117    pub async fn get(&mut self, path: &str) -> Result<Response> {
118        self.request(match http::Request::get(path).body(Body::empty()) {
119            Ok(request) => request,
120            Err(_) => {
121                unreachable!("Test request should be valid")
122            }
123        })
124        .await
125    }
126
127    /// Send a request to the given path.
128    ///
129    /// # Errors
130    ///
131    /// Propagates any errors that the request handler might return.
132    ///
133    /// # Examples
134    ///
135    /// ```
136    /// use cot::test::Client;
137    /// use cot::{Body, Project};
138    /// use cot::config::ProjectConfig;
139    ///
140    /// struct MyProject;
141    /// impl Project for MyProject {
142    ///     fn config(&self, config_name: &str) -> cot::Result<ProjectConfig> {
143    ///         Ok(ProjectConfig::default())
144    ///     }
145    /// }
146    ///
147    /// # #[tokio::main]
148    /// # async fn main() -> cot::Result<()> {
149    /// let mut client = Client::new(MyProject).await;
150    /// let response = client.request(cot::http::Request::get("/").body(Body::empty()).unwrap()).await?;
151    /// assert!(!response.into_body().into_bytes().await?.is_empty());
152    /// # Ok(())
153    /// }
154    /// ```
155    pub async fn request(&mut self, mut request: Request) -> Result<Response> {
156        prepare_request(&mut request, self.context.clone());
157
158        poll_fn(|cx| self.handler.poll_ready(cx)).await?;
159        self.handler.call(request).await
160    }
161}
162
163/// A builder for creating test requests, typically used for unit testing
164/// without having to create a full Cot project and do actual HTTP requests.
165///
166/// # Examples
167///
168/// ```
169/// use cot::Body;
170/// use cot::request::Request;
171/// use cot::response::{Response, ResponseExt};
172/// use cot::test::TestRequestBuilder;
173/// use http::StatusCode;
174///
175/// # #[tokio::main]
176/// # async fn main() -> cot::Result<()> {
177/// async fn index(request: Request) -> cot::Result<Response> {
178///     Ok(Response::new_html(
179///         StatusCode::OK,
180///         Body::fixed("Hello world!"),
181///     ))
182/// }
183///
184/// let request = TestRequestBuilder::get("/").build();
185///
186/// assert_eq!(
187///     index(request).await?.into_body().into_bytes().await?,
188///     "Hello world!"
189/// );
190/// # Ok(())
191/// # }
192/// ```
193#[derive(Debug, Clone)]
194pub struct TestRequestBuilder {
195    method: http::Method,
196    url: String,
197    router: Option<Router>,
198    session: Option<Session>,
199    config: Option<Arc<ProjectConfig>>,
200    auth_backend: Option<AuthBackendWrapper>,
201    auth: Option<Auth>,
202    #[cfg(feature = "db")]
203    database: Option<Arc<Database>>,
204    form_data: Option<Vec<(String, String)>>,
205    #[cfg(feature = "json")]
206    json_data: Option<String>,
207}
208
209/// A wrapper over an auth backend that is cloneable.
210#[derive(Debug, Clone)]
211struct AuthBackendWrapper {
212    #[debug("..")]
213    inner: Arc<dyn AuthBackend>,
214}
215
216impl AuthBackendWrapper {
217    pub(crate) fn new<AB>(inner: AB) -> Self
218    where
219        AB: AuthBackend + 'static,
220    {
221        Self {
222            inner: Arc::new(inner),
223        }
224    }
225}
226
227#[async_trait]
228impl AuthBackend for AuthBackendWrapper {
229    async fn authenticate(
230        &self,
231        credentials: &(dyn Any + Send + Sync),
232    ) -> cot::auth::Result<Option<Box<dyn User + Send + Sync>>> {
233        self.inner.authenticate(credentials).await
234    }
235
236    async fn get_by_id(
237        &self,
238        id: UserId,
239    ) -> cot::auth::Result<Option<Box<dyn User + Send + Sync>>> {
240        self.inner.get_by_id(id).await
241    }
242}
243
244impl Default for TestRequestBuilder {
245    fn default() -> Self {
246        Self {
247            method: http::Method::GET,
248            url: "/".to_string(),
249            router: None,
250            session: None,
251            config: None,
252            auth_backend: None,
253            auth: None,
254            #[cfg(feature = "db")]
255            database: None,
256            form_data: None,
257            #[cfg(feature = "json")]
258            json_data: None,
259        }
260    }
261}
262
263impl TestRequestBuilder {
264    /// Create a new GET request builder.
265    ///
266    /// # Examples
267    ///
268    /// ```
269    /// use cot::Body;
270    /// use cot::request::Request;
271    /// use cot::response::{Response, ResponseExt};
272    /// use cot::test::TestRequestBuilder;
273    /// use http::StatusCode;
274    ///
275    /// # #[tokio::main]
276    /// # async fn main() -> cot::Result<()> {
277    /// async fn index(request: Request) -> cot::Result<Response> {
278    ///     Ok(Response::new_html(
279    ///         StatusCode::OK,
280    ///         Body::fixed("Hello world!"),
281    ///     ))
282    /// }
283    ///
284    /// let request = TestRequestBuilder::get("/").build();
285    ///
286    /// assert_eq!(
287    ///     index(request).await?.into_body().into_bytes().await?,
288    ///     "Hello world!"
289    /// );
290    /// # Ok(())
291    /// # }
292    /// ```
293    #[must_use]
294    pub fn get(url: &str) -> Self {
295        Self {
296            method: http::Method::GET,
297            url: url.to_string(),
298            ..Self::default()
299        }
300    }
301
302    /// Create a new POST request builder.
303    ///
304    /// # Examples
305    ///
306    /// ```
307    /// use cot::Body;
308    /// use cot::request::Request;
309    /// use cot::response::{Response, ResponseExt};
310    /// use cot::test::TestRequestBuilder;
311    /// use http::StatusCode;
312    ///
313    /// # #[tokio::main]
314    /// # async fn main() -> cot::Result<()> {
315    /// async fn index(request: Request) -> cot::Result<Response> {
316    ///     Ok(Response::new_html(
317    ///         StatusCode::OK,
318    ///         Body::fixed("Hello world!"),
319    ///     ))
320    /// }
321    ///
322    /// let request = TestRequestBuilder::post("/").build();
323    ///
324    /// assert_eq!(
325    ///     index(request).await?.into_body().into_bytes().await?,
326    ///     "Hello world!"
327    /// );
328    /// # Ok(())
329    /// # }
330    /// ```
331    #[must_use]
332    pub fn post(url: &str) -> Self {
333        Self {
334            method: http::Method::POST,
335            url: url.to_string(),
336            ..Self::default()
337        }
338    }
339
340    /// Add a project config instance to the request builder.
341    ///
342    /// # Examples
343    ///
344    /// ```
345    /// use cot::config::ProjectConfig;
346    /// use cot::test::TestRequestBuilder;
347    ///
348    /// let request = TestRequestBuilder::get("/")
349    ///     .config(ProjectConfig::dev_default())
350    ///     .build();
351    /// ```
352    pub fn config(&mut self, config: ProjectConfig) -> &mut Self {
353        self.config = Some(Arc::new(config));
354        self
355    }
356
357    /// Create a new request builder with default configuration.
358    ///
359    /// # Examples
360    ///
361    /// ```
362    /// use cot::Body;
363    /// use cot::request::Request;
364    /// use cot::response::{Response, ResponseExt};
365    /// use cot::test::TestRequestBuilder;
366    /// use http::StatusCode;
367    ///
368    /// # #[tokio::main]
369    /// # async fn main() -> cot::Result<()> {
370    /// async fn index(request: Request) -> cot::Result<Response> {
371    ///     Ok(Response::new_html(
372    ///         StatusCode::OK,
373    ///         Body::fixed("Hello world!"),
374    ///     ))
375    /// }
376    ///
377    /// let request = TestRequestBuilder::get("/").with_default_config().build();
378    ///
379    /// assert_eq!(
380    ///     index(request).await?.into_body().into_bytes().await?,
381    ///     "Hello world!"
382    /// );
383    /// # Ok(())
384    /// # }
385    /// ```
386    pub fn with_default_config(&mut self) -> &mut Self {
387        self.config = Some(Arc::new(ProjectConfig::default()));
388        self
389    }
390
391    /// Add an authentication backend to the request builder.
392    ///
393    /// # Examples
394    ///
395    /// ```
396    /// use cot::auth::NoAuthBackend;
397    /// use cot::test::TestRequestBuilder;
398    ///
399    /// let request = TestRequestBuilder::get("/")
400    ///     .auth_backend(NoAuthBackend)
401    ///     .build();
402    /// ```
403    pub fn auth_backend<T: AuthBackend + 'static>(&mut self, auth_backend: T) -> &mut Self {
404        self.auth_backend = Some(AuthBackendWrapper::new(auth_backend));
405        self
406    }
407
408    /// Add a router to the request builder.
409    ///
410    /// # Examples
411    ///
412    /// ```
413    /// use cot::request::Request;
414    /// use cot::response::Response;
415    /// use cot::router::{Route, Router};
416    /// use cot::test::TestRequestBuilder;
417    ///
418    /// async fn index(request: Request) -> cot::Result<Response> {
419    ///     todo!()
420    /// }
421    ///
422    /// let router = Router::with_urls([Route::with_handler_and_name("/", index, "index")]);
423    /// let request = TestRequestBuilder::get("/").router(router).build();
424    /// ```
425    pub fn router(&mut self, router: Router) -> &mut Self {
426        self.router = Some(router);
427        self
428    }
429
430    /// Add a session support to the request builder.
431    ///
432    /// # Examples
433    ///
434    /// ```
435    /// use cot::test::TestRequestBuilder;
436    ///
437    /// let request = TestRequestBuilder::get("/").with_session().build();
438    /// ```
439    pub fn with_session(&mut self) -> &mut Self {
440        let session_store = MemoryStore::default();
441        let session_inner = tower_sessions::Session::new(None, Arc::new(session_store), None);
442        self.session = Some(Session::new(session_inner));
443        self
444    }
445
446    /// Add a session support to the request builder with the session copied
447    /// over from another [`Request`] object.
448    ///
449    /// # Examples
450    ///
451    /// ```
452    /// use cot::request::RequestExt;
453    /// use cot::session::Session;
454    /// use cot::test::TestRequestBuilder;
455    ///
456    /// # #[tokio::main]
457    /// # async fn main() -> cot::Result<()> {
458    /// let mut request = TestRequestBuilder::get("/").with_session().build();
459    /// Session::from_request(&request)
460    ///     .insert("key", "value")
461    ///     .await?;
462    ///
463    /// let mut request = TestRequestBuilder::get("/")
464    ///     .with_session_from(&request)
465    ///     .build();
466    /// # Ok(())
467    /// # }
468    /// ```
469    pub fn with_session_from(&mut self, request: &Request) -> &mut Self {
470        self.session = Some(Session::from_request(request).clone());
471        self
472    }
473
474    /// Add a session support to the request builder with the session object
475    /// provided as a parameter.
476    ///
477    /// # Examples
478    ///
479    /// ```
480    /// use cot::request::RequestExt;
481    /// use cot::session::Session;
482    /// use cot::test::TestRequestBuilder;
483    ///
484    /// # #[tokio::main]
485    /// # async fn main() -> cot::Result<()> {
486    /// let mut request = TestRequestBuilder::get("/").with_session().build();
487    /// let session = Session::from_request(&request);
488    /// session.insert("key", "value").await?;
489    ///
490    /// let mut request = TestRequestBuilder::get("/")
491    ///     .session(session.clone())
492    ///     .build();
493    /// # Ok(())
494    /// # }
495    /// ```
496    pub fn session(&mut self, session: Session) -> &mut Self {
497        self.session = Some(session);
498        self
499    }
500
501    /// Add a database to the request builder.
502    ///
503    /// # Examples
504    ///
505    /// ```
506    /// use std::sync::Arc;
507    ///
508    /// use cot::db::Database;
509    /// use cot::test::TestRequestBuilder;
510    /// use cot::request::{Request, RequestExt};
511    /// use cot::response::{Response, ResponseExt};
512    /// use cot::{Body, StatusCode};
513    ///
514    /// async fn index(request: Request) -> cot::Result<Response> {
515    ///     let db = request.db();
516    ///
517    ///     // ... do something with db
518    ///
519    ///     Ok(Response::new_html(
520    ///         StatusCode::OK,
521    ///         Body::fixed("Hello world!"),
522    ///     ))
523    /// }
524    ///
525    /// # #[tokio::main]
526    /// # async fn main() -> cot::Result<()> {
527    /// let request = TestRequestBuilder::get("/")
528    ///     .database(Database::new("sqlite::memory:").await?)
529    ///     .build();
530    /// # Ok(())
531    /// }
532    /// ```
533    #[cfg(feature = "db")]
534    pub fn database<DB: Into<Arc<Database>>>(&mut self, database: DB) -> &mut Self {
535        self.database = Some(database.into());
536        self
537    }
538
539    /// Use database authentication in the test request.
540    ///
541    /// Note that this calls [`Self::auth_backend`], [`Self::with_session`],
542    /// [`Self::database`], possibly overriding any values set by you earlier.
543    ///
544    /// # Panics
545    ///
546    /// Panics if the auth object fails to be created.
547    ///
548    /// # Examples
549    ///
550    /// ```
551    /// use cot::config::ProjectConfig;
552    /// use cot::test::{TestDatabase, TestRequestBuilder};
553    ///
554    /// # #[tokio::main]
555    /// # async fn main() -> cot::Result<()> {
556    /// let mut test_database = TestDatabase::new_sqlite().await?;
557    /// test_database.with_auth().run_migrations().await;
558    /// let request = TestRequestBuilder::get("/")
559    ///     .with_db_auth(test_database.database())
560    ///     .await
561    ///     .build();
562    /// # Ok(())
563    /// # }
564    /// ```
565    #[cfg(feature = "db")]
566    pub async fn with_db_auth(&mut self, db: Arc<Database>) -> &mut Self {
567        self.auth_backend(DatabaseUserBackend::new(Arc::clone(&db)));
568        self.with_session();
569        self.database(db);
570        self.auth = Some(
571            Auth::new(
572                self.session.clone().expect("Session was just set"),
573                self.auth_backend
574                    .clone()
575                    .expect("Auth backend was just set")
576                    .inner,
577                crate::config::SecretKey::from("000000"),
578                &[],
579            )
580            .await
581            .expect("Failed to create Auth"),
582        );
583
584        self
585    }
586
587    /// Add form data to the request builder.
588    ///
589    /// # Examples
590    ///
591    /// ```
592    /// use cot::test::TestRequestBuilder;
593    ///
594    /// let request = TestRequestBuilder::post("/")
595    ///     .form_data(&[("name", "Alice"), ("age", "30")])
596    ///     .build();
597    /// ```
598    pub fn form_data<T: ToString>(&mut self, form_data: &[(T, T)]) -> &mut Self {
599        self.form_data = Some(
600            form_data
601                .iter()
602                .map(|(k, v)| (k.to_string(), v.to_string()))
603                .collect(),
604        );
605        self
606    }
607
608    /// Add JSON data to the request builder.
609    ///
610    /// # Examples
611    ///
612    /// ```
613    /// use cot::test::TestRequestBuilder;
614    ///
615    /// #[derive(serde::Serialize)]
616    /// struct Data {
617    ///     key: String,
618    ///     value: i32,
619    /// }
620    ///
621    /// let request = TestRequestBuilder::post("/")
622    ///     .json(&Data {
623    ///         key: "value".to_string(),
624    ///         value: 42,
625    ///     })
626    ///     .build();
627    /// ```
628    ///
629    /// # Panics
630    ///
631    /// Panics if the JSON serialization fails.
632    #[cfg(feature = "json")]
633    pub fn json<T: serde::Serialize>(&mut self, data: &T) -> &mut Self {
634        self.json_data = Some(serde_json::to_string(data).expect("Failed to serialize JSON"));
635        self
636    }
637
638    /// Build the request.
639    ///
640    /// # Examples
641    ///
642    /// ```
643    /// use cot::Body;
644    /// use cot::request::Request;
645    /// use cot::response::{Response, ResponseExt};
646    /// use cot::test::TestRequestBuilder;
647    /// use http::StatusCode;
648    ///
649    /// # #[tokio::main]
650    /// # async fn main() -> cot::Result<()> {
651    /// async fn index(request: Request) -> cot::Result<Response> {
652    ///     Ok(Response::new_html(
653    ///         StatusCode::OK,
654    ///         Body::fixed("Hello world!"),
655    ///     ))
656    /// }
657    ///
658    /// let request = TestRequestBuilder::get("/").build();
659    ///
660    /// assert_eq!(
661    ///     index(request).await?.into_body().into_bytes().await?,
662    ///     "Hello world!"
663    /// );
664    /// # Ok(())
665    /// # }
666    /// ```
667    #[must_use]
668    pub fn build(&mut self) -> http::Request<Body> {
669        let Ok(mut request) = http::Request::builder()
670            .method(self.method.clone())
671            .uri(self.url.clone())
672            .body(Body::empty())
673        else {
674            unreachable!("Test request should be valid");
675        };
676
677        let auth_backend = std::mem::take(&mut self.auth_backend);
678        #[expect(trivial_casts)]
679        let auth_backend = match auth_backend {
680            Some(auth_backend) => Arc::new(auth_backend) as Arc<dyn AuthBackend>,
681            None => Arc::new(NoAuthBackend),
682        };
683
684        let context = ProjectContext::initialized(
685            self.config.clone().unwrap_or_default(),
686            Vec::new(),
687            Arc::new(self.router.clone().unwrap_or_else(Router::empty)),
688            auth_backend,
689            #[cfg(feature = "db")]
690            self.database.clone(),
691        );
692        prepare_request(&mut request, Arc::new(context));
693
694        if let Some(session) = &self.session {
695            request.extensions_mut().insert(session.clone());
696        }
697
698        if let Some(auth) = &self.auth {
699            request.extensions_mut().insert(auth.clone());
700        }
701
702        if let Some(form_data) = &self.form_data {
703            if self.method != http::Method::POST {
704                todo!("Form data can currently only be used with POST requests");
705            }
706
707            let mut data = form_urlencoded::Serializer::new(String::new());
708            for (key, value) in form_data {
709                data.append_pair(key, value);
710            }
711
712            *request.body_mut() = Body::fixed(data.finish());
713            request.headers_mut().insert(
714                http::header::CONTENT_TYPE,
715                http::HeaderValue::from_static("application/x-www-form-urlencoded"),
716            );
717        }
718
719        #[cfg(feature = "json")]
720        if let Some(json_data) = &self.json_data {
721            *request.body_mut() = Body::fixed(json_data.clone());
722            request.headers_mut().insert(
723                http::header::CONTENT_TYPE,
724                http::HeaderValue::from_static("application/json"),
725            );
726        }
727
728        request
729    }
730}
731
732/// A test database.
733///
734/// This is used to create a separate database for testing and run migrations on
735/// it.
736///
737/// # Examples
738///
739/// ```
740/// use cot::test::{TestDatabase, TestRequestBuilder};
741///
742/// # #[tokio::main]
743/// # async fn main() -> cot::Result<()> {
744/// let mut test_database = TestDatabase::new_sqlite().await?;
745/// let request = TestRequestBuilder::get("/")
746///     .database(test_database.database())
747///     .build();
748///
749/// // do something with the request
750///
751/// test_database.cleanup().await?;
752/// # Ok(())
753/// # }
754/// ```
755#[cfg(feature = "db")]
756#[derive(Debug)]
757pub struct TestDatabase {
758    database: Arc<Database>,
759    kind: TestDatabaseKind,
760    migrations: Vec<MigrationWrapper>,
761}
762
763#[cfg(feature = "db")]
764impl TestDatabase {
765    fn new(database: Database, kind: TestDatabaseKind) -> TestDatabase {
766        Self {
767            database: Arc::new(database),
768            kind,
769            migrations: Vec::new(),
770        }
771    }
772
773    /// Create a new in-memory SQLite database for testing.
774    ///
775    /// # Errors
776    ///
777    /// If the database could not have been created.
778    ///
779    ///
780    /// # Examples
781    ///
782    /// ```
783    /// use cot::test::{TestDatabase, TestRequestBuilder};
784    ///
785    /// # #[tokio::main]
786    /// # async fn main() -> cot::Result<()> {
787    /// let mut test_database = TestDatabase::new_sqlite().await?;
788    /// let request = TestRequestBuilder::get("/")
789    ///     .database(test_database.database())
790    ///     .build();
791    ///
792    /// // do something with the request
793    ///
794    /// test_database.cleanup().await?;
795    /// # Ok(())
796    /// # }
797    /// ```
798    pub async fn new_sqlite() -> Result<Self> {
799        let database = Database::new("sqlite::memory:").await?;
800        Ok(Self::new(database, TestDatabaseKind::Sqlite))
801    }
802
803    /// Create a new PostgreSQL database for testing and connects to it.
804    ///
805    /// The database URL is read from the `POSTGRES_URL` environment variable.
806    /// Note that it shouldn't include the database name — the function will
807    /// create a new database for the test by connecting to the `postgres`
808    /// database. If no URL is provided, it defaults to
809    /// `postgresql://cot:cot@localhost`.
810    ///
811    /// The database is created with the name `test_cot__{test_name}`.
812    /// Make sure that `test_name` is unique for each test so that the databases
813    /// don't conflict with each other.
814    ///
815    /// The database is dropped when `self.cleanup()` is called. Note that this
816    /// means that the database will not be dropped if the test panics.
817    ///
818    /// # Errors
819    ///
820    /// Returns an error if a database connection (either to the test database,
821    /// or postgres maintenance database) could not be established.
822    ///
823    /// Returns an error if the old test database could not be dropped.
824    ///
825    /// Returns an error if the new test database could not be created.
826    ///
827    /// # Examples
828    ///
829    /// ```no_run
830    /// use cot::test::{TestDatabase, TestRequestBuilder};
831    ///
832    /// # #[tokio::main]
833    /// # async fn main() -> cot::Result<()> {
834    /// let mut test_database = TestDatabase::new_postgres("my_test").await?;
835    /// let request = TestRequestBuilder::get("/")
836    ///     .database(test_database.database())
837    ///     .build();
838    ///
839    /// // do something with the request
840    ///
841    /// test_database.cleanup().await?;
842    /// # Ok(())
843    /// # }
844    /// ```
845    pub async fn new_postgres(test_name: &str) -> Result<Self> {
846        let db_url = std::env::var("POSTGRES_URL")
847            .unwrap_or_else(|_| "postgresql://cot:cot@localhost".to_string());
848        let database = Database::new(format!("{db_url}/postgres")).await?;
849
850        let test_database_name = format!("test_cot__{test_name}");
851        database
852            .raw(&format!("DROP DATABASE IF EXISTS {test_database_name}"))
853            .await?;
854        database
855            .raw(&format!("CREATE DATABASE {test_database_name}"))
856            .await?;
857        database.close().await?;
858
859        let database = Database::new(format!("{db_url}/{test_database_name}")).await?;
860
861        Ok(Self::new(
862            database,
863            TestDatabaseKind::Postgres {
864                db_url,
865                db_name: test_database_name,
866            },
867        ))
868    }
869
870    /// Create a new MySQL database for testing and connects to it.
871    ///
872    /// The database URL is read from the `MYSQL_URL` environment variable.
873    /// Note that it shouldn't include the database name — the function will
874    /// create a new database for the test by connecting to the `mysql`
875    /// database. If no URL is provided, it defaults to
876    /// `mysql://root:@localhost`.
877    ///
878    /// The database is created with the name `test_cot__{test_name}`.
879    /// Make sure that `test_name` is unique for each test so that the databases
880    /// don't conflict with each other.
881    ///
882    /// The database is dropped when `self.cleanup()` is called. Note that this
883    /// means that the database will not be dropped if the test panics.
884    ///
885    ///
886    /// # Errors
887    ///
888    /// Returns an error if a database connection (either to the test database,
889    /// or MySQL maintenance database) could not be established.
890    ///
891    /// Returns an error if the old test database could not be dropped.
892    ///
893    /// Returns an error if the new test database could not be created.
894    ///
895    /// # Examples
896    ///
897    /// ```no_run
898    /// use cot::test::{TestDatabase, TestRequestBuilder};
899    ///
900    /// # #[tokio::main]
901    /// # async fn main() -> cot::Result<()> {
902    /// let mut test_database = TestDatabase::new_mysql("my_test").await?;
903    /// let request = TestRequestBuilder::get("/")
904    ///     .database(test_database.database())
905    ///     .build();
906    ///
907    /// // do something with the request
908    ///
909    /// test_database.cleanup().await?;
910    /// # Ok(())
911    /// # }
912    /// ```
913    pub async fn new_mysql(test_name: &str) -> Result<Self> {
914        let db_url =
915            std::env::var("MYSQL_URL").unwrap_or_else(|_| "mysql://root:@localhost".to_string());
916        let database = Database::new(format!("{db_url}/mysql")).await?;
917
918        let test_database_name = format!("test_cot__{test_name}");
919        database
920            .raw(&format!("DROP DATABASE IF EXISTS {test_database_name}"))
921            .await?;
922        database
923            .raw(&format!("CREATE DATABASE {test_database_name}"))
924            .await?;
925        database.close().await?;
926
927        let database = Database::new(format!("{db_url}/{test_database_name}")).await?;
928
929        Ok(Self::new(
930            database,
931            TestDatabaseKind::MySql {
932                db_url,
933                db_name: test_database_name,
934            },
935        ))
936    }
937
938    /// Add the default Cot authentication migrations to the test database.
939    ///
940    /// This is useful if you want to test something that requires
941    /// authentication.
942    ///
943    /// # Examples
944    ///
945    /// ```
946    /// use cot::test::{TestDatabase, TestRequestBuilder};
947    ///
948    /// # #[tokio::main]
949    /// # async fn main() -> cot::Result<()> {
950    /// let mut test_database = TestDatabase::new_sqlite().await?;
951    /// test_database.with_auth().run_migrations().await;
952    ///
953    /// let request = TestRequestBuilder::get("/")
954    ///     .with_db_auth(test_database.database())
955    ///     .await
956    ///     .build();
957    ///
958    /// // do something with the request
959    ///
960    /// test_database.cleanup().await?;
961    /// # Ok(())
962    /// # }
963    /// ```
964    #[cfg(feature = "db")]
965    pub fn with_auth(&mut self) -> &mut Self {
966        self.add_migrations(cot::auth::db::migrations::MIGRATIONS.to_vec());
967        self
968    }
969
970    /// Add migrations to the test database.
971    ///
972    /// # Examples
973    ///
974    /// ```
975    /// use cot::test::{TestDatabase, TestMigration};
976    ///
977    /// # #[tokio::main]
978    /// # async fn main() -> cot::Result<()> {
979    /// let mut test_database = TestDatabase::new_sqlite().await?;
980    ///
981    /// test_database.add_migrations(vec![TestMigration::new(
982    ///     "auth",
983    ///     "create_users",
984    ///     vec![],
985    ///     vec![],
986    /// )]);
987    /// # Ok(())
988    /// # }
989    /// ```
990    pub fn add_migrations<T: DynMigration + Send + Sync + 'static, V: IntoIterator<Item = T>>(
991        &mut self,
992        migrations: V,
993    ) -> &mut Self {
994        self.migrations
995            .extend(migrations.into_iter().map(MigrationWrapper::new));
996        self
997    }
998
999    /// Run the migrations on the test database.
1000    ///
1001    /// # Panics
1002    ///
1003    /// Panics if the migration engine could not be initialized or if the
1004    /// migrations could not be run.
1005    ///
1006    /// # Examples
1007    ///
1008    /// ```
1009    /// use cot::test::{TestDatabase, TestMigration};
1010    ///
1011    /// # #[tokio::main]
1012    /// # async fn main() -> cot::Result<()> {
1013    /// let mut test_database = TestDatabase::new_sqlite().await?;
1014    /// test_database.add_migrations(vec![TestMigration::new(
1015    ///     "auth",
1016    ///     "create_users",
1017    ///     vec![],
1018    ///     vec![],
1019    /// )]);
1020    ///
1021    /// test_database.run_migrations().await;
1022    /// # Ok(())
1023    /// # }
1024    /// ```
1025    pub async fn run_migrations(&mut self) -> &mut Self {
1026        if !self.migrations.is_empty() {
1027            let engine = MigrationEngine::new(std::mem::take(&mut self.migrations))
1028                .expect("Failed to initialize the migration engine");
1029            engine
1030                .run(&self.database())
1031                .await
1032                .expect("Failed to run migrations");
1033        }
1034        self
1035    }
1036
1037    /// Get the database.
1038    ///
1039    /// # Examples
1040    ///
1041    /// ```
1042    /// use cot::test::{TestDatabase, TestRequestBuilder};
1043    ///
1044    /// # #[tokio::main]
1045    /// # async fn main() -> cot::Result<()> {
1046    /// let database = TestDatabase::new_sqlite().await?;
1047    ///
1048    /// let request = TestRequestBuilder::get("/")
1049    ///     .database(database.database())
1050    ///     .build();
1051    /// # Ok(())
1052    /// # }
1053    /// ```
1054    #[must_use]
1055    pub fn database(&self) -> Arc<Database> {
1056        self.database.clone()
1057    }
1058
1059    /// Cleanup the test database.
1060    ///
1061    /// This removes the test database and closes the connection. Note that this
1062    /// means that the database will not be dropped if the test panics, nor will
1063    /// it be dropped if you don't call this function.
1064    ///
1065    /// # Errors
1066    ///
1067    /// Returns an error if the database could not be closed or if the database
1068    /// could not be dropped.
1069    ///
1070    /// # Examples
1071    ///
1072    /// ```
1073    /// use cot::test::TestDatabase;
1074    ///
1075    /// # #[tokio::main]
1076    /// # async fn main() -> cot::Result<()> {
1077    /// let mut test_database = TestDatabase::new_sqlite().await?;
1078    /// test_database.cleanup().await?;
1079    /// # Ok(())
1080    /// # }
1081    /// ```
1082    pub async fn cleanup(&self) -> Result<()> {
1083        self.database.close().await?;
1084        match &self.kind {
1085            TestDatabaseKind::Sqlite => {}
1086            TestDatabaseKind::Postgres { db_url, db_name } => {
1087                let database = Database::new(format!("{db_url}/postgres")).await?;
1088
1089                database
1090                    .raw(&format!("DROP DATABASE {db_name} WITH (FORCE)"))
1091                    .await?;
1092                database.close().await?;
1093            }
1094            TestDatabaseKind::MySql { db_url, db_name } => {
1095                let database = Database::new(format!("{db_url}/mysql")).await?;
1096
1097                database.raw(&format!("DROP DATABASE {db_name}")).await?;
1098                database.close().await?;
1099            }
1100        }
1101
1102        Ok(())
1103    }
1104}
1105
1106#[cfg(feature = "db")]
1107impl std::ops::Deref for TestDatabase {
1108    type Target = Database;
1109
1110    fn deref(&self) -> &Self::Target {
1111        &self.database
1112    }
1113}
1114
1115#[cfg(feature = "db")]
1116#[derive(Debug, Clone)]
1117enum TestDatabaseKind {
1118    Sqlite,
1119    Postgres { db_url: String, db_name: String },
1120    MySql { db_url: String, db_name: String },
1121}
1122
1123/// A test migration.
1124///
1125/// This can be used if you need a dynamically created migration for testing.
1126///
1127/// # Examples
1128///
1129/// ```
1130/// use cot::db::migrations::{Field, Operation};
1131/// use cot::db::{ColumnType, Identifier};
1132/// use cot::test::{TestDatabase, TestMigration};
1133///
1134/// const OPERATION: Operation = Operation::create_model()
1135///     .table_name(Identifier::new("myapp__users"))
1136///     .fields(&[Field::new(Identifier::new("id"), ColumnType::Integer)
1137///         .auto()
1138///         .primary_key()])
1139///     .build();
1140///
1141/// let migration = TestMigration::new("auth", "create_users", vec![], vec![OPERATION]);
1142/// ```
1143#[cfg(feature = "db")]
1144#[derive(Debug, Clone)]
1145pub struct TestMigration {
1146    app_name: &'static str,
1147    name: &'static str,
1148    dependencies: Vec<MigrationDependency>,
1149    operations: Vec<Operation>,
1150}
1151
1152#[cfg(feature = "db")]
1153impl TestMigration {
1154    /// Create a new test migration.
1155    ///
1156    /// # Examples
1157    ///
1158    /// ```
1159    /// use cot::db::migrations::{Field, Operation};
1160    /// use cot::db::{ColumnType, Identifier};
1161    /// use cot::test::{TestDatabase, TestMigration};
1162    ///
1163    /// const OPERATION: Operation = Operation::create_model()
1164    ///     .table_name(Identifier::new("myapp__users"))
1165    ///     .fields(&[Field::new(Identifier::new("id"), ColumnType::Integer)
1166    ///         .auto()
1167    ///         .primary_key()])
1168    ///     .build();
1169    ///
1170    /// let migration = TestMigration::new("auth", "create_users", vec![], vec![OPERATION]);
1171    /// ```
1172    #[must_use]
1173    pub fn new<D: Into<Vec<MigrationDependency>>, O: Into<Vec<Operation>>>(
1174        app_name: &'static str,
1175        name: &'static str,
1176        dependencies: D,
1177        operations: O,
1178    ) -> Self {
1179        Self {
1180            app_name,
1181            name,
1182            dependencies: dependencies.into(),
1183            operations: operations.into(),
1184        }
1185    }
1186}
1187
1188#[cfg(feature = "db")]
1189impl DynMigration for TestMigration {
1190    fn app_name(&self) -> &str {
1191        self.app_name
1192    }
1193
1194    fn name(&self) -> &str {
1195        self.name
1196    }
1197
1198    fn dependencies(&self) -> &[MigrationDependency] {
1199        &self.dependencies
1200    }
1201
1202    fn operations(&self) -> &[Operation] {
1203        &self.operations
1204    }
1205}
1206
1207/// A guard for running tests serially.
1208///
1209/// This is mostly useful for tests that need to modify some global state (e.g.
1210/// environment variables or current working directory).
1211#[doc(hidden)] // not part of the public API; used in cot-cli
1212pub fn serial_guard() -> std::sync::MutexGuard<'static, ()> {
1213    static LOCK: std::sync::OnceLock<std::sync::Mutex<()>> = std::sync::OnceLock::new();
1214    let lock = LOCK.get_or_init(|| std::sync::Mutex::new(()));
1215    match lock.lock() {
1216        Ok(guard) => guard,
1217        Err(poison_error) => {
1218            lock.clear_poison();
1219            // We can ignore poisoned mutexes because we don't store any data inside
1220            poison_error.into_inner()
1221        }
1222    }
1223}