cot/test.rs
1//! Test utilities for Cot projects.
2
3use std::any::Any;
4use std::future::poll_fn;
5use std::sync::Arc;
6
7use async_trait::async_trait;
8use derive_more::Debug;
9use tower::Service;
10use tower_sessions::MemoryStore;
11
12#[cfg(feature = "db")]
13use crate::auth::db::DatabaseUserBackend;
14use crate::auth::{Auth, AuthBackend, NoAuthBackend, User, UserId};
15use crate::config::ProjectConfig;
16#[cfg(feature = "db")]
17use crate::db::Database;
18#[cfg(feature = "db")]
19use crate::db::migrations::{
20 DynMigration, MigrationDependency, MigrationEngine, MigrationWrapper, Operation,
21};
22use crate::handler::BoxedHandler;
23use crate::project::prepare_request;
24use crate::request::Request;
25use crate::response::Response;
26use crate::router::Router;
27use crate::session::Session;
28use crate::{Body, Bootstrapper, Project, ProjectContext, Result};
29
30/// A test client for making requests to a Cot project.
31///
32/// Useful for End-to-End testing Cot projects.
33#[derive(Debug)]
34pub struct Client {
35 context: Arc<ProjectContext>,
36 handler: BoxedHandler,
37}
38
39impl Client {
40 /// Create a new test client for a Cot project.
41 ///
42 /// # Panics
43 ///
44 /// Panics if the test config could not be loaded.
45 /// Panics if the project could not be initialized.
46 ///
47 /// # Examples
48 ///
49 /// ```
50 /// use cot::test::Client;
51 /// use cot::Project;
52 /// use cot::config::ProjectConfig;
53 ///
54 /// struct MyProject;
55 /// impl Project for MyProject {
56 /// fn config(&self, config_name: &str) -> cot::Result<ProjectConfig> {
57 /// Ok(ProjectConfig::default())
58 /// }
59 /// }
60 ///
61 /// # #[tokio::main]
62 /// # async fn main() -> cot::Result<()> {
63 /// let mut client = Client::new(MyProject).await;
64 /// let response = client.get("/").await?;
65 /// assert!(!response.into_body().into_bytes().await?.is_empty());
66 /// # Ok(())
67 /// }
68 /// ```
69 #[must_use]
70 #[expect(clippy::future_not_send)] // used in the test code
71 pub async fn new<P>(project: P) -> Self
72 where
73 P: Project + 'static,
74 {
75 let config = project.config("test").expect("Could not get test config");
76 let bootstrapper = Bootstrapper::new(project)
77 .with_config(config)
78 .boot()
79 .await
80 .expect("Could not boot project");
81
82 let (context, handler) = bootstrapper.into_context_and_handler();
83 Self {
84 context: Arc::new(context),
85 handler,
86 }
87 }
88
89 /// Send a GET request to the given path.
90 ///
91 /// # Errors
92 ///
93 /// Propagates any errors that the request handler might return.
94 ///
95 /// # Examples
96 ///
97 /// ```
98 /// use cot::test::Client;
99 /// use cot::Project;
100 /// use cot::config::ProjectConfig;
101 ///
102 /// struct MyProject;
103 /// impl Project for MyProject {
104 /// fn config(&self, config_name: &str) -> cot::Result<ProjectConfig> {
105 /// Ok(ProjectConfig::default())
106 /// }
107 /// }
108 ///
109 /// # #[tokio::main]
110 /// # async fn main() -> cot::Result<()> {
111 /// let mut client = Client::new(MyProject).await;
112 /// let response = client.get("/").await?;
113 /// assert!(!response.into_body().into_bytes().await?.is_empty());
114 /// # Ok(())
115 /// }
116 /// ```
117 pub async fn get(&mut self, path: &str) -> Result<Response> {
118 self.request(match http::Request::get(path).body(Body::empty()) {
119 Ok(request) => request,
120 Err(_) => {
121 unreachable!("Test request should be valid")
122 }
123 })
124 .await
125 }
126
127 /// Send a request to the given path.
128 ///
129 /// # Errors
130 ///
131 /// Propagates any errors that the request handler might return.
132 ///
133 /// # Examples
134 ///
135 /// ```
136 /// use cot::test::Client;
137 /// use cot::{Body, Project};
138 /// use cot::config::ProjectConfig;
139 ///
140 /// struct MyProject;
141 /// impl Project for MyProject {
142 /// fn config(&self, config_name: &str) -> cot::Result<ProjectConfig> {
143 /// Ok(ProjectConfig::default())
144 /// }
145 /// }
146 ///
147 /// # #[tokio::main]
148 /// # async fn main() -> cot::Result<()> {
149 /// let mut client = Client::new(MyProject).await;
150 /// let response = client.request(cot::http::Request::get("/").body(Body::empty()).unwrap()).await?;
151 /// assert!(!response.into_body().into_bytes().await?.is_empty());
152 /// # Ok(())
153 /// }
154 /// ```
155 pub async fn request(&mut self, mut request: Request) -> Result<Response> {
156 prepare_request(&mut request, self.context.clone());
157
158 poll_fn(|cx| self.handler.poll_ready(cx)).await?;
159 self.handler.call(request).await
160 }
161}
162
163/// A builder for creating test requests, typically used for unit testing
164/// without having to create a full Cot project and do actual HTTP requests.
165///
166/// # Examples
167///
168/// ```
169/// use cot::Body;
170/// use cot::request::Request;
171/// use cot::response::{Response, ResponseExt};
172/// use cot::test::TestRequestBuilder;
173/// use http::StatusCode;
174///
175/// # #[tokio::main]
176/// # async fn main() -> cot::Result<()> {
177/// async fn index(request: Request) -> cot::Result<Response> {
178/// Ok(Response::new_html(
179/// StatusCode::OK,
180/// Body::fixed("Hello world!"),
181/// ))
182/// }
183///
184/// let request = TestRequestBuilder::get("/").build();
185///
186/// assert_eq!(
187/// index(request).await?.into_body().into_bytes().await?,
188/// "Hello world!"
189/// );
190/// # Ok(())
191/// # }
192/// ```
193#[derive(Debug, Clone)]
194pub struct TestRequestBuilder {
195 method: http::Method,
196 url: String,
197 router: Option<Router>,
198 session: Option<Session>,
199 config: Option<Arc<ProjectConfig>>,
200 auth_backend: Option<AuthBackendWrapper>,
201 auth: Option<Auth>,
202 #[cfg(feature = "db")]
203 database: Option<Arc<Database>>,
204 form_data: Option<Vec<(String, String)>>,
205 #[cfg(feature = "json")]
206 json_data: Option<String>,
207}
208
209/// A wrapper over an auth backend that is cloneable.
210#[derive(Debug, Clone)]
211struct AuthBackendWrapper {
212 #[debug("..")]
213 inner: Arc<dyn AuthBackend>,
214}
215
216impl AuthBackendWrapper {
217 pub(crate) fn new<AB>(inner: AB) -> Self
218 where
219 AB: AuthBackend + 'static,
220 {
221 Self {
222 inner: Arc::new(inner),
223 }
224 }
225}
226
227#[async_trait]
228impl AuthBackend for AuthBackendWrapper {
229 async fn authenticate(
230 &self,
231 credentials: &(dyn Any + Send + Sync),
232 ) -> cot::auth::Result<Option<Box<dyn User + Send + Sync>>> {
233 self.inner.authenticate(credentials).await
234 }
235
236 async fn get_by_id(
237 &self,
238 id: UserId,
239 ) -> cot::auth::Result<Option<Box<dyn User + Send + Sync>>> {
240 self.inner.get_by_id(id).await
241 }
242}
243
244impl Default for TestRequestBuilder {
245 fn default() -> Self {
246 Self {
247 method: http::Method::GET,
248 url: "/".to_string(),
249 router: None,
250 session: None,
251 config: None,
252 auth_backend: None,
253 auth: None,
254 #[cfg(feature = "db")]
255 database: None,
256 form_data: None,
257 #[cfg(feature = "json")]
258 json_data: None,
259 }
260 }
261}
262
263impl TestRequestBuilder {
264 /// Create a new GET request builder.
265 ///
266 /// # Examples
267 ///
268 /// ```
269 /// use cot::Body;
270 /// use cot::request::Request;
271 /// use cot::response::{Response, ResponseExt};
272 /// use cot::test::TestRequestBuilder;
273 /// use http::StatusCode;
274 ///
275 /// # #[tokio::main]
276 /// # async fn main() -> cot::Result<()> {
277 /// async fn index(request: Request) -> cot::Result<Response> {
278 /// Ok(Response::new_html(
279 /// StatusCode::OK,
280 /// Body::fixed("Hello world!"),
281 /// ))
282 /// }
283 ///
284 /// let request = TestRequestBuilder::get("/").build();
285 ///
286 /// assert_eq!(
287 /// index(request).await?.into_body().into_bytes().await?,
288 /// "Hello world!"
289 /// );
290 /// # Ok(())
291 /// # }
292 /// ```
293 #[must_use]
294 pub fn get(url: &str) -> Self {
295 Self {
296 method: http::Method::GET,
297 url: url.to_string(),
298 ..Self::default()
299 }
300 }
301
302 /// Create a new POST request builder.
303 ///
304 /// # Examples
305 ///
306 /// ```
307 /// use cot::Body;
308 /// use cot::request::Request;
309 /// use cot::response::{Response, ResponseExt};
310 /// use cot::test::TestRequestBuilder;
311 /// use http::StatusCode;
312 ///
313 /// # #[tokio::main]
314 /// # async fn main() -> cot::Result<()> {
315 /// async fn index(request: Request) -> cot::Result<Response> {
316 /// Ok(Response::new_html(
317 /// StatusCode::OK,
318 /// Body::fixed("Hello world!"),
319 /// ))
320 /// }
321 ///
322 /// let request = TestRequestBuilder::post("/").build();
323 ///
324 /// assert_eq!(
325 /// index(request).await?.into_body().into_bytes().await?,
326 /// "Hello world!"
327 /// );
328 /// # Ok(())
329 /// # }
330 /// ```
331 #[must_use]
332 pub fn post(url: &str) -> Self {
333 Self {
334 method: http::Method::POST,
335 url: url.to_string(),
336 ..Self::default()
337 }
338 }
339
340 /// Add a project config instance to the request builder.
341 ///
342 /// # Examples
343 ///
344 /// ```
345 /// use cot::config::ProjectConfig;
346 /// use cot::test::TestRequestBuilder;
347 ///
348 /// let request = TestRequestBuilder::get("/")
349 /// .config(ProjectConfig::dev_default())
350 /// .build();
351 /// ```
352 pub fn config(&mut self, config: ProjectConfig) -> &mut Self {
353 self.config = Some(Arc::new(config));
354 self
355 }
356
357 /// Create a new request builder with default configuration.
358 ///
359 /// # Examples
360 ///
361 /// ```
362 /// use cot::Body;
363 /// use cot::request::Request;
364 /// use cot::response::{Response, ResponseExt};
365 /// use cot::test::TestRequestBuilder;
366 /// use http::StatusCode;
367 ///
368 /// # #[tokio::main]
369 /// # async fn main() -> cot::Result<()> {
370 /// async fn index(request: Request) -> cot::Result<Response> {
371 /// Ok(Response::new_html(
372 /// StatusCode::OK,
373 /// Body::fixed("Hello world!"),
374 /// ))
375 /// }
376 ///
377 /// let request = TestRequestBuilder::get("/").with_default_config().build();
378 ///
379 /// assert_eq!(
380 /// index(request).await?.into_body().into_bytes().await?,
381 /// "Hello world!"
382 /// );
383 /// # Ok(())
384 /// # }
385 /// ```
386 pub fn with_default_config(&mut self) -> &mut Self {
387 self.config = Some(Arc::new(ProjectConfig::default()));
388 self
389 }
390
391 /// Add an authentication backend to the request builder.
392 ///
393 /// # Examples
394 ///
395 /// ```
396 /// use cot::auth::NoAuthBackend;
397 /// use cot::test::TestRequestBuilder;
398 ///
399 /// let request = TestRequestBuilder::get("/")
400 /// .auth_backend(NoAuthBackend)
401 /// .build();
402 /// ```
403 pub fn auth_backend<T: AuthBackend + 'static>(&mut self, auth_backend: T) -> &mut Self {
404 self.auth_backend = Some(AuthBackendWrapper::new(auth_backend));
405 self
406 }
407
408 /// Add a router to the request builder.
409 ///
410 /// # Examples
411 ///
412 /// ```
413 /// use cot::request::Request;
414 /// use cot::response::Response;
415 /// use cot::router::{Route, Router};
416 /// use cot::test::TestRequestBuilder;
417 ///
418 /// async fn index(request: Request) -> cot::Result<Response> {
419 /// todo!()
420 /// }
421 ///
422 /// let router = Router::with_urls([Route::with_handler_and_name("/", index, "index")]);
423 /// let request = TestRequestBuilder::get("/").router(router).build();
424 /// ```
425 pub fn router(&mut self, router: Router) -> &mut Self {
426 self.router = Some(router);
427 self
428 }
429
430 /// Add a session support to the request builder.
431 ///
432 /// # Examples
433 ///
434 /// ```
435 /// use cot::test::TestRequestBuilder;
436 ///
437 /// let request = TestRequestBuilder::get("/").with_session().build();
438 /// ```
439 pub fn with_session(&mut self) -> &mut Self {
440 let session_store = MemoryStore::default();
441 let session_inner = tower_sessions::Session::new(None, Arc::new(session_store), None);
442 self.session = Some(Session::new(session_inner));
443 self
444 }
445
446 /// Add a session support to the request builder with the session copied
447 /// over from another [`Request`] object.
448 ///
449 /// # Examples
450 ///
451 /// ```
452 /// use cot::request::RequestExt;
453 /// use cot::session::Session;
454 /// use cot::test::TestRequestBuilder;
455 ///
456 /// # #[tokio::main]
457 /// # async fn main() -> cot::Result<()> {
458 /// let mut request = TestRequestBuilder::get("/").with_session().build();
459 /// Session::from_request(&request)
460 /// .insert("key", "value")
461 /// .await?;
462 ///
463 /// let mut request = TestRequestBuilder::get("/")
464 /// .with_session_from(&request)
465 /// .build();
466 /// # Ok(())
467 /// # }
468 /// ```
469 pub fn with_session_from(&mut self, request: &Request) -> &mut Self {
470 self.session = Some(Session::from_request(request).clone());
471 self
472 }
473
474 /// Add a session support to the request builder with the session object
475 /// provided as a parameter.
476 ///
477 /// # Examples
478 ///
479 /// ```
480 /// use cot::request::RequestExt;
481 /// use cot::session::Session;
482 /// use cot::test::TestRequestBuilder;
483 ///
484 /// # #[tokio::main]
485 /// # async fn main() -> cot::Result<()> {
486 /// let mut request = TestRequestBuilder::get("/").with_session().build();
487 /// let session = Session::from_request(&request);
488 /// session.insert("key", "value").await?;
489 ///
490 /// let mut request = TestRequestBuilder::get("/")
491 /// .session(session.clone())
492 /// .build();
493 /// # Ok(())
494 /// # }
495 /// ```
496 pub fn session(&mut self, session: Session) -> &mut Self {
497 self.session = Some(session);
498 self
499 }
500
501 /// Add a database to the request builder.
502 ///
503 /// # Examples
504 ///
505 /// ```
506 /// use std::sync::Arc;
507 ///
508 /// use cot::db::Database;
509 /// use cot::test::TestRequestBuilder;
510 /// use cot::request::{Request, RequestExt};
511 /// use cot::response::{Response, ResponseExt};
512 /// use cot::{Body, StatusCode};
513 ///
514 /// async fn index(request: Request) -> cot::Result<Response> {
515 /// let db = request.db();
516 ///
517 /// // ... do something with db
518 ///
519 /// Ok(Response::new_html(
520 /// StatusCode::OK,
521 /// Body::fixed("Hello world!"),
522 /// ))
523 /// }
524 ///
525 /// # #[tokio::main]
526 /// # async fn main() -> cot::Result<()> {
527 /// let request = TestRequestBuilder::get("/")
528 /// .database(Database::new("sqlite::memory:").await?)
529 /// .build();
530 /// # Ok(())
531 /// }
532 /// ```
533 #[cfg(feature = "db")]
534 pub fn database<DB: Into<Arc<Database>>>(&mut self, database: DB) -> &mut Self {
535 self.database = Some(database.into());
536 self
537 }
538
539 /// Use database authentication in the test request.
540 ///
541 /// Note that this calls [`Self::auth_backend`], [`Self::with_session`],
542 /// [`Self::database`], possibly overriding any values set by you earlier.
543 ///
544 /// # Panics
545 ///
546 /// Panics if the auth object fails to be created.
547 ///
548 /// # Examples
549 ///
550 /// ```
551 /// use cot::config::ProjectConfig;
552 /// use cot::test::{TestDatabase, TestRequestBuilder};
553 ///
554 /// # #[tokio::main]
555 /// # async fn main() -> cot::Result<()> {
556 /// let mut test_database = TestDatabase::new_sqlite().await?;
557 /// test_database.with_auth().run_migrations().await;
558 /// let request = TestRequestBuilder::get("/")
559 /// .with_db_auth(test_database.database())
560 /// .await
561 /// .build();
562 /// # Ok(())
563 /// # }
564 /// ```
565 #[cfg(feature = "db")]
566 pub async fn with_db_auth(&mut self, db: Arc<Database>) -> &mut Self {
567 self.auth_backend(DatabaseUserBackend::new(Arc::clone(&db)));
568 self.with_session();
569 self.database(db);
570 self.auth = Some(
571 Auth::new(
572 self.session.clone().expect("Session was just set"),
573 self.auth_backend
574 .clone()
575 .expect("Auth backend was just set")
576 .inner,
577 crate::config::SecretKey::from("000000"),
578 &[],
579 )
580 .await
581 .expect("Failed to create Auth"),
582 );
583
584 self
585 }
586
587 /// Add form data to the request builder.
588 ///
589 /// # Examples
590 ///
591 /// ```
592 /// use cot::test::TestRequestBuilder;
593 ///
594 /// let request = TestRequestBuilder::post("/")
595 /// .form_data(&[("name", "Alice"), ("age", "30")])
596 /// .build();
597 /// ```
598 pub fn form_data<T: ToString>(&mut self, form_data: &[(T, T)]) -> &mut Self {
599 self.form_data = Some(
600 form_data
601 .iter()
602 .map(|(k, v)| (k.to_string(), v.to_string()))
603 .collect(),
604 );
605 self
606 }
607
608 /// Add JSON data to the request builder.
609 ///
610 /// # Examples
611 ///
612 /// ```
613 /// use cot::test::TestRequestBuilder;
614 ///
615 /// #[derive(serde::Serialize)]
616 /// struct Data {
617 /// key: String,
618 /// value: i32,
619 /// }
620 ///
621 /// let request = TestRequestBuilder::post("/")
622 /// .json(&Data {
623 /// key: "value".to_string(),
624 /// value: 42,
625 /// })
626 /// .build();
627 /// ```
628 ///
629 /// # Panics
630 ///
631 /// Panics if the JSON serialization fails.
632 #[cfg(feature = "json")]
633 pub fn json<T: serde::Serialize>(&mut self, data: &T) -> &mut Self {
634 self.json_data = Some(serde_json::to_string(data).expect("Failed to serialize JSON"));
635 self
636 }
637
638 /// Build the request.
639 ///
640 /// # Examples
641 ///
642 /// ```
643 /// use cot::Body;
644 /// use cot::request::Request;
645 /// use cot::response::{Response, ResponseExt};
646 /// use cot::test::TestRequestBuilder;
647 /// use http::StatusCode;
648 ///
649 /// # #[tokio::main]
650 /// # async fn main() -> cot::Result<()> {
651 /// async fn index(request: Request) -> cot::Result<Response> {
652 /// Ok(Response::new_html(
653 /// StatusCode::OK,
654 /// Body::fixed("Hello world!"),
655 /// ))
656 /// }
657 ///
658 /// let request = TestRequestBuilder::get("/").build();
659 ///
660 /// assert_eq!(
661 /// index(request).await?.into_body().into_bytes().await?,
662 /// "Hello world!"
663 /// );
664 /// # Ok(())
665 /// # }
666 /// ```
667 #[must_use]
668 pub fn build(&mut self) -> http::Request<Body> {
669 let Ok(mut request) = http::Request::builder()
670 .method(self.method.clone())
671 .uri(self.url.clone())
672 .body(Body::empty())
673 else {
674 unreachable!("Test request should be valid");
675 };
676
677 let auth_backend = std::mem::take(&mut self.auth_backend);
678 #[expect(trivial_casts)]
679 let auth_backend = match auth_backend {
680 Some(auth_backend) => Arc::new(auth_backend) as Arc<dyn AuthBackend>,
681 None => Arc::new(NoAuthBackend),
682 };
683
684 let context = ProjectContext::initialized(
685 self.config.clone().unwrap_or_default(),
686 Vec::new(),
687 Arc::new(self.router.clone().unwrap_or_else(Router::empty)),
688 auth_backend,
689 #[cfg(feature = "db")]
690 self.database.clone(),
691 );
692 prepare_request(&mut request, Arc::new(context));
693
694 if let Some(session) = &self.session {
695 request.extensions_mut().insert(session.clone());
696 }
697
698 if let Some(auth) = &self.auth {
699 request.extensions_mut().insert(auth.clone());
700 }
701
702 if let Some(form_data) = &self.form_data {
703 if self.method != http::Method::POST {
704 todo!("Form data can currently only be used with POST requests");
705 }
706
707 let mut data = form_urlencoded::Serializer::new(String::new());
708 for (key, value) in form_data {
709 data.append_pair(key, value);
710 }
711
712 *request.body_mut() = Body::fixed(data.finish());
713 request.headers_mut().insert(
714 http::header::CONTENT_TYPE,
715 http::HeaderValue::from_static("application/x-www-form-urlencoded"),
716 );
717 }
718
719 #[cfg(feature = "json")]
720 if let Some(json_data) = &self.json_data {
721 *request.body_mut() = Body::fixed(json_data.clone());
722 request.headers_mut().insert(
723 http::header::CONTENT_TYPE,
724 http::HeaderValue::from_static("application/json"),
725 );
726 }
727
728 request
729 }
730}
731
732/// A test database.
733///
734/// This is used to create a separate database for testing and run migrations on
735/// it.
736///
737/// # Examples
738///
739/// ```
740/// use cot::test::{TestDatabase, TestRequestBuilder};
741///
742/// # #[tokio::main]
743/// # async fn main() -> cot::Result<()> {
744/// let mut test_database = TestDatabase::new_sqlite().await?;
745/// let request = TestRequestBuilder::get("/")
746/// .database(test_database.database())
747/// .build();
748///
749/// // do something with the request
750///
751/// test_database.cleanup().await?;
752/// # Ok(())
753/// # }
754/// ```
755#[cfg(feature = "db")]
756#[derive(Debug)]
757pub struct TestDatabase {
758 database: Arc<Database>,
759 kind: TestDatabaseKind,
760 migrations: Vec<MigrationWrapper>,
761}
762
763#[cfg(feature = "db")]
764impl TestDatabase {
765 fn new(database: Database, kind: TestDatabaseKind) -> TestDatabase {
766 Self {
767 database: Arc::new(database),
768 kind,
769 migrations: Vec::new(),
770 }
771 }
772
773 /// Create a new in-memory SQLite database for testing.
774 ///
775 /// # Errors
776 ///
777 /// If the database could not have been created.
778 ///
779 ///
780 /// # Examples
781 ///
782 /// ```
783 /// use cot::test::{TestDatabase, TestRequestBuilder};
784 ///
785 /// # #[tokio::main]
786 /// # async fn main() -> cot::Result<()> {
787 /// let mut test_database = TestDatabase::new_sqlite().await?;
788 /// let request = TestRequestBuilder::get("/")
789 /// .database(test_database.database())
790 /// .build();
791 ///
792 /// // do something with the request
793 ///
794 /// test_database.cleanup().await?;
795 /// # Ok(())
796 /// # }
797 /// ```
798 pub async fn new_sqlite() -> Result<Self> {
799 let database = Database::new("sqlite::memory:").await?;
800 Ok(Self::new(database, TestDatabaseKind::Sqlite))
801 }
802
803 /// Create a new PostgreSQL database for testing and connects to it.
804 ///
805 /// The database URL is read from the `POSTGRES_URL` environment variable.
806 /// Note that it shouldn't include the database name — the function will
807 /// create a new database for the test by connecting to the `postgres`
808 /// database. If no URL is provided, it defaults to
809 /// `postgresql://cot:cot@localhost`.
810 ///
811 /// The database is created with the name `test_cot__{test_name}`.
812 /// Make sure that `test_name` is unique for each test so that the databases
813 /// don't conflict with each other.
814 ///
815 /// The database is dropped when `self.cleanup()` is called. Note that this
816 /// means that the database will not be dropped if the test panics.
817 ///
818 /// # Errors
819 ///
820 /// Returns an error if a database connection (either to the test database,
821 /// or postgres maintenance database) could not be established.
822 ///
823 /// Returns an error if the old test database could not be dropped.
824 ///
825 /// Returns an error if the new test database could not be created.
826 ///
827 /// # Examples
828 ///
829 /// ```no_run
830 /// use cot::test::{TestDatabase, TestRequestBuilder};
831 ///
832 /// # #[tokio::main]
833 /// # async fn main() -> cot::Result<()> {
834 /// let mut test_database = TestDatabase::new_postgres("my_test").await?;
835 /// let request = TestRequestBuilder::get("/")
836 /// .database(test_database.database())
837 /// .build();
838 ///
839 /// // do something with the request
840 ///
841 /// test_database.cleanup().await?;
842 /// # Ok(())
843 /// # }
844 /// ```
845 pub async fn new_postgres(test_name: &str) -> Result<Self> {
846 let db_url = std::env::var("POSTGRES_URL")
847 .unwrap_or_else(|_| "postgresql://cot:cot@localhost".to_string());
848 let database = Database::new(format!("{db_url}/postgres")).await?;
849
850 let test_database_name = format!("test_cot__{test_name}");
851 database
852 .raw(&format!("DROP DATABASE IF EXISTS {test_database_name}"))
853 .await?;
854 database
855 .raw(&format!("CREATE DATABASE {test_database_name}"))
856 .await?;
857 database.close().await?;
858
859 let database = Database::new(format!("{db_url}/{test_database_name}")).await?;
860
861 Ok(Self::new(
862 database,
863 TestDatabaseKind::Postgres {
864 db_url,
865 db_name: test_database_name,
866 },
867 ))
868 }
869
870 /// Create a new MySQL database for testing and connects to it.
871 ///
872 /// The database URL is read from the `MYSQL_URL` environment variable.
873 /// Note that it shouldn't include the database name — the function will
874 /// create a new database for the test by connecting to the `mysql`
875 /// database. If no URL is provided, it defaults to
876 /// `mysql://root:@localhost`.
877 ///
878 /// The database is created with the name `test_cot__{test_name}`.
879 /// Make sure that `test_name` is unique for each test so that the databases
880 /// don't conflict with each other.
881 ///
882 /// The database is dropped when `self.cleanup()` is called. Note that this
883 /// means that the database will not be dropped if the test panics.
884 ///
885 ///
886 /// # Errors
887 ///
888 /// Returns an error if a database connection (either to the test database,
889 /// or MySQL maintenance database) could not be established.
890 ///
891 /// Returns an error if the old test database could not be dropped.
892 ///
893 /// Returns an error if the new test database could not be created.
894 ///
895 /// # Examples
896 ///
897 /// ```no_run
898 /// use cot::test::{TestDatabase, TestRequestBuilder};
899 ///
900 /// # #[tokio::main]
901 /// # async fn main() -> cot::Result<()> {
902 /// let mut test_database = TestDatabase::new_mysql("my_test").await?;
903 /// let request = TestRequestBuilder::get("/")
904 /// .database(test_database.database())
905 /// .build();
906 ///
907 /// // do something with the request
908 ///
909 /// test_database.cleanup().await?;
910 /// # Ok(())
911 /// # }
912 /// ```
913 pub async fn new_mysql(test_name: &str) -> Result<Self> {
914 let db_url =
915 std::env::var("MYSQL_URL").unwrap_or_else(|_| "mysql://root:@localhost".to_string());
916 let database = Database::new(format!("{db_url}/mysql")).await?;
917
918 let test_database_name = format!("test_cot__{test_name}");
919 database
920 .raw(&format!("DROP DATABASE IF EXISTS {test_database_name}"))
921 .await?;
922 database
923 .raw(&format!("CREATE DATABASE {test_database_name}"))
924 .await?;
925 database.close().await?;
926
927 let database = Database::new(format!("{db_url}/{test_database_name}")).await?;
928
929 Ok(Self::new(
930 database,
931 TestDatabaseKind::MySql {
932 db_url,
933 db_name: test_database_name,
934 },
935 ))
936 }
937
938 /// Add the default Cot authentication migrations to the test database.
939 ///
940 /// This is useful if you want to test something that requires
941 /// authentication.
942 ///
943 /// # Examples
944 ///
945 /// ```
946 /// use cot::test::{TestDatabase, TestRequestBuilder};
947 ///
948 /// # #[tokio::main]
949 /// # async fn main() -> cot::Result<()> {
950 /// let mut test_database = TestDatabase::new_sqlite().await?;
951 /// test_database.with_auth().run_migrations().await;
952 ///
953 /// let request = TestRequestBuilder::get("/")
954 /// .with_db_auth(test_database.database())
955 /// .await
956 /// .build();
957 ///
958 /// // do something with the request
959 ///
960 /// test_database.cleanup().await?;
961 /// # Ok(())
962 /// # }
963 /// ```
964 #[cfg(feature = "db")]
965 pub fn with_auth(&mut self) -> &mut Self {
966 self.add_migrations(cot::auth::db::migrations::MIGRATIONS.to_vec());
967 self
968 }
969
970 /// Add migrations to the test database.
971 ///
972 /// # Examples
973 ///
974 /// ```
975 /// use cot::test::{TestDatabase, TestMigration};
976 ///
977 /// # #[tokio::main]
978 /// # async fn main() -> cot::Result<()> {
979 /// let mut test_database = TestDatabase::new_sqlite().await?;
980 ///
981 /// test_database.add_migrations(vec![TestMigration::new(
982 /// "auth",
983 /// "create_users",
984 /// vec![],
985 /// vec![],
986 /// )]);
987 /// # Ok(())
988 /// # }
989 /// ```
990 pub fn add_migrations<T: DynMigration + Send + Sync + 'static, V: IntoIterator<Item = T>>(
991 &mut self,
992 migrations: V,
993 ) -> &mut Self {
994 self.migrations
995 .extend(migrations.into_iter().map(MigrationWrapper::new));
996 self
997 }
998
999 /// Run the migrations on the test database.
1000 ///
1001 /// # Panics
1002 ///
1003 /// Panics if the migration engine could not be initialized or if the
1004 /// migrations could not be run.
1005 ///
1006 /// # Examples
1007 ///
1008 /// ```
1009 /// use cot::test::{TestDatabase, TestMigration};
1010 ///
1011 /// # #[tokio::main]
1012 /// # async fn main() -> cot::Result<()> {
1013 /// let mut test_database = TestDatabase::new_sqlite().await?;
1014 /// test_database.add_migrations(vec![TestMigration::new(
1015 /// "auth",
1016 /// "create_users",
1017 /// vec![],
1018 /// vec![],
1019 /// )]);
1020 ///
1021 /// test_database.run_migrations().await;
1022 /// # Ok(())
1023 /// # }
1024 /// ```
1025 pub async fn run_migrations(&mut self) -> &mut Self {
1026 if !self.migrations.is_empty() {
1027 let engine = MigrationEngine::new(std::mem::take(&mut self.migrations))
1028 .expect("Failed to initialize the migration engine");
1029 engine
1030 .run(&self.database())
1031 .await
1032 .expect("Failed to run migrations");
1033 }
1034 self
1035 }
1036
1037 /// Get the database.
1038 ///
1039 /// # Examples
1040 ///
1041 /// ```
1042 /// use cot::test::{TestDatabase, TestRequestBuilder};
1043 ///
1044 /// # #[tokio::main]
1045 /// # async fn main() -> cot::Result<()> {
1046 /// let database = TestDatabase::new_sqlite().await?;
1047 ///
1048 /// let request = TestRequestBuilder::get("/")
1049 /// .database(database.database())
1050 /// .build();
1051 /// # Ok(())
1052 /// # }
1053 /// ```
1054 #[must_use]
1055 pub fn database(&self) -> Arc<Database> {
1056 self.database.clone()
1057 }
1058
1059 /// Cleanup the test database.
1060 ///
1061 /// This removes the test database and closes the connection. Note that this
1062 /// means that the database will not be dropped if the test panics, nor will
1063 /// it be dropped if you don't call this function.
1064 ///
1065 /// # Errors
1066 ///
1067 /// Returns an error if the database could not be closed or if the database
1068 /// could not be dropped.
1069 ///
1070 /// # Examples
1071 ///
1072 /// ```
1073 /// use cot::test::TestDatabase;
1074 ///
1075 /// # #[tokio::main]
1076 /// # async fn main() -> cot::Result<()> {
1077 /// let mut test_database = TestDatabase::new_sqlite().await?;
1078 /// test_database.cleanup().await?;
1079 /// # Ok(())
1080 /// # }
1081 /// ```
1082 pub async fn cleanup(&self) -> Result<()> {
1083 self.database.close().await?;
1084 match &self.kind {
1085 TestDatabaseKind::Sqlite => {}
1086 TestDatabaseKind::Postgres { db_url, db_name } => {
1087 let database = Database::new(format!("{db_url}/postgres")).await?;
1088
1089 database
1090 .raw(&format!("DROP DATABASE {db_name} WITH (FORCE)"))
1091 .await?;
1092 database.close().await?;
1093 }
1094 TestDatabaseKind::MySql { db_url, db_name } => {
1095 let database = Database::new(format!("{db_url}/mysql")).await?;
1096
1097 database.raw(&format!("DROP DATABASE {db_name}")).await?;
1098 database.close().await?;
1099 }
1100 }
1101
1102 Ok(())
1103 }
1104}
1105
1106#[cfg(feature = "db")]
1107impl std::ops::Deref for TestDatabase {
1108 type Target = Database;
1109
1110 fn deref(&self) -> &Self::Target {
1111 &self.database
1112 }
1113}
1114
1115#[cfg(feature = "db")]
1116#[derive(Debug, Clone)]
1117enum TestDatabaseKind {
1118 Sqlite,
1119 Postgres { db_url: String, db_name: String },
1120 MySql { db_url: String, db_name: String },
1121}
1122
1123/// A test migration.
1124///
1125/// This can be used if you need a dynamically created migration for testing.
1126///
1127/// # Examples
1128///
1129/// ```
1130/// use cot::db::migrations::{Field, Operation};
1131/// use cot::db::{ColumnType, Identifier};
1132/// use cot::test::{TestDatabase, TestMigration};
1133///
1134/// const OPERATION: Operation = Operation::create_model()
1135/// .table_name(Identifier::new("myapp__users"))
1136/// .fields(&[Field::new(Identifier::new("id"), ColumnType::Integer)
1137/// .auto()
1138/// .primary_key()])
1139/// .build();
1140///
1141/// let migration = TestMigration::new("auth", "create_users", vec![], vec![OPERATION]);
1142/// ```
1143#[cfg(feature = "db")]
1144#[derive(Debug, Clone)]
1145pub struct TestMigration {
1146 app_name: &'static str,
1147 name: &'static str,
1148 dependencies: Vec<MigrationDependency>,
1149 operations: Vec<Operation>,
1150}
1151
1152#[cfg(feature = "db")]
1153impl TestMigration {
1154 /// Create a new test migration.
1155 ///
1156 /// # Examples
1157 ///
1158 /// ```
1159 /// use cot::db::migrations::{Field, Operation};
1160 /// use cot::db::{ColumnType, Identifier};
1161 /// use cot::test::{TestDatabase, TestMigration};
1162 ///
1163 /// const OPERATION: Operation = Operation::create_model()
1164 /// .table_name(Identifier::new("myapp__users"))
1165 /// .fields(&[Field::new(Identifier::new("id"), ColumnType::Integer)
1166 /// .auto()
1167 /// .primary_key()])
1168 /// .build();
1169 ///
1170 /// let migration = TestMigration::new("auth", "create_users", vec![], vec![OPERATION]);
1171 /// ```
1172 #[must_use]
1173 pub fn new<D: Into<Vec<MigrationDependency>>, O: Into<Vec<Operation>>>(
1174 app_name: &'static str,
1175 name: &'static str,
1176 dependencies: D,
1177 operations: O,
1178 ) -> Self {
1179 Self {
1180 app_name,
1181 name,
1182 dependencies: dependencies.into(),
1183 operations: operations.into(),
1184 }
1185 }
1186}
1187
1188#[cfg(feature = "db")]
1189impl DynMigration for TestMigration {
1190 fn app_name(&self) -> &str {
1191 self.app_name
1192 }
1193
1194 fn name(&self) -> &str {
1195 self.name
1196 }
1197
1198 fn dependencies(&self) -> &[MigrationDependency] {
1199 &self.dependencies
1200 }
1201
1202 fn operations(&self) -> &[Operation] {
1203 &self.operations
1204 }
1205}
1206
1207/// A guard for running tests serially.
1208///
1209/// This is mostly useful for tests that need to modify some global state (e.g.
1210/// environment variables or current working directory).
1211#[doc(hidden)] // not part of the public API; used in cot-cli
1212pub fn serial_guard() -> std::sync::MutexGuard<'static, ()> {
1213 static LOCK: std::sync::OnceLock<std::sync::Mutex<()>> = std::sync::OnceLock::new();
1214 let lock = LOCK.get_or_init(|| std::sync::Mutex::new(()));
1215 match lock.lock() {
1216 Ok(guard) => guard,
1217 Err(poison_error) => {
1218 lock.clear_poison();
1219 // We can ignore poisoned mutexes because we don't store any data inside
1220 poison_error.into_inner()
1221 }
1222 }
1223}