Skip to main content

modo_auth/
provider.rs

1use std::future::Future;
2use std::pin::Pin;
3use std::sync::Arc;
4
5/// Trait for loading a user by their session-stored ID.
6///
7/// Implement this on your own type (e.g., a repository struct that holds a DB pool)
8/// and register it via `app.service(UserProviderService::new(your_impl))`.
9pub trait UserProvider: Send + Sync + 'static {
10    /// The user type returned by this provider.
11    type User: Clone + Send + Sync + 'static;
12
13    /// Look up a user by their ID (as stored in the session).
14    ///
15    /// Return `Ok(None)` if the user doesn't exist.
16    /// Return `Err` only for infrastructure failures (DB errors, etc.).
17    fn find_by_id(
18        &self,
19        id: &str,
20    ) -> impl Future<Output = Result<Option<Self::User>, modo::Error>> + Send;
21}
22
23/// Object-safe bridge trait for type-erasing `UserProvider`.
24trait UserProviderDyn<U>: Send + Sync {
25    fn find_by_id<'a>(
26        &'a self,
27        id: &'a str,
28    ) -> Pin<Box<dyn Future<Output = Result<Option<U>, modo::Error>> + Send + 'a>>;
29}
30
31impl<P: UserProvider> UserProviderDyn<P::User> for P {
32    fn find_by_id<'a>(
33        &'a self,
34        id: &'a str,
35    ) -> Pin<Box<dyn Future<Output = Result<Option<P::User>, modo::Error>> + Send + 'a>> {
36        Box::pin(UserProvider::find_by_id(self, id))
37    }
38}
39
40/// Type-erased wrapper around a [`UserProvider`] implementation.
41///
42/// Register with `app.service(UserProviderService::new(your_impl))`. The service
43/// is stored in the registry keyed by user type `U` so that `Auth<U>` and
44/// `OptionalAuth<U>` can retrieve it by `TypeId` at request time.
45pub struct UserProviderService<U: Clone + Send + Sync + 'static> {
46    inner: Arc<dyn UserProviderDyn<U>>,
47}
48
49impl<U: Clone + Send + Sync + 'static> Clone for UserProviderService<U> {
50    fn clone(&self) -> Self {
51        Self {
52            inner: Arc::clone(&self.inner),
53        }
54    }
55}
56
57impl<U: Clone + Send + Sync + 'static> std::fmt::Debug for UserProviderService<U> {
58    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
59        write!(f, "UserProviderService<{}>", std::any::type_name::<U>())
60    }
61}
62
63impl<U: Clone + Send + Sync + 'static> UserProviderService<U> {
64    /// Wrap a [`UserProvider`] implementation for registration in the service registry.
65    pub fn new<P: UserProvider<User = U>>(provider: P) -> Self {
66        Self {
67            inner: Arc::new(provider),
68        }
69    }
70
71    /// Look up a user by their session-stored ID, delegating to the wrapped provider.
72    pub async fn find_by_id(&self, id: &str) -> Result<Option<U>, modo::Error> {
73        self.inner.find_by_id(id).await
74    }
75}
76
77#[cfg(test)]
78mod tests {
79    use super::*;
80
81    #[derive(Clone, Debug, PartialEq)]
82    struct TestUser {
83        id: String,
84        name: String,
85    }
86
87    struct TestProvider;
88
89    impl UserProvider for TestProvider {
90        type User = TestUser;
91
92        async fn find_by_id(&self, id: &str) -> Result<Option<Self::User>, modo::Error> {
93            if id == "user-1" {
94                Ok(Some(TestUser {
95                    id: "user-1".to_string(),
96                    name: "Alice".to_string(),
97                }))
98            } else if id == "error-user" {
99                Err(modo::Error::internal("db error"))
100            } else {
101                Ok(None)
102            }
103        }
104    }
105
106    #[tokio::test]
107    async fn user_provider_service_finds_existing_user() {
108        let svc = UserProviderService::new(TestProvider);
109        let user = svc.find_by_id("user-1").await.unwrap();
110        assert_eq!(
111            user,
112            Some(TestUser {
113                id: "user-1".to_string(),
114                name: "Alice".to_string(),
115            })
116        );
117    }
118
119    #[tokio::test]
120    async fn user_provider_service_returns_none_for_missing_user() {
121        let svc = UserProviderService::new(TestProvider);
122        let user = svc.find_by_id("nonexistent").await.unwrap();
123        assert_eq!(user, None);
124    }
125
126    #[tokio::test]
127    async fn user_provider_service_propagates_errors() {
128        let svc = UserProviderService::new(TestProvider);
129        let result = svc.find_by_id("error-user").await;
130        assert!(result.is_err());
131    }
132}