battler_wamp/auth/undisputed/
authenticator.rs1use anyhow::Result;
2use async_trait::async_trait;
3use futures_util::lock::Mutex;
4
5use crate::{
6 auth::{
7 Identity,
8 auth_method::AuthMethod,
9 authenticator::{
10 ClientAuthenticator as ClientAuthenticatorInterface,
11 ServerAuthenticator as ServerAuthenticatorInterface,
12 },
13 undisputed::{
14 UserData,
15 message::{
16 ClientFinalMessage,
17 ClientFinalMessageExtra,
18 ClientFirstMessage,
19 ClientFirstMessageExtra,
20 ServerFinalMessage,
21 ServerFinalMessageExtra,
22 ServerFirstMessage,
23 ServerFirstMessageExtra,
24 },
25 },
26 },
27 core::{
28 error::InteractionError,
29 hash::HashSet,
30 },
31};
32
33pub struct ServerAuthenticator {
35 user: Mutex<Option<UserData>>,
36}
37
38impl ServerAuthenticator {
39 pub fn new() -> Self {
41 Self {
42 user: Mutex::new(None),
43 }
44 }
45}
46
47#[async_trait]
48impl
49 ServerAuthenticatorInterface<
50 ClientFirstMessageExtra,
51 ServerFirstMessageExtra,
52 ClientFinalMessageExtra,
53 ServerFinalMessageExtra,
54 > for ServerAuthenticator
55{
56 fn auth_method(&self) -> AuthMethod {
57 AuthMethod::Undisputed
58 }
59
60 async fn challenge(&self, message: ClientFirstMessage) -> Result<ServerFirstMessage> {
61 let user = UserData {
62 identity: Identity {
63 id: message.id.clone(),
64 role: message.extra.role.clone(),
65 },
66 };
67 *self.user.lock().await = Some(user);
68 Ok(ServerFirstMessage {
69 method: self.auth_method(),
70 extra: ServerFirstMessageExtra {},
71 })
72 }
73
74 async fn authenticate(&self, _: ClientFinalMessage) -> Result<ServerFinalMessage> {
75 let user = self.user.lock().await;
76 let user = user.as_ref().ok_or_else(|| {
77 InteractionError::AuthenticationFailed("expected pending user".to_owned())
78 })?;
79 Ok(ServerFinalMessage {
80 identity: user.identity.clone(),
81 method: self.auth_method(),
82 provider: "static".to_owned(),
83 extra: ServerFinalMessageExtra {},
84 })
85 }
86}
87
88pub struct ClientAuthenticator {
90 id: String,
91 role: String,
92}
93
94impl ClientAuthenticator {
95 pub fn new(id: String, role: String) -> Self {
97 Self { id, role }
98 }
99}
100
101#[async_trait]
102impl
103 ClientAuthenticatorInterface<
104 ClientFirstMessageExtra,
105 ServerFirstMessageExtra,
106 ClientFinalMessageExtra,
107 ServerFinalMessageExtra,
108 > for ClientAuthenticator
109{
110 fn auth_method(&self) -> AuthMethod {
111 AuthMethod::Undisputed
112 }
113
114 async fn hello(&self) -> Result<ClientFirstMessage> {
115 Ok(ClientFirstMessage {
116 id: self.id.clone(),
117 methods: HashSet::from_iter([self.auth_method()]),
118 extra: ClientFirstMessageExtra {
119 role: self.role.clone(),
120 },
121 })
122 }
123
124 async fn handle_challenge(&self, _: ServerFirstMessage) -> Result<ClientFinalMessage> {
125 Ok(ClientFinalMessage {
126 signature: "not_applicable".to_owned(),
127 extra: ClientFinalMessageExtra {},
128 })
129 }
130
131 async fn verify_signature(&self, _: ServerFinalMessage) -> Result<()> {
132 Ok(())
133 }
134}
135
136#[cfg(test)]
137mod scram_test {
138 use anyhow::Result;
139 use async_trait::async_trait;
140
141 use crate::{
142 auth::{
143 authenticator::{
144 ClientAuthenticator,
145 ServerAuthenticator,
146 },
147 scram::{
148 authenticator::{
149 ClientAuthenticator as ScramClientAuthenticator,
150 ServerAuthenticator as ScramServerAuthenticator,
151 },
152 user::{
153 UserData,
154 UserDatabase,
155 new_user,
156 },
157 },
158 },
159 core::{
160 error::InteractionError,
161 hash::HashMap,
162 },
163 };
164
165 #[derive(Default)]
166 struct FakeUserDatabase {
167 users: HashMap<String, UserData>,
168 }
169
170 impl<S, T> FromIterator<(S, T)> for FakeUserDatabase
171 where
172 S: Into<String> + AsRef<str>,
173 T: Into<String> + AsRef<str>,
174 {
175 fn from_iter<I>(iter: I) -> Self
176 where
177 I: IntoIterator<Item = (S, T)>,
178 {
179 Self {
180 users: iter
181 .into_iter()
182 .map(|(s, t)| {
183 let user = new_user(s.as_ref(), t.as_ref()).unwrap();
184 (s.into(), user)
185 })
186 .collect(),
187 }
188 }
189 }
190
191 #[async_trait]
192 impl UserDatabase for FakeUserDatabase {
193 async fn user_data(&self, id: &str) -> Result<UserData> {
194 self.users
195 .get(id)
196 .ok_or_else(|| InteractionError::NoSuchPrincipal.into())
197 .cloned()
198 }
199 }
200
201 #[tokio::test(flavor = "multi_thread")]
202 async fn client_and_server_authenticate_correctly() {
203 let server_authenticator = ScramServerAuthenticator::new(Box::new(
204 FakeUserDatabase::from_iter([("user", "password123!")]),
205 ));
206 let client_authenticator =
207 ScramClientAuthenticator::new("user".to_owned(), "password123!".to_owned());
208
209 let client_first = client_authenticator.hello().await.unwrap();
210 let server_first = server_authenticator.challenge(client_first).await.unwrap();
211 let client_final = client_authenticator
212 .handle_challenge(server_first)
213 .await
214 .unwrap();
215 let server_final = server_authenticator
216 .authenticate(client_final)
217 .await
218 .unwrap();
219 assert_matches::assert_matches!(
220 client_authenticator.verify_signature(server_final).await,
221 Ok(())
222 );
223 }
224
225 #[tokio::test(flavor = "multi_thread")]
226 async fn authentication_fails_for_invalid_password() {
227 let server_authenticator = ScramServerAuthenticator::new(Box::new(
228 FakeUserDatabase::from_iter([("user", "password123!")]),
229 ));
230 let client_authenticator =
231 ScramClientAuthenticator::new("user".to_owned(), "wrong".to_owned());
232
233 let client_first = client_authenticator.hello().await.unwrap();
234 let server_first = server_authenticator.challenge(client_first).await.unwrap();
235 let client_final = client_authenticator
236 .handle_challenge(server_first)
237 .await
238 .unwrap();
239 assert_matches::assert_matches!(server_authenticator.authenticate(client_final).await, Err(err) => {
240 assert_matches::assert_matches!(err.downcast::<InteractionError>(), Ok(InteractionError::AuthenticationDenied(_)));
241 });
242 }
243
244 #[tokio::test(flavor = "multi_thread")]
245 async fn authentication_fails_for_invalid_user() {
246 let server_authenticator = ScramServerAuthenticator::new(Box::new(
247 FakeUserDatabase::from_iter([("user", "password123!")]),
248 ));
249 let client_authenticator =
250 ScramClientAuthenticator::new("another".to_owned(), "password123!".to_owned());
251
252 let client_first = client_authenticator.hello().await.unwrap();
253 assert_matches::assert_matches!(server_authenticator.challenge(client_first).await, Err(err) => {
254 assert_matches::assert_matches!(err.downcast::<InteractionError>(), Ok(InteractionError::NoSuchPrincipal));
255 });
256 }
257}