loco_rs/
app.rs

1//! This module contains the core components and traits for building a web
2//! server application.
3#[cfg(feature = "with-db")]
4use {sea_orm::DatabaseConnection, std::path::Path};
5
6use std::{
7    any::{Any, TypeId},
8    net::SocketAddr,
9    sync::Arc,
10};
11
12use async_trait::async_trait;
13use axum::Router as AxumRouter;
14use dashmap::DashMap;
15
16use crate::{
17    bgworker::{self, Queue},
18    boot::{shutdown_signal, BootResult, ServeParams, StartMode},
19    cache::{self},
20    config::Config,
21    controller::{
22        middleware::{self, MiddlewareLayer},
23        AppRoutes,
24    },
25    environment::Environment,
26    mailer::EmailSender,
27    storage::Storage,
28    task::Tasks,
29    Result,
30};
31
32/// Type-safe heterogeneous storage for arbitrary application data
33#[derive(Default, Debug)]
34pub struct SharedStore {
35    // Use DashMap for concurrent access with fine-grained locking
36    storage: DashMap<TypeId, Box<dyn Any + Send + Sync>>,
37}
38
39impl SharedStore {
40    /// Insert a value of type T into the shared store
41    ///
42    /// # Example
43    /// ```
44    /// # use loco_rs::app::SharedStore;
45    /// let shared_store = SharedStore::default();
46    ///
47    /// #[derive(Debug)]
48    /// struct TestService {
49    ///     name: String,
50    ///     value: i32,
51    /// }
52    ///
53    /// let service = TestService {
54    ///     name: "test".to_string(),
55    ///     value: 100,
56    /// };
57    ///
58    /// shared_store.insert(service);
59    /// assert!(shared_store.contains::<TestService>());
60    /// ```
61    pub fn insert<T: 'static + Send + Sync>(&self, val: T) {
62        self.storage.insert(TypeId::of::<T>(), Box::new(val));
63    }
64
65    /// Remove a value of type T from the shared store
66    ///
67    /// Returns `Some(T)` if the value was present and removed, `None` otherwise.
68    ///
69    /// # Example
70    /// ```
71    /// # use loco_rs::app::SharedStore;
72    /// let shared_store = SharedStore::default();
73    ///
74    /// struct TestService {
75    ///     name: String,
76    ///     value: i32,
77    /// }
78    ///
79    /// let service = TestService {
80    ///     name: "test".to_string(),
81    ///     value: 100,
82    /// };
83    ///
84    /// shared_store.insert(service);
85    /// assert!(shared_store.contains::<TestService>());
86    ///
87    /// // Remove and get the value
88    /// let removed_service_opt = shared_store.remove::<TestService>();
89    /// assert!(removed_service_opt.is_some(), "Service should be present");
90    /// // Assert fields individually instead of comparing the whole struct
91    /// if let Some(removed_service) = removed_service_opt {
92    ///      assert_eq!(removed_service.name, "test");
93    ///      assert_eq!(removed_service.value, 100);
94    /// }
95    /// // Ensure it's gone
96    /// assert!(!shared_store.contains::<TestService>());
97    ///
98    /// // Trying to remove again returns None
99    /// assert!(shared_store.remove::<TestService>().is_none());
100    /// ```
101    #[must_use]
102    pub fn remove<T: 'static + Send + Sync>(&self) -> Option<T> {
103        self.storage
104            .remove(&TypeId::of::<T>())
105            .map(|(_, v)| v) // Extract the Box<dyn Any>
106            .and_then(|any| any.downcast::<T>().ok()) // Downcast to Box<T>
107            .map(|boxed| *boxed) // Dereference the Box<T> to get T
108    }
109
110    /// Get a reference to a value of type T from the shared store.
111    ///
112    /// Returns `None` if the value doesn't exist.
113    /// The reference is valid for as long as the returned `RefGuard` is held.
114    /// If you need to clone the value, you can do so directly from the
115    /// returned reference, or use the `get` method instead.
116    ///
117    /// # Example
118    /// ```
119    /// # use loco_rs::app::SharedStore;
120    /// let shared_store = SharedStore::default();
121    ///
122    /// #[derive(Clone)]
123    /// struct TestService {
124    ///     name: String,
125    ///     value: i32,
126    /// }
127    ///
128    /// let service = TestService {
129    ///     name: "test".to_string(),
130    ///     value: 100,
131    /// };
132    ///
133    /// shared_store.insert(service);
134    ///
135    /// // Get a reference to the service
136    /// let service_ref = shared_store.get_ref::<TestService>().expect("Service not found");
137    /// // Access fields directly
138    /// assert_eq!(service_ref.name, "test");
139    /// assert_eq!(service_ref.value, 100);
140    ///
141    /// // Clone if needed (the field itself)
142    /// let name_clone = service_ref.name.clone();
143    /// assert_eq!(name_clone, "test");
144    ///
145    /// // Compute values from the reference
146    /// let name_len = service_ref.name.len();
147    /// assert_eq!(name_len, 4);
148    /// ```
149    #[must_use]
150    pub fn get_ref<T: 'static + Send + Sync>(&self) -> Option<RefGuard<'_, T>> {
151        let type_id = TypeId::of::<T>();
152        self.storage.get(&type_id).map(|r| RefGuard::<T> {
153            inner: r,
154            _phantom: std::marker::PhantomData,
155        })
156    }
157
158    /// Get a clone of a value of type T from the shared store.
159    /// Requires T to implement Clone.
160    ///
161    /// Returns `None` if the value doesn't exist.
162    /// This method clones the stored value.
163    /// If cloning is not desired or T does not implement Clone,
164    /// use `get_ref` instead.
165    ///
166    /// # Example
167    /// ```
168    /// # use loco_rs::app::SharedStore;
169    /// let shared_store = SharedStore::default();
170    ///
171    /// #[derive(Clone)]
172    /// struct TestService {
173    ///     name: String,
174    ///     value: i32,
175    /// }
176    ///
177    /// let service = TestService {
178    ///     name: "test".to_string(),
179    ///     value: 100,
180    /// };
181    ///
182    /// shared_store.insert(service);
183    ///
184    /// // Get a clone of the service
185    /// let service_clone_opt = shared_store.get::<TestService>();
186    /// assert!(service_clone_opt.is_some(), "Service not found");
187    /// // Assert fields individually
188    /// if let Some(ref service_clone) = service_clone_opt {
189    ///     assert_eq!(service_clone.name, "test");
190    ///     assert_eq!(service_clone.value, 100);
191    /// }
192    /// ```
193    #[must_use]
194    pub fn get<T: 'static + Send + Sync + Clone>(&self) -> Option<T> {
195        self.get_ref::<T>().map(|guard| (*guard).clone())
196    }
197
198    /// Check if the shared store contains a value of type T
199    ///
200    /// # Example
201    /// ```
202    /// # use loco_rs::app::SharedStore;
203    /// let shared_store = SharedStore::default();
204    ///
205    /// struct TestService {
206    ///     name: String,
207    ///     value: i32,
208    /// }
209    ///
210    /// let service = TestService {
211    ///     name: "test".to_string(),
212    ///     value: 100,
213    /// };
214    ///
215    /// shared_store.insert(service);
216    /// assert!(shared_store.contains::<TestService>());
217    /// assert!(!shared_store.contains::<String>());
218    /// ```
219    #[must_use]
220    pub fn contains<T: 'static + Send + Sync>(&self) -> bool {
221        self.storage.contains_key(&TypeId::of::<T>())
222    }
223}
224
225// A wrapper around DashMap's Ref type that erases the exact type
226// but provides deref to the target type
227pub struct RefGuard<'a, T: 'static + Send + Sync> {
228    inner: dashmap::mapref::one::Ref<'a, TypeId, Box<dyn Any + Send + Sync>>,
229    _phantom: std::marker::PhantomData<&'a T>,
230}
231
232impl<T: 'static + Send + Sync> std::ops::Deref for RefGuard<'_, T> {
233    type Target = T;
234
235    fn deref(&self) -> &Self::Target {
236        // This is safe because we only create a RefGuard for a specific type
237        // after looking it up by its TypeId
238        #[allow(clippy::coerce_container_to_any)]
239        self.inner
240            .value()
241            .downcast_ref::<T>()
242            .expect("Type mismatch in RefGuard")
243    }
244}
245
246/// Represents the application context for a web server.
247///
248/// This struct encapsulates various components and configurations required by
249/// the web server to operate. It is typically used to store and manage shared
250/// resources and settings that are accessible throughout the application's
251/// lifetime.
252#[derive(Clone)]
253#[allow(clippy::module_name_repetitions)]
254pub struct AppContext {
255    /// The environment in which the application is running.
256    pub environment: Environment,
257    #[cfg(feature = "with-db")]
258    /// A database connection used by the application.
259    pub db: DatabaseConnection,
260    /// Queue provider
261    pub queue_provider: Option<Arc<bgworker::Queue>>,
262    /// Configuration settings for the application
263    pub config: Config,
264    /// An optional email sender component that can be used to send email.
265    pub mailer: Option<EmailSender>,
266    // An optional storage instance for the application
267    pub storage: Arc<Storage>,
268    // Cache instance for the application
269    pub cache: Arc<cache::Cache>,
270    /// Shared store for arbitrary application data
271    pub shared_store: Arc<SharedStore>,
272}
273
274/// A trait that defines hooks for customizing and extending the behavior of a
275/// web server application.
276///
277/// Users of the web server application should implement this trait to customize
278/// the application's routing, worker connections, task registration, and
279/// database actions according to their specific requirements and use cases.
280#[async_trait]
281pub trait Hooks: Send {
282    /// Defines the composite app version
283    #[must_use]
284    fn app_version() -> String {
285        "dev".to_string()
286    }
287    /// Defines the crate name
288    ///
289    /// Example
290    /// ```rust
291    /// fn app_name() -> &'static str {
292    ///     env!("CARGO_CRATE_NAME")
293    /// }
294    /// ```
295    fn app_name() -> &'static str;
296
297    /// Initializes and boots the application based on the specified mode and
298    /// environment.
299    ///
300    /// The boot initialization process may vary depending on whether a DB
301    /// migrator is used or not.
302    ///
303    /// # Examples
304    ///
305    /// With DB:
306    /// ```rust,ignore
307    /// async fn boot(mode: StartMode, environment: &str, config: Config) -> Result<BootResult> {
308    ///     create_app::<Self, Migrator>(mode, environment, config).await
309    /// }
310    /// ````
311    ///
312    /// Without DB:
313    /// ```rust,ignore
314    /// async fn boot(mode: StartMode, environment: &str, config: Config) -> Result<BootResult> {
315    ///     create_app::<Self>(mode, environment, config).await
316    /// }
317    /// ````
318    ///
319    ///
320    /// # Errors
321    /// Could not boot the application
322    async fn boot(mode: StartMode, environment: &Environment, config: Config)
323        -> Result<BootResult>;
324
325    /// Start serving the Axum web application on the specified address and
326    /// port.
327    ///
328    /// # Returns
329    /// A Result indicating success () or an error if the server fails to start.
330    async fn serve(app: AxumRouter, ctx: &AppContext, serve_params: &ServeParams) -> Result<()> {
331        let listener = tokio::net::TcpListener::bind(&format!(
332            "{}:{}",
333            serve_params.binding, serve_params.port
334        ))
335        .await?;
336
337        let cloned_ctx = ctx.clone();
338        axum::serve(
339            listener,
340            app.into_make_service_with_connect_info::<SocketAddr>(),
341        )
342        .with_graceful_shutdown(async move {
343            shutdown_signal().await;
344            tracing::info!("shutting down...");
345            Self::on_shutdown(&cloned_ctx).await;
346        })
347        .await?;
348
349        Ok(())
350    }
351
352    /// Override and return `Ok(true)` to provide an alternative logging and
353    /// tracing stack of your own.
354    /// When returning `Ok(true)`, Loco will *not* initialize its own logger,
355    /// so you should set up a complete tracing and logging stack.
356    ///
357    /// # Errors
358    /// If fails returns an error
359    fn init_logger(_ctx: &AppContext) -> Result<bool> {
360        Ok(false)
361    }
362
363    /// Loads the configuration settings for the application based on the given environment.
364    ///
365    /// This function is responsible for retrieving the configuration for the application
366    /// based on the current environment.
367    async fn load_config(env: &Environment) -> Result<Config> {
368        env.load()
369    }
370
371    /// Returns the initial Axum router for the application, allowing the user
372    /// to control the construction of the Axum router. This is where a fallback
373    /// handler can be installed before middleware or other routes are added.
374    ///
375    /// # Errors
376    /// Return an [`Result`] when the router could not be created
377    async fn before_routes(_ctx: &AppContext) -> Result<AxumRouter<AppContext>> {
378        Ok(AxumRouter::new())
379    }
380
381    /// Invoke this function after the Loco routers have been constructed. This
382    /// function enables you to configure custom Axum logics, such as layers,
383    /// that are compatible with Axum.
384    ///
385    /// # Errors
386    /// Axum router error
387    async fn after_routes(router: AxumRouter, _ctx: &AppContext) -> Result<AxumRouter> {
388        Ok(router)
389    }
390
391    /// Provide a list of initializers
392    /// An initializer can be used to seamlessly add functionality to your app
393    /// or to initialize some aspects of it.
394    async fn initializers(_ctx: &AppContext) -> Result<Vec<Box<dyn Initializer>>> {
395        Ok(vec![])
396    }
397
398    /// Provide a list of middlewares
399    #[must_use]
400    fn middlewares(ctx: &AppContext) -> Vec<Box<dyn MiddlewareLayer>> {
401        middleware::default_middleware_stack(ctx)
402    }
403
404    /// Calling the function before run the app
405    /// You can now code some custom loading of resources or other things before
406    /// the app runs
407    async fn before_run(_app_context: &AppContext) -> Result<()> {
408        Ok(())
409    }
410
411    /// Defines the application's routing configuration.
412    fn routes(_ctx: &AppContext) -> AppRoutes;
413
414    // Provides the options to change Loco [`AppContext`] after initialization.
415    async fn after_context(ctx: AppContext) -> Result<AppContext> {
416        Ok(ctx)
417    }
418
419    /// Connects custom workers to the application using the provided
420    /// [`Processor`] and [`AppContext`].
421    async fn connect_workers(ctx: &AppContext, queue: &Queue) -> Result<()>;
422
423    /// Registers custom tasks with the provided [`Tasks`] object.
424    fn register_tasks(tasks: &mut Tasks);
425
426    /// Truncates the database as required. Users should implement this
427    /// function. The truncate controlled from the [`crate::config::Database`]
428    /// by changing `dangerously_truncate` to true (default false).
429    /// Truncate can be useful when you want to truncate the database before any
430    /// test.
431    #[cfg(feature = "with-db")]
432    async fn truncate(_ctx: &AppContext) -> Result<()>;
433
434    /// Seeds the database with initial data.
435    #[cfg(feature = "with-db")]
436    async fn seed(_ctx: &AppContext, path: &Path) -> Result<()>;
437
438    /// Called when the application is shutting down.
439    /// This function allows users to perform any necessary cleanup or final
440    /// actions before the application stops completely.
441    async fn on_shutdown(_ctx: &AppContext) {}
442}
443
444/// An initializer.
445/// Initializers should be kept in `src/initializers/`
446///
447/// Initializers can provide health checks by implementing the `check` method.
448/// These checks will be run during the `cargo loco doctor` command to validate
449/// the initializer's configuration and test its connections.
450#[async_trait]
451// <snip id="initializers-trait">
452pub trait Initializer: Sync + Send {
453    /// The initializer name or identifier
454    fn name(&self) -> String;
455
456    /// Occurs after the app's `before_run`.
457    /// Use this to for one-time initializations, load caches, perform web
458    /// hooks, etc.
459    async fn before_run(&self, _app_context: &AppContext) -> Result<()> {
460        Ok(())
461    }
462
463    /// Occurs after the app's `after_routes`.
464    /// Use this to compose additional functionality and wire it into an Axum
465    /// Router
466    async fn after_routes(&self, router: AxumRouter, _ctx: &AppContext) -> Result<AxumRouter> {
467        Ok(router)
468    }
469
470    /// Perform health checks for this initializer.
471    /// This method is called during the doctor command to validate the initializer's configuration.
472    /// Return `None` if no check is needed, or `Some(Check)` if a check should be performed.
473    async fn check(&self, _app_context: &AppContext) -> Result<Option<crate::doctor::Check>> {
474        Ok(None)
475    }
476}
477// </snip>
478
479#[cfg(test)]
480mod tests {
481    use super::*;
482    use crate::tests_cfg::app::get_app_context;
483
484    struct TestService {
485        name: String,
486        value: i32,
487    }
488
489    #[derive(Clone)]
490    struct CloneableTestService {
491        name: String,
492        value: i32,
493    }
494
495    #[test]
496    fn test_extensions_insert_and_get() {
497        // Setup
498        let shared_store = SharedStore::default();
499
500        shared_store.insert(42i32);
501        assert_eq!(shared_store.get::<i32>().expect("Value should exist"), 42);
502
503        let service = TestService {
504            name: "test".to_string(),
505            value: 100,
506        };
507
508        shared_store.insert(service);
509
510        let service_ref_opt = shared_store.get_ref::<TestService>();
511        assert!(service_ref_opt.is_some(), "Service ref should exist");
512        if let Some(service_ref) = service_ref_opt {
513            assert_eq!(service_ref.name, "test");
514            assert_eq!(service_ref.value, 100);
515            let name_clone = service_ref.name.clone();
516            assert_eq!(name_clone, "test");
517        } else {
518            panic!("Should have gotten Some(service_ref)");
519        }
520    }
521
522    #[test]
523    fn test_extensions_get_without_clone() {
524        let shared_store = SharedStore::default();
525
526        let service = TestService {
527            name: "test_direct".to_string(),
528            value: 100,
529        };
530        shared_store.insert(service);
531
532        let service_ref_opt = shared_store.get_ref::<TestService>();
533        assert!(service_ref_opt.is_some(), "Service ref should exist");
534        if let Some(service_ref) = service_ref_opt {
535            assert_eq!(service_ref.name, "test_direct");
536            assert_eq!(service_ref.value, 100);
537        } else {
538            panic!("Should have gotten Some(service_ref)");
539        }
540
541        let name_len_opt = shared_store.get_ref::<TestService>().map(|r| r.name.len());
542        assert!(
543            name_len_opt.is_some(),
544            "Service ref should exist for len check"
545        );
546        assert_eq!(name_len_opt.unwrap(), 11);
547
548        let value_opt = shared_store.get_ref::<TestService>().map(|r| r.value);
549        assert!(
550            value_opt.is_some(),
551            "Service ref should exist for value check"
552        );
553        assert_eq!(value_opt.unwrap(), 100);
554    }
555
556    #[test]
557    fn test_extensions_remove() {
558        let shared_store = SharedStore::default();
559
560        shared_store.insert(42i32);
561        assert!(shared_store.contains::<i32>());
562        assert_eq!(shared_store.remove::<i32>(), Some(42));
563        assert!(!shared_store.contains::<i32>());
564        assert_eq!(shared_store.remove::<i32>(), None);
565
566        let service = TestService {
567            name: "rem".to_string(),
568            value: 50,
569        };
570        shared_store.insert(service);
571        assert!(shared_store.contains::<TestService>());
572        let removed_opt = shared_store.remove::<TestService>();
573        assert!(removed_opt.is_some());
574        if let Some(removed) = removed_opt {
575            assert_eq!(removed.name, "rem");
576            assert_eq!(removed.value, 50);
577        } else {
578            panic!("Removed option should be Some");
579        }
580        assert!(!shared_store.contains::<TestService>());
581        assert!(shared_store.remove::<TestService>().is_none());
582    }
583
584    #[test]
585    fn test_extensions_contains() {
586        let shared_store = SharedStore::default();
587
588        shared_store.insert(42i32);
589        shared_store.insert(TestService {
590            name: "contains".to_string(),
591            value: 1,
592        });
593
594        assert!(shared_store.contains::<i32>());
595        assert!(shared_store.contains::<TestService>());
596        assert!(!shared_store.contains::<String>());
597        assert!(!shared_store.contains::<CloneableTestService>());
598    }
599
600    #[test]
601    fn test_extensions_get_cloned() {
602        let shared_store = SharedStore::default();
603
604        shared_store.insert(42i32);
605        assert_eq!(shared_store.get::<i32>(), Some(42));
606        assert!(shared_store.contains::<i32>());
607
608        let service = CloneableTestService {
609            name: "cloned_test".to_string(),
610            value: 200,
611        };
612        shared_store.insert(service.clone());
613
614        let service_clone_opt = shared_store.get::<CloneableTestService>();
615        assert!(service_clone_opt.is_some(), "Cloned service should exist");
616        if let Some(ref service_clone) = service_clone_opt {
617            assert_eq!(service_clone.name, "cloned_test");
618            assert_eq!(service_clone.value, 200);
619        } else {
620            panic!("Should have gotten Some(service_clone)");
621        }
622
623        assert!(shared_store.contains::<CloneableTestService>());
624        let original_ref_opt = shared_store.get_ref::<CloneableTestService>();
625        assert!(original_ref_opt.is_some(), "Original ref should exist");
626        if let Some(original_ref) = original_ref_opt {
627            assert_eq!(original_ref.name, "cloned_test");
628            assert_eq!(original_ref.value, 200);
629        } else {
630            panic!("Should have gotten Some(original_ref)");
631        }
632
633        assert_eq!(shared_store.get::<String>(), None);
634        assert!(shared_store.get::<CloneableTestService>().is_some());
635        // The following line correctly fails to compile because TestService doesn't impl Clone,
636        // which is required by the `get` method.
637        // let non_existent_clone = shared_store.get::<TestService>();
638    }
639
640    #[tokio::test]
641    async fn test_app_context_extensions() {
642        let ctx = get_app_context().await;
643
644        let service_cloneable = CloneableTestService {
645            name: "app_context_test_cloneable".to_string(),
646            value: 42,
647        };
648        ctx.shared_store.insert(service_cloneable.clone());
649
650        let ref_opt = ctx.shared_store.get_ref::<CloneableTestService>();
651        assert!(ref_opt.is_some(), "Cloneable service ref should exist");
652        if let Some(service_ref) = ref_opt {
653            assert_eq!(service_ref.name, "app_context_test_cloneable");
654            assert_eq!(service_ref.value, 42);
655        } else {
656            panic!("Should have gotten Some(service_ref)");
657        }
658
659        let clone_opt = ctx.shared_store.get::<CloneableTestService>();
660        assert!(clone_opt.is_some(), "Should get cloned service");
661        if let Some(service_clone) = clone_opt {
662            assert_eq!(service_clone.name, "app_context_test_cloneable");
663            assert_eq!(service_clone.value, 42);
664        } else {
665            panic!("Should have gotten Some(service_clone)");
666        }
667
668        assert!(ctx.shared_store.contains::<CloneableTestService>());
669        assert!(!ctx.shared_store.contains::<String>());
670
671        let removed_cloneable_opt = ctx.shared_store.remove::<CloneableTestService>();
672        assert!(removed_cloneable_opt.is_some());
673        if let Some(removed) = removed_cloneable_opt {
674            assert_eq!(removed.name, "app_context_test_cloneable");
675            assert_eq!(removed.value, 42);
676        } else {
677            panic!("Removed cloneable option should be Some");
678        }
679        assert!(!ctx.shared_store.contains::<CloneableTestService>());
680
681        let service_non_cloneable = TestService {
682            name: "app_context_test_non_cloneable".to_string(),
683            value: 99,
684        };
685        ctx.shared_store.insert(service_non_cloneable);
686
687        let non_clone_ref_opt = ctx.shared_store.get_ref::<TestService>();
688        assert!(
689            non_clone_ref_opt.is_some(),
690            "Non-cloneable service ref should exist"
691        );
692        if let Some(service_ref) = non_clone_ref_opt {
693            assert_eq!(service_ref.name, "app_context_test_non_cloneable");
694            assert_eq!(service_ref.value, 99);
695        } else {
696            panic!("Should have gotten Some(service_ref)");
697        }
698
699        assert!(ctx.shared_store.contains::<TestService>());
700
701        let removed_non_cloneable_opt = ctx.shared_store.remove::<TestService>();
702        assert!(removed_non_cloneable_opt.is_some());
703        if let Some(removed) = removed_non_cloneable_opt {
704            assert_eq!(removed.name, "app_context_test_non_cloneable");
705            assert_eq!(removed.value, 99);
706        } else {
707            panic!("Removed non-cloneable option should be Some");
708        }
709        assert!(!ctx.shared_store.contains::<TestService>());
710    }
711}