Skip to main content

modo_auth/
provider.rs

1use std::future::Future;
2use std::pin::Pin;
3use std::sync::Arc;
4
5/// Trait for loading a user by their session-stored ID.
6///
7/// Implement this on your own type (e.g., a repository struct that holds a DB pool)
8/// and register it via `app.service(UserProviderService::new(your_impl))`.
9pub trait UserProvider: Send + Sync + 'static {
10    type User: Clone + Send + Sync + 'static;
11
12    /// Look up a user by their ID (as stored in the session).
13    ///
14    /// Return `Ok(None)` if the user doesn't exist.
15    /// Return `Err` only for infrastructure failures (DB errors, etc.).
16    fn find_by_id(
17        &self,
18        id: &str,
19    ) -> impl Future<Output = Result<Option<Self::User>, modo::Error>> + Send;
20}
21
22/// Object-safe bridge trait for type-erasing `UserProvider`.
23trait UserProviderDyn<U>: Send + Sync {
24    fn find_by_id<'a>(
25        &'a self,
26        id: &'a str,
27    ) -> Pin<Box<dyn Future<Output = Result<Option<U>, modo::Error>> + Send + 'a>>;
28}
29
30impl<P: UserProvider> UserProviderDyn<P::User> for P {
31    fn find_by_id<'a>(
32        &'a self,
33        id: &'a str,
34    ) -> Pin<Box<dyn Future<Output = Result<Option<P::User>, modo::Error>> + Send + 'a>> {
35        Box::pin(UserProvider::find_by_id(self, id))
36    }
37}
38
39/// Type-erased wrapper around a [`UserProvider`] implementation.
40///
41/// Register with `app.service(UserProviderService::new(your_impl))`. The service
42/// is stored in the registry keyed by user type `U` so that `Auth<U>` and
43/// `OptionalAuth<U>` can retrieve it by `TypeId` at request time.
44pub struct UserProviderService<U: Clone + Send + Sync + 'static> {
45    inner: Arc<dyn UserProviderDyn<U>>,
46}
47
48impl<U: Clone + Send + Sync + 'static> Clone for UserProviderService<U> {
49    fn clone(&self) -> Self {
50        Self {
51            inner: Arc::clone(&self.inner),
52        }
53    }
54}
55
56impl<U: Clone + Send + Sync + 'static> std::fmt::Debug for UserProviderService<U> {
57    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
58        write!(f, "UserProviderService<{}>", std::any::type_name::<U>())
59    }
60}
61
62impl<U: Clone + Send + Sync + 'static> UserProviderService<U> {
63    /// Wrap a [`UserProvider`] implementation for registration in the service registry.
64    pub fn new<P: UserProvider<User = U>>(provider: P) -> Self {
65        Self {
66            inner: Arc::new(provider),
67        }
68    }
69
70    /// Look up a user by their session-stored ID, delegating to the wrapped provider.
71    pub async fn find_by_id(&self, id: &str) -> Result<Option<U>, modo::Error> {
72        self.inner.find_by_id(id).await
73    }
74}
75
76#[cfg(test)]
77mod tests {
78    use super::*;
79
80    #[derive(Clone, Debug, PartialEq)]
81    struct TestUser {
82        id: String,
83        name: String,
84    }
85
86    struct TestProvider;
87
88    impl UserProvider for TestProvider {
89        type User = TestUser;
90
91        async fn find_by_id(&self, id: &str) -> Result<Option<Self::User>, modo::Error> {
92            if id == "user-1" {
93                Ok(Some(TestUser {
94                    id: "user-1".to_string(),
95                    name: "Alice".to_string(),
96                }))
97            } else if id == "error-user" {
98                Err(modo::Error::internal("db error"))
99            } else {
100                Ok(None)
101            }
102        }
103    }
104
105    #[tokio::test]
106    async fn user_provider_service_finds_existing_user() {
107        let svc = UserProviderService::new(TestProvider);
108        let user = svc.find_by_id("user-1").await.unwrap();
109        assert_eq!(
110            user,
111            Some(TestUser {
112                id: "user-1".to_string(),
113                name: "Alice".to_string(),
114            })
115        );
116    }
117
118    #[tokio::test]
119    async fn user_provider_service_returns_none_for_missing_user() {
120        let svc = UserProviderService::new(TestProvider);
121        let user = svc.find_by_id("nonexistent").await.unwrap();
122        assert_eq!(user, None);
123    }
124
125    #[tokio::test]
126    async fn user_provider_service_propagates_errors() {
127        let svc = UserProviderService::new(TestProvider);
128        let result = svc.find_by_id("error-user").await;
129        assert!(result.is_err());
130    }
131}