battler_wamp/auth/undisputed/
authenticator.rs

1use anyhow::Result;
2use async_trait::async_trait;
3use futures_util::lock::Mutex;
4
5use crate::{
6    auth::{
7        Identity,
8        auth_method::AuthMethod,
9        authenticator::{
10            ClientAuthenticator as ClientAuthenticatorInterface,
11            ServerAuthenticator as ServerAuthenticatorInterface,
12        },
13        undisputed::{
14            UserData,
15            message::{
16                ClientFinalMessage,
17                ClientFinalMessageExtra,
18                ClientFirstMessage,
19                ClientFirstMessageExtra,
20                ServerFinalMessage,
21                ServerFinalMessageExtra,
22                ServerFirstMessage,
23                ServerFirstMessageExtra,
24            },
25        },
26    },
27    core::{
28        error::InteractionError,
29        hash::HashSet,
30    },
31};
32
33/// Server authenticator for WAMP-SCRAM.
34pub struct ServerAuthenticator {
35    user: Mutex<Option<UserData>>,
36}
37
38impl ServerAuthenticator {
39    /// Creates a new server authenticator.
40    pub fn new() -> Self {
41        Self {
42            user: Mutex::new(None),
43        }
44    }
45}
46
47#[async_trait]
48impl
49    ServerAuthenticatorInterface<
50        ClientFirstMessageExtra,
51        ServerFirstMessageExtra,
52        ClientFinalMessageExtra,
53        ServerFinalMessageExtra,
54    > for ServerAuthenticator
55{
56    fn auth_method(&self) -> AuthMethod {
57        AuthMethod::Undisputed
58    }
59
60    async fn challenge(&self, message: ClientFirstMessage) -> Result<ServerFirstMessage> {
61        let user = UserData {
62            identity: Identity {
63                id: message.id.clone(),
64                role: message.extra.role.clone(),
65            },
66        };
67        *self.user.lock().await = Some(user);
68        Ok(ServerFirstMessage {
69            method: self.auth_method(),
70            extra: ServerFirstMessageExtra {},
71        })
72    }
73
74    async fn authenticate(&self, _: ClientFinalMessage) -> Result<ServerFinalMessage> {
75        let user = self.user.lock().await;
76        let user = user.as_ref().ok_or_else(|| {
77            InteractionError::AuthenticationFailed("expected pending user".to_owned())
78        })?;
79        Ok(ServerFinalMessage {
80            identity: user.identity.clone(),
81            method: self.auth_method(),
82            provider: "static".to_owned(),
83            extra: ServerFinalMessageExtra {},
84        })
85    }
86}
87
88/// Client authenticator for WAMP-SCRAM.
89pub struct ClientAuthenticator {
90    id: String,
91    role: String,
92}
93
94impl ClientAuthenticator {
95    /// Creates a new client authenticator.
96    pub fn new(id: String, role: String) -> Self {
97        Self { id, role }
98    }
99}
100
101#[async_trait]
102impl
103    ClientAuthenticatorInterface<
104        ClientFirstMessageExtra,
105        ServerFirstMessageExtra,
106        ClientFinalMessageExtra,
107        ServerFinalMessageExtra,
108    > for ClientAuthenticator
109{
110    fn auth_method(&self) -> AuthMethod {
111        AuthMethod::Undisputed
112    }
113
114    async fn hello(&self) -> Result<ClientFirstMessage> {
115        Ok(ClientFirstMessage {
116            id: self.id.clone(),
117            methods: HashSet::from_iter([self.auth_method()]),
118            extra: ClientFirstMessageExtra {
119                role: self.role.clone(),
120            },
121        })
122    }
123
124    async fn handle_challenge(&self, _: ServerFirstMessage) -> Result<ClientFinalMessage> {
125        Ok(ClientFinalMessage {
126            signature: "not_applicable".to_owned(),
127            extra: ClientFinalMessageExtra {},
128        })
129    }
130
131    async fn verify_signature(&self, _: ServerFinalMessage) -> Result<()> {
132        Ok(())
133    }
134}
135
136#[cfg(test)]
137mod scram_test {
138    use anyhow::Result;
139    use async_trait::async_trait;
140
141    use crate::{
142        auth::{
143            authenticator::{
144                ClientAuthenticator,
145                ServerAuthenticator,
146            },
147            scram::{
148                authenticator::{
149                    ClientAuthenticator as ScramClientAuthenticator,
150                    ServerAuthenticator as ScramServerAuthenticator,
151                },
152                user::{
153                    UserData,
154                    UserDatabase,
155                    new_user,
156                },
157            },
158        },
159        core::{
160            error::InteractionError,
161            hash::HashMap,
162        },
163    };
164
165    #[derive(Default)]
166    struct FakeUserDatabase {
167        users: HashMap<String, UserData>,
168    }
169
170    impl<S, T> FromIterator<(S, T)> for FakeUserDatabase
171    where
172        S: Into<String> + AsRef<str>,
173        T: Into<String> + AsRef<str>,
174    {
175        fn from_iter<I>(iter: I) -> Self
176        where
177            I: IntoIterator<Item = (S, T)>,
178        {
179            Self {
180                users: iter
181                    .into_iter()
182                    .map(|(s, t)| {
183                        let user = new_user(s.as_ref(), t.as_ref()).unwrap();
184                        (s.into(), user)
185                    })
186                    .collect(),
187            }
188        }
189    }
190
191    #[async_trait]
192    impl UserDatabase for FakeUserDatabase {
193        async fn user_data(&self, id: &str) -> Result<UserData> {
194            self.users
195                .get(id)
196                .ok_or_else(|| InteractionError::NoSuchPrincipal.into())
197                .cloned()
198        }
199    }
200
201    #[tokio::test(flavor = "multi_thread")]
202    async fn client_and_server_authenticate_correctly() {
203        let server_authenticator = ScramServerAuthenticator::new(Box::new(
204            FakeUserDatabase::from_iter([("user", "password123!")]),
205        ));
206        let client_authenticator =
207            ScramClientAuthenticator::new("user".to_owned(), "password123!".to_owned());
208
209        let client_first = client_authenticator.hello().await.unwrap();
210        let server_first = server_authenticator.challenge(client_first).await.unwrap();
211        let client_final = client_authenticator
212            .handle_challenge(server_first)
213            .await
214            .unwrap();
215        let server_final = server_authenticator
216            .authenticate(client_final)
217            .await
218            .unwrap();
219        assert_matches::assert_matches!(
220            client_authenticator.verify_signature(server_final).await,
221            Ok(())
222        );
223    }
224
225    #[tokio::test(flavor = "multi_thread")]
226    async fn authentication_fails_for_invalid_password() {
227        let server_authenticator = ScramServerAuthenticator::new(Box::new(
228            FakeUserDatabase::from_iter([("user", "password123!")]),
229        ));
230        let client_authenticator =
231            ScramClientAuthenticator::new("user".to_owned(), "wrong".to_owned());
232
233        let client_first = client_authenticator.hello().await.unwrap();
234        let server_first = server_authenticator.challenge(client_first).await.unwrap();
235        let client_final = client_authenticator
236            .handle_challenge(server_first)
237            .await
238            .unwrap();
239        assert_matches::assert_matches!(server_authenticator.authenticate(client_final).await, Err(err) => {
240            assert_matches::assert_matches!(err.downcast::<InteractionError>(), Ok(InteractionError::AuthenticationDenied(_)));
241        });
242    }
243
244    #[tokio::test(flavor = "multi_thread")]
245    async fn authentication_fails_for_invalid_user() {
246        let server_authenticator = ScramServerAuthenticator::new(Box::new(
247            FakeUserDatabase::from_iter([("user", "password123!")]),
248        ));
249        let client_authenticator =
250            ScramClientAuthenticator::new("another".to_owned(), "password123!".to_owned());
251
252        let client_first = client_authenticator.hello().await.unwrap();
253        assert_matches::assert_matches!(server_authenticator.challenge(client_first).await, Err(err) => {
254            assert_matches::assert_matches!(err.downcast::<InteractionError>(), Ok(InteractionError::NoSuchPrincipal));
255        });
256    }
257}