loco_rs/
app.rs

1//! This module contains the core components and traits for building a web
2//! server application.
3#[cfg(feature = "with-db")]
4use {sea_orm::DatabaseConnection, std::path::Path};
5
6use std::{
7    any::{Any, TypeId},
8    net::SocketAddr,
9    sync::Arc,
10};
11
12use async_trait::async_trait;
13use axum::Router as AxumRouter;
14use dashmap::DashMap;
15
16use crate::{
17    bgworker::{self, Queue},
18    boot::{shutdown_signal, BootResult, ServeParams, StartMode},
19    cache::{self},
20    config::Config,
21    controller::{
22        middleware::{self, MiddlewareLayer},
23        AppRoutes,
24    },
25    environment::Environment,
26    mailer::EmailSender,
27    storage::Storage,
28    task::Tasks,
29    Result,
30};
31
32/// Type-safe heterogeneous storage for arbitrary application data
33#[derive(Default, Debug)]
34pub struct SharedStore {
35    // Use DashMap for concurrent access with fine-grained locking
36    storage: DashMap<TypeId, Box<dyn Any + Send + Sync>>,
37}
38
39impl SharedStore {
40    /// Insert a value of type T into the shared store
41    ///
42    /// # Example
43    /// ```
44    /// # use loco_rs::app::SharedStore;
45    /// let shared_store = SharedStore::default();
46    ///
47    /// #[derive(Debug)]
48    /// struct TestService {
49    ///     name: String,
50    ///     value: i32,
51    /// }
52    ///
53    /// let service = TestService {
54    ///     name: "test".to_string(),
55    ///     value: 100,
56    /// };
57    ///
58    /// shared_store.insert(service);
59    /// assert!(shared_store.contains::<TestService>());
60    /// ```
61    pub fn insert<T: 'static + Send + Sync>(&self, val: T) {
62        self.storage.insert(TypeId::of::<T>(), Box::new(val));
63    }
64
65    /// Remove a value of type T from the shared store
66    ///
67    /// Returns `Some(T)` if the value was present and removed, `None` otherwise.
68    ///
69    /// # Example
70    /// ```
71    /// # use loco_rs::app::SharedStore;
72    /// let shared_store = SharedStore::default();
73    ///
74    /// struct TestService {
75    ///     name: String,
76    ///     value: i32,
77    /// }
78    ///
79    /// let service = TestService {
80    ///     name: "test".to_string(),
81    ///     value: 100,
82    /// };
83    ///
84    /// shared_store.insert(service);
85    /// assert!(shared_store.contains::<TestService>());
86    ///
87    /// // Remove and get the value
88    /// let removed_service_opt = shared_store.remove::<TestService>();
89    /// assert!(removed_service_opt.is_some(), "Service should be present");
90    /// // Assert fields individually instead of comparing the whole struct
91    /// if let Some(removed_service) = removed_service_opt {
92    ///      assert_eq!(removed_service.name, "test");
93    ///      assert_eq!(removed_service.value, 100);
94    /// }
95    /// // Ensure it's gone
96    /// assert!(!shared_store.contains::<TestService>());
97    ///
98    /// // Trying to remove again returns None
99    /// assert!(shared_store.remove::<TestService>().is_none());
100    /// ```
101    #[must_use]
102    pub fn remove<T: 'static + Send + Sync>(&self) -> Option<T> {
103        self.storage
104            .remove(&TypeId::of::<T>())
105            .map(|(_, v)| v) // Extract the Box<dyn Any>
106            .and_then(|any| any.downcast::<T>().ok()) // Downcast to Box<T>
107            .map(|boxed| *boxed) // Dereference the Box<T> to get T
108    }
109
110    /// Get a reference to a value of type T from the shared store.
111    ///
112    /// Returns `None` if the value doesn't exist.
113    /// The reference is valid for as long as the returned `RefGuard` is held.
114    /// If you need to clone the value, you can do so directly from the
115    /// returned reference, or use the `get` method instead.
116    ///
117    /// # Example
118    /// ```
119    /// # use loco_rs::app::SharedStore;
120    /// let shared_store = SharedStore::default();
121    ///
122    /// #[derive(Clone)]
123    /// struct TestService {
124    ///     name: String,
125    ///     value: i32,
126    /// }
127    ///
128    /// let service = TestService {
129    ///     name: "test".to_string(),
130    ///     value: 100,
131    /// };
132    ///
133    /// shared_store.insert(service);
134    ///
135    /// // Get a reference to the service
136    /// let service_ref = shared_store.get_ref::<TestService>().expect("Service not found");
137    /// // Access fields directly
138    /// assert_eq!(service_ref.name, "test");
139    /// assert_eq!(service_ref.value, 100);
140    ///
141    /// // Clone if needed (the field itself)
142    /// let name_clone = service_ref.name.clone();
143    /// assert_eq!(name_clone, "test");
144    ///
145    /// // Compute values from the reference
146    /// let name_len = service_ref.name.len();
147    /// assert_eq!(name_len, 4);
148    /// ```
149    #[must_use]
150    pub fn get_ref<T: 'static + Send + Sync>(&self) -> Option<RefGuard<'_, T>> {
151        let type_id = TypeId::of::<T>();
152        self.storage.get(&type_id).map(|r| RefGuard::<T> {
153            inner: r,
154            _phantom: std::marker::PhantomData,
155        })
156    }
157
158    /// Get a clone of a value of type T from the shared store.
159    /// Requires T to implement Clone.
160    ///
161    /// Returns `None` if the value doesn't exist.
162    /// This method clones the stored value.
163    /// If cloning is not desired or T does not implement Clone,
164    /// use `get_ref` instead.
165    ///
166    /// # Example
167    /// ```
168    /// # use loco_rs::app::SharedStore;
169    /// let shared_store = SharedStore::default();
170    ///
171    /// #[derive(Clone)]
172    /// struct TestService {
173    ///     name: String,
174    ///     value: i32,
175    /// }
176    ///
177    /// let service = TestService {
178    ///     name: "test".to_string(),
179    ///     value: 100,
180    /// };
181    ///
182    /// shared_store.insert(service);
183    ///
184    /// // Get a clone of the service
185    /// let service_clone_opt = shared_store.get::<TestService>();
186    /// assert!(service_clone_opt.is_some(), "Service not found");
187    /// // Assert fields individually
188    /// if let Some(ref service_clone) = service_clone_opt {
189    ///     assert_eq!(service_clone.name, "test");
190    ///     assert_eq!(service_clone.value, 100);
191    /// }
192    /// ```
193    #[must_use]
194    pub fn get<T: 'static + Send + Sync + Clone>(&self) -> Option<T> {
195        self.get_ref::<T>().map(|guard| (*guard).clone())
196    }
197
198    /// Check if the shared store contains a value of type T
199    ///
200    /// # Example
201    /// ```
202    /// # use loco_rs::app::SharedStore;
203    /// let shared_store = SharedStore::default();
204    ///
205    /// struct TestService {
206    ///     name: String,
207    ///     value: i32,
208    /// }
209    ///
210    /// let service = TestService {
211    ///     name: "test".to_string(),
212    ///     value: 100,
213    /// };
214    ///
215    /// shared_store.insert(service);
216    /// assert!(shared_store.contains::<TestService>());
217    /// assert!(!shared_store.contains::<String>());
218    /// ```
219    #[must_use]
220    pub fn contains<T: 'static + Send + Sync>(&self) -> bool {
221        self.storage.contains_key(&TypeId::of::<T>())
222    }
223}
224
225// A wrapper around DashMap's Ref type that erases the exact type
226// but provides deref to the target type
227pub struct RefGuard<'a, T: 'static + Send + Sync> {
228    inner: dashmap::mapref::one::Ref<'a, TypeId, Box<dyn Any + Send + Sync>>,
229    _phantom: std::marker::PhantomData<&'a T>,
230}
231
232impl<T: 'static + Send + Sync> std::ops::Deref for RefGuard<'_, T> {
233    type Target = T;
234
235    fn deref(&self) -> &Self::Target {
236        // This is safe because we only create a RefGuard for a specific type
237        // after looking it up by its TypeId
238        self.inner
239            .value()
240            .downcast_ref::<T>()
241            .expect("Type mismatch in RefGuard")
242    }
243}
244
245/// Represents the application context for a web server.
246///
247/// This struct encapsulates various components and configurations required by
248/// the web server to operate. It is typically used to store and manage shared
249/// resources and settings that are accessible throughout the application's
250/// lifetime.
251#[derive(Clone)]
252#[allow(clippy::module_name_repetitions)]
253pub struct AppContext {
254    /// The environment in which the application is running.
255    pub environment: Environment,
256    #[cfg(feature = "with-db")]
257    /// A database connection used by the application.
258    pub db: DatabaseConnection,
259    /// Queue provider
260    pub queue_provider: Option<Arc<bgworker::Queue>>,
261    /// Configuration settings for the application
262    pub config: Config,
263    /// An optional email sender component that can be used to send email.
264    pub mailer: Option<EmailSender>,
265    // An optional storage instance for the application
266    pub storage: Arc<Storage>,
267    // Cache instance for the application
268    pub cache: Arc<cache::Cache>,
269    /// Shared store for arbitrary application data
270    pub shared_store: Arc<SharedStore>,
271}
272
273/// A trait that defines hooks for customizing and extending the behavior of a
274/// web server application.
275///
276/// Users of the web server application should implement this trait to customize
277/// the application's routing, worker connections, task registration, and
278/// database actions according to their specific requirements and use cases.
279#[async_trait]
280pub trait Hooks: Send {
281    /// Defines the composite app version
282    #[must_use]
283    fn app_version() -> String {
284        "dev".to_string()
285    }
286    /// Defines the crate name
287    ///
288    /// Example
289    /// ```rust
290    /// fn app_name() -> &'static str {
291    ///     env!("CARGO_CRATE_NAME")
292    /// }
293    /// ```
294    fn app_name() -> &'static str;
295
296    /// Initializes and boots the application based on the specified mode and
297    /// environment.
298    ///
299    /// The boot initialization process may vary depending on whether a DB
300    /// migrator is used or not.
301    ///
302    /// # Examples
303    ///
304    /// With DB:
305    /// ```rust,ignore
306    /// async fn boot(mode: StartMode, environment: &str, config: Config) -> Result<BootResult> {
307    ///     create_app::<Self, Migrator>(mode, environment, config).await
308    /// }
309    /// ````
310    ///
311    /// Without DB:
312    /// ```rust,ignore
313    /// async fn boot(mode: StartMode, environment: &str, config: Config) -> Result<BootResult> {
314    ///     create_app::<Self>(mode, environment, config).await
315    /// }
316    /// ````
317    ///
318    ///
319    /// # Errors
320    /// Could not boot the application
321    async fn boot(mode: StartMode, environment: &Environment, config: Config)
322        -> Result<BootResult>;
323
324    /// Start serving the Axum web application on the specified address and
325    /// port.
326    ///
327    /// # Returns
328    /// A Result indicating success () or an error if the server fails to start.
329    async fn serve(app: AxumRouter, ctx: &AppContext, serve_params: &ServeParams) -> Result<()> {
330        let listener = tokio::net::TcpListener::bind(&format!(
331            "{}:{}",
332            serve_params.binding, serve_params.port
333        ))
334        .await?;
335
336        let cloned_ctx = ctx.clone();
337        axum::serve(
338            listener,
339            app.into_make_service_with_connect_info::<SocketAddr>(),
340        )
341        .with_graceful_shutdown(async move {
342            shutdown_signal().await;
343            tracing::info!("shutting down...");
344            Self::on_shutdown(&cloned_ctx).await;
345        })
346        .await?;
347
348        Ok(())
349    }
350
351    /// Override and return `Ok(true)` to provide an alternative logging and
352    /// tracing stack of your own.
353    /// When returning `Ok(true)`, Loco will *not* initialize its own logger,
354    /// so you should set up a complete tracing and logging stack.
355    ///
356    /// # Errors
357    /// If fails returns an error
358    fn init_logger(_ctx: &AppContext) -> Result<bool> {
359        Ok(false)
360    }
361
362    /// Loads the configuration settings for the application based on the given environment.
363    ///
364    /// This function is responsible for retrieving the configuration for the application
365    /// based on the current environment.
366    async fn load_config(env: &Environment) -> Result<Config> {
367        env.load()
368    }
369
370    /// Returns the initial Axum router for the application, allowing the user
371    /// to control the construction of the Axum router. This is where a fallback
372    /// handler can be installed before middleware or other routes are added.
373    ///
374    /// # Errors
375    /// Return an [`Result`] when the router could not be created
376    async fn before_routes(_ctx: &AppContext) -> Result<AxumRouter<AppContext>> {
377        Ok(AxumRouter::new())
378    }
379
380    /// Invoke this function after the Loco routers have been constructed. This
381    /// function enables you to configure custom Axum logics, such as layers,
382    /// that are compatible with Axum.
383    ///
384    /// # Errors
385    /// Axum router error
386    async fn after_routes(router: AxumRouter, _ctx: &AppContext) -> Result<AxumRouter> {
387        Ok(router)
388    }
389
390    /// Provide a list of initializers
391    /// An initializer can be used to seamlessly add functionality to your app
392    /// or to initialize some aspects of it.
393    async fn initializers(_ctx: &AppContext) -> Result<Vec<Box<dyn Initializer>>> {
394        Ok(vec![])
395    }
396
397    /// Provide a list of middlewares
398    #[must_use]
399    fn middlewares(ctx: &AppContext) -> Vec<Box<dyn MiddlewareLayer>> {
400        middleware::default_middleware_stack(ctx)
401    }
402
403    /// Calling the function before run the app
404    /// You can now code some custom loading of resources or other things before
405    /// the app runs
406    async fn before_run(_app_context: &AppContext) -> Result<()> {
407        Ok(())
408    }
409
410    /// Defines the application's routing configuration.
411    fn routes(_ctx: &AppContext) -> AppRoutes;
412
413    // Provides the options to change Loco [`AppContext`] after initialization.
414    async fn after_context(ctx: AppContext) -> Result<AppContext> {
415        Ok(ctx)
416    }
417
418    /// Connects custom workers to the application using the provided
419    /// [`Processor`] and [`AppContext`].
420    async fn connect_workers(ctx: &AppContext, queue: &Queue) -> Result<()>;
421
422    /// Registers custom tasks with the provided [`Tasks`] object.
423    fn register_tasks(tasks: &mut Tasks);
424
425    /// Truncates the database as required. Users should implement this
426    /// function. The truncate controlled from the [`crate::config::Database`]
427    /// by changing dangerously_truncate to true (default false).
428    /// Truncate can be useful when you want to truncate the database before any
429    /// test.
430    #[cfg(feature = "with-db")]
431    async fn truncate(_ctx: &AppContext) -> Result<()>;
432
433    /// Seeds the database with initial data.
434    #[cfg(feature = "with-db")]
435    async fn seed(_ctx: &AppContext, path: &Path) -> Result<()>;
436
437    /// Called when the application is shutting down.
438    /// This function allows users to perform any necessary cleanup or final
439    /// actions before the application stops completely.
440    async fn on_shutdown(_ctx: &AppContext) {}
441}
442
443/// An initializer.
444/// Initializers should be kept in `src/initializers/`
445///
446/// Initializers can provide health checks by implementing the `check` method.
447/// These checks will be run during the `cargo loco doctor` command to validate
448/// the initializer's configuration and test its connections.
449#[async_trait]
450// <snip id="initializers-trait">
451pub trait Initializer: Sync + Send {
452    /// The initializer name or identifier
453    fn name(&self) -> String;
454
455    /// Occurs after the app's `before_run`.
456    /// Use this to for one-time initializations, load caches, perform web
457    /// hooks, etc.
458    async fn before_run(&self, _app_context: &AppContext) -> Result<()> {
459        Ok(())
460    }
461
462    /// Occurs after the app's `after_routes`.
463    /// Use this to compose additional functionality and wire it into an Axum
464    /// Router
465    async fn after_routes(&self, router: AxumRouter, _ctx: &AppContext) -> Result<AxumRouter> {
466        Ok(router)
467    }
468
469    /// Perform health checks for this initializer.
470    /// This method is called during the doctor command to validate the initializer's configuration.
471    /// Return `None` if no check is needed, or `Some(Check)` if a check should be performed.
472    async fn check(&self, _app_context: &AppContext) -> Result<Option<crate::doctor::Check>> {
473        Ok(None)
474    }
475}
476// </snip>
477
478#[cfg(test)]
479mod tests {
480    use super::*;
481    use crate::tests_cfg::app::get_app_context;
482
483    struct TestService {
484        name: String,
485        value: i32,
486    }
487
488    #[derive(Clone)]
489    struct CloneableTestService {
490        name: String,
491        value: i32,
492    }
493
494    #[test]
495    fn test_extensions_insert_and_get() {
496        // Setup
497        let shared_store = SharedStore::default();
498
499        shared_store.insert(42i32);
500        assert_eq!(shared_store.get::<i32>().expect("Value should exist"), 42);
501
502        let service = TestService {
503            name: "test".to_string(),
504            value: 100,
505        };
506
507        shared_store.insert(service);
508
509        let service_ref_opt = shared_store.get_ref::<TestService>();
510        assert!(service_ref_opt.is_some(), "Service ref should exist");
511        if let Some(service_ref) = service_ref_opt {
512            assert_eq!(service_ref.name, "test");
513            assert_eq!(service_ref.value, 100);
514            let name_clone = service_ref.name.clone();
515            assert_eq!(name_clone, "test");
516        } else {
517            panic!("Should have gotten Some(service_ref)");
518        }
519    }
520
521    #[test]
522    fn test_extensions_get_without_clone() {
523        let shared_store = SharedStore::default();
524
525        let service = TestService {
526            name: "test_direct".to_string(),
527            value: 100,
528        };
529        shared_store.insert(service);
530
531        let service_ref_opt = shared_store.get_ref::<TestService>();
532        assert!(service_ref_opt.is_some(), "Service ref should exist");
533        if let Some(service_ref) = service_ref_opt {
534            assert_eq!(service_ref.name, "test_direct");
535            assert_eq!(service_ref.value, 100);
536        } else {
537            panic!("Should have gotten Some(service_ref)");
538        }
539
540        let name_len_opt = shared_store.get_ref::<TestService>().map(|r| r.name.len());
541        assert!(
542            name_len_opt.is_some(),
543            "Service ref should exist for len check"
544        );
545        assert_eq!(name_len_opt.unwrap(), 11);
546
547        let value_opt = shared_store.get_ref::<TestService>().map(|r| r.value);
548        assert!(
549            value_opt.is_some(),
550            "Service ref should exist for value check"
551        );
552        assert_eq!(value_opt.unwrap(), 100);
553    }
554
555    #[test]
556    fn test_extensions_remove() {
557        let shared_store = SharedStore::default();
558
559        shared_store.insert(42i32);
560        assert!(shared_store.contains::<i32>());
561        assert_eq!(shared_store.remove::<i32>(), Some(42));
562        assert!(!shared_store.contains::<i32>());
563        assert_eq!(shared_store.remove::<i32>(), None);
564
565        let service = TestService {
566            name: "rem".to_string(),
567            value: 50,
568        };
569        shared_store.insert(service);
570        assert!(shared_store.contains::<TestService>());
571        let removed_opt = shared_store.remove::<TestService>();
572        assert!(removed_opt.is_some());
573        if let Some(removed) = removed_opt {
574            assert_eq!(removed.name, "rem");
575            assert_eq!(removed.value, 50);
576        } else {
577            panic!("Removed option should be Some");
578        }
579        assert!(!shared_store.contains::<TestService>());
580        assert!(shared_store.remove::<TestService>().is_none());
581    }
582
583    #[test]
584    fn test_extensions_contains() {
585        let shared_store = SharedStore::default();
586
587        shared_store.insert(42i32);
588        shared_store.insert(TestService {
589            name: "contains".to_string(),
590            value: 1,
591        });
592
593        assert!(shared_store.contains::<i32>());
594        assert!(shared_store.contains::<TestService>());
595        assert!(!shared_store.contains::<String>());
596        assert!(!shared_store.contains::<CloneableTestService>());
597    }
598
599    #[test]
600    fn test_extensions_get_cloned() {
601        let shared_store = SharedStore::default();
602
603        shared_store.insert(42i32);
604        assert_eq!(shared_store.get::<i32>(), Some(42));
605        assert!(shared_store.contains::<i32>());
606
607        let service = CloneableTestService {
608            name: "cloned_test".to_string(),
609            value: 200,
610        };
611        shared_store.insert(service.clone());
612
613        let service_clone_opt = shared_store.get::<CloneableTestService>();
614        assert!(service_clone_opt.is_some(), "Cloned service should exist");
615        if let Some(ref service_clone) = service_clone_opt {
616            assert_eq!(service_clone.name, "cloned_test");
617            assert_eq!(service_clone.value, 200);
618        } else {
619            panic!("Should have gotten Some(service_clone)");
620        }
621
622        assert!(shared_store.contains::<CloneableTestService>());
623        let original_ref_opt = shared_store.get_ref::<CloneableTestService>();
624        assert!(original_ref_opt.is_some(), "Original ref should exist");
625        if let Some(original_ref) = original_ref_opt {
626            assert_eq!(original_ref.name, "cloned_test");
627            assert_eq!(original_ref.value, 200);
628        } else {
629            panic!("Should have gotten Some(original_ref)");
630        }
631
632        assert_eq!(shared_store.get::<String>(), None);
633        assert!(shared_store.get::<CloneableTestService>().is_some());
634        // The following line correctly fails to compile because TestService doesn't impl Clone,
635        // which is required by the `get` method.
636        // let non_existent_clone = shared_store.get::<TestService>();
637    }
638
639    #[tokio::test]
640    async fn test_app_context_extensions() {
641        let ctx = get_app_context().await;
642
643        let service_cloneable = CloneableTestService {
644            name: "app_context_test_cloneable".to_string(),
645            value: 42,
646        };
647        ctx.shared_store.insert(service_cloneable.clone());
648
649        let ref_opt = ctx.shared_store.get_ref::<CloneableTestService>();
650        assert!(ref_opt.is_some(), "Cloneable service ref should exist");
651        if let Some(service_ref) = ref_opt {
652            assert_eq!(service_ref.name, "app_context_test_cloneable");
653            assert_eq!(service_ref.value, 42);
654        } else {
655            panic!("Should have gotten Some(service_ref)");
656        }
657
658        let clone_opt = ctx.shared_store.get::<CloneableTestService>();
659        assert!(clone_opt.is_some(), "Should get cloned service");
660        if let Some(service_clone) = clone_opt {
661            assert_eq!(service_clone.name, "app_context_test_cloneable");
662            assert_eq!(service_clone.value, 42);
663        } else {
664            panic!("Should have gotten Some(service_clone)");
665        }
666
667        assert!(ctx.shared_store.contains::<CloneableTestService>());
668        assert!(!ctx.shared_store.contains::<String>());
669
670        let removed_cloneable_opt = ctx.shared_store.remove::<CloneableTestService>();
671        assert!(removed_cloneable_opt.is_some());
672        if let Some(removed) = removed_cloneable_opt {
673            assert_eq!(removed.name, "app_context_test_cloneable");
674            assert_eq!(removed.value, 42);
675        } else {
676            panic!("Removed cloneable option should be Some");
677        }
678        assert!(!ctx.shared_store.contains::<CloneableTestService>());
679
680        let service_non_cloneable = TestService {
681            name: "app_context_test_non_cloneable".to_string(),
682            value: 99,
683        };
684        ctx.shared_store.insert(service_non_cloneable);
685
686        let non_clone_ref_opt = ctx.shared_store.get_ref::<TestService>();
687        assert!(
688            non_clone_ref_opt.is_some(),
689            "Non-cloneable service ref should exist"
690        );
691        if let Some(service_ref) = non_clone_ref_opt {
692            assert_eq!(service_ref.name, "app_context_test_non_cloneable");
693            assert_eq!(service_ref.value, 99);
694        } else {
695            panic!("Should have gotten Some(service_ref)");
696        }
697
698        assert!(ctx.shared_store.contains::<TestService>());
699
700        let removed_non_cloneable_opt = ctx.shared_store.remove::<TestService>();
701        assert!(removed_non_cloneable_opt.is_some());
702        if let Some(removed) = removed_non_cloneable_opt {
703            assert_eq!(removed.name, "app_context_test_non_cloneable");
704            assert_eq!(removed.value, 99);
705        } else {
706            panic!("Removed non-cloneable option should be Some");
707        }
708        assert!(!ctx.shared_store.contains::<TestService>());
709    }
710}