1use std::collections::HashMap;
2use std::fmt::Display;
3use std::io;
4
5use async_trait::async_trait;
6
7use crate::authenticator::Authenticator;
8use crate::msg::*;
9
10mod methods;
11pub use methods::*;
12
13#[async_trait]
15pub trait AuthHandler: AuthMethodHandler + Send {
16 async fn on_initialization(
19 &mut self,
20 initialization: Initialization,
21 ) -> io::Result<InitializationResponse> {
22 Ok(InitializationResponse {
23 methods: initialization.methods,
24 })
25 }
26
27 #[allow(unused_variables)]
29 async fn on_start_method(&mut self, start_method: StartMethod) -> io::Result<()> {
30 Ok(())
31 }
32
33 async fn on_finished(&mut self) -> io::Result<()> {
35 Ok(())
36 }
37}
38
39pub struct DummyAuthHandler;
42
43#[async_trait]
44impl AuthHandler for DummyAuthHandler {}
45
46#[async_trait]
47impl AuthMethodHandler for DummyAuthHandler {
48 async fn on_challenge(&mut self, _: Challenge) -> io::Result<ChallengeResponse> {
49 Err(io::Error::from(io::ErrorKind::Unsupported))
50 }
51
52 async fn on_verification(&mut self, _: Verification) -> io::Result<VerificationResponse> {
53 Err(io::Error::from(io::ErrorKind::Unsupported))
54 }
55
56 async fn on_info(&mut self, _: Info) -> io::Result<()> {
57 Err(io::Error::from(io::ErrorKind::Unsupported))
58 }
59
60 async fn on_error(&mut self, _: Error) -> io::Result<()> {
61 Err(io::Error::from(io::ErrorKind::Unsupported))
62 }
63}
64
65pub struct SingleAuthHandler(Box<dyn AuthMethodHandler>);
67
68impl SingleAuthHandler {
69 pub fn new<T: AuthMethodHandler + 'static>(method_handler: T) -> Self {
70 Self(Box::new(method_handler))
71 }
72}
73
74#[async_trait]
75impl AuthHandler for SingleAuthHandler {}
76
77#[async_trait]
78impl AuthMethodHandler for SingleAuthHandler {
79 async fn on_challenge(&mut self, challenge: Challenge) -> io::Result<ChallengeResponse> {
80 self.0.on_challenge(challenge).await
81 }
82
83 async fn on_verification(
84 &mut self,
85 verification: Verification,
86 ) -> io::Result<VerificationResponse> {
87 self.0.on_verification(verification).await
88 }
89
90 async fn on_info(&mut self, info: Info) -> io::Result<()> {
91 self.0.on_info(info).await
92 }
93
94 async fn on_error(&mut self, error: Error) -> io::Result<()> {
95 self.0.on_error(error).await
96 }
97}
98
99pub struct AuthHandlerMap {
108 active: String,
109 map: HashMap<&'static str, Box<dyn AuthMethodHandler>>,
110}
111
112impl AuthHandlerMap {
113 pub fn new() -> Self {
115 Self {
116 active: String::new(),
117 map: HashMap::new(),
118 }
119 }
120
121 pub fn active_id(&self) -> &str {
123 &self.active
124 }
125
126 pub fn set_active_id(&mut self, id: impl Into<String>) {
128 self.active = id.into();
129 }
130
131 pub fn insert_method_handler<T: AuthMethodHandler + 'static>(
134 &mut self,
135 id: &'static str,
136 handler: T,
137 ) -> Option<Box<dyn AuthMethodHandler>> {
138 self.map.insert(id, Box::new(handler))
139 }
140
141 pub fn remove_method_handler(
143 &mut self,
144 id: &'static str,
145 ) -> Option<Box<dyn AuthMethodHandler>> {
146 self.map.remove(id)
147 }
148
149 pub fn get_mut_active_method_handler_or_error(
152 &mut self,
153 ) -> io::Result<&mut (dyn AuthMethodHandler + 'static)> {
154 let id = self.active.clone();
155 self.get_mut_active_method_handler().ok_or_else(|| {
156 io::Error::new(io::ErrorKind::Other, format!("No active handler for {id}"))
157 })
158 }
159
160 pub fn get_mut_active_method_handler(
162 &mut self,
163 ) -> Option<&mut (dyn AuthMethodHandler + 'static)> {
164 self.get_mut_method_handler(&self.active.clone())
166 }
167
168 pub fn get_mut_method_handler(
170 &mut self,
171 id: &str,
172 ) -> Option<&mut (dyn AuthMethodHandler + 'static)> {
173 self.map.get_mut(id).map(|h| h.as_mut())
174 }
175}
176
177impl AuthHandlerMap {
178 pub fn with_static_key<K>(mut self, key: K) -> Self
180 where
181 K: Display + Send + 'static,
182 {
183 self.insert_method_handler("static_key", StaticKeyAuthMethodHandler::simple(key));
184 self
185 }
186}
187
188impl Default for AuthHandlerMap {
189 fn default() -> Self {
190 Self::new()
191 }
192}
193
194#[async_trait]
195impl AuthHandler for AuthHandlerMap {
196 async fn on_initialization(
197 &mut self,
198 initialization: Initialization,
199 ) -> io::Result<InitializationResponse> {
200 let methods = initialization
201 .methods
202 .into_iter()
203 .filter(|method| self.map.contains_key(method.as_str()))
204 .collect();
205
206 Ok(InitializationResponse { methods })
207 }
208
209 async fn on_start_method(&mut self, start_method: StartMethod) -> io::Result<()> {
210 self.set_active_id(start_method.method);
211 Ok(())
212 }
213
214 async fn on_finished(&mut self) -> io::Result<()> {
215 Ok(())
216 }
217}
218
219#[async_trait]
220impl AuthMethodHandler for AuthHandlerMap {
221 async fn on_challenge(&mut self, challenge: Challenge) -> io::Result<ChallengeResponse> {
222 let handler = self.get_mut_active_method_handler_or_error()?;
223 handler.on_challenge(challenge).await
224 }
225
226 async fn on_verification(
227 &mut self,
228 verification: Verification,
229 ) -> io::Result<VerificationResponse> {
230 let handler = self.get_mut_active_method_handler_or_error()?;
231 handler.on_verification(verification).await
232 }
233
234 async fn on_info(&mut self, info: Info) -> io::Result<()> {
235 let handler = self.get_mut_active_method_handler_or_error()?;
236 handler.on_info(info).await
237 }
238
239 async fn on_error(&mut self, error: Error) -> io::Result<()> {
240 let handler = self.get_mut_active_method_handler_or_error()?;
241 handler.on_error(error).await
242 }
243}
244
245pub struct ProxyAuthHandler<'a>(&'a mut dyn Authenticator);
247
248impl<'a> ProxyAuthHandler<'a> {
249 pub fn new(authenticator: &'a mut dyn Authenticator) -> Self {
250 Self(authenticator)
251 }
252}
253
254#[async_trait]
255impl<'a> AuthHandler for ProxyAuthHandler<'a> {
256 async fn on_initialization(
257 &mut self,
258 initialization: Initialization,
259 ) -> io::Result<InitializationResponse> {
260 Authenticator::initialize(self.0, initialization).await
261 }
262
263 async fn on_start_method(&mut self, start_method: StartMethod) -> io::Result<()> {
264 Authenticator::start_method(self.0, start_method).await
265 }
266
267 async fn on_finished(&mut self) -> io::Result<()> {
268 Authenticator::finished(self.0).await
269 }
270}
271
272#[async_trait]
273impl<'a> AuthMethodHandler for ProxyAuthHandler<'a> {
274 async fn on_challenge(&mut self, challenge: Challenge) -> io::Result<ChallengeResponse> {
275 Authenticator::challenge(self.0, challenge).await
276 }
277
278 async fn on_verification(
279 &mut self,
280 verification: Verification,
281 ) -> io::Result<VerificationResponse> {
282 Authenticator::verify(self.0, verification).await
283 }
284
285 async fn on_info(&mut self, info: Info) -> io::Result<()> {
286 Authenticator::info(self.0, info).await
287 }
288
289 async fn on_error(&mut self, error: Error) -> io::Result<()> {
290 Authenticator::error(self.0, error).await
291 }
292}
293
294pub struct DynAuthHandler<'a>(&'a mut dyn AuthHandler);
297
298impl<'a> DynAuthHandler<'a> {
299 pub fn new(handler: &'a mut dyn AuthHandler) -> Self {
300 Self(handler)
301 }
302}
303
304impl<'a, T: AuthHandler> From<&'a mut T> for DynAuthHandler<'a> {
305 fn from(handler: &'a mut T) -> Self {
306 Self::new(handler as &mut dyn AuthHandler)
307 }
308}
309
310#[async_trait]
311impl<'a> AuthHandler for DynAuthHandler<'a> {
312 async fn on_initialization(
313 &mut self,
314 initialization: Initialization,
315 ) -> io::Result<InitializationResponse> {
316 self.0.on_initialization(initialization).await
317 }
318
319 async fn on_start_method(&mut self, start_method: StartMethod) -> io::Result<()> {
320 self.0.on_start_method(start_method).await
321 }
322
323 async fn on_finished(&mut self) -> io::Result<()> {
324 self.0.on_finished().await
325 }
326}
327
328#[async_trait]
329impl<'a> AuthMethodHandler for DynAuthHandler<'a> {
330 async fn on_challenge(&mut self, challenge: Challenge) -> io::Result<ChallengeResponse> {
331 self.0.on_challenge(challenge).await
332 }
333
334 async fn on_verification(
335 &mut self,
336 verification: Verification,
337 ) -> io::Result<VerificationResponse> {
338 self.0.on_verification(verification).await
339 }
340
341 async fn on_info(&mut self, info: Info) -> io::Result<()> {
342 self.0.on_info(info).await
343 }
344
345 async fn on_error(&mut self, error: Error) -> io::Result<()> {
346 self.0.on_error(error).await
347 }
348}
349
350#[cfg(any(test, feature = "tests"))]
352pub struct TestAuthHandler {
353 pub on_initialization:
354 Box<dyn FnMut(Initialization) -> io::Result<InitializationResponse> + Send>,
355 pub on_challenge: Box<dyn FnMut(Challenge) -> io::Result<ChallengeResponse> + Send>,
356 pub on_verification: Box<dyn FnMut(Verification) -> io::Result<VerificationResponse> + Send>,
357 pub on_info: Box<dyn FnMut(Info) -> io::Result<()> + Send>,
358 pub on_error: Box<dyn FnMut(Error) -> io::Result<()> + Send>,
359 pub on_start_method: Box<dyn FnMut(StartMethod) -> io::Result<()> + Send>,
360 pub on_finished: Box<dyn FnMut() -> io::Result<()> + Send>,
361}
362
363#[cfg(any(test, feature = "tests"))]
364impl Default for TestAuthHandler {
365 fn default() -> Self {
366 Self {
367 on_initialization: Box::new(|x| Ok(InitializationResponse { methods: x.methods })),
368 on_challenge: Box::new(|x| {
369 Ok(ChallengeResponse {
370 answers: x.questions.into_iter().map(|x| x.text).collect(),
371 })
372 }),
373 on_verification: Box::new(|_| Ok(VerificationResponse { valid: true })),
374 on_info: Box::new(|_| Ok(())),
375 on_error: Box::new(|_| Ok(())),
376 on_start_method: Box::new(|_| Ok(())),
377 on_finished: Box::new(|| Ok(())),
378 }
379 }
380}
381
382#[cfg(any(test, feature = "tests"))]
383#[async_trait]
384impl AuthHandler for TestAuthHandler {
385 async fn on_initialization(
386 &mut self,
387 initialization: Initialization,
388 ) -> io::Result<InitializationResponse> {
389 (self.on_initialization)(initialization)
390 }
391
392 async fn on_start_method(&mut self, start_method: StartMethod) -> io::Result<()> {
393 (self.on_start_method)(start_method)
394 }
395
396 async fn on_finished(&mut self) -> io::Result<()> {
397 (self.on_finished)()
398 }
399}
400
401#[cfg(any(test, feature = "tests"))]
402#[async_trait]
403impl AuthMethodHandler for TestAuthHandler {
404 async fn on_challenge(&mut self, challenge: Challenge) -> io::Result<ChallengeResponse> {
405 (self.on_challenge)(challenge)
406 }
407
408 async fn on_verification(
409 &mut self,
410 verification: Verification,
411 ) -> io::Result<VerificationResponse> {
412 (self.on_verification)(verification)
413 }
414
415 async fn on_info(&mut self, info: Info) -> io::Result<()> {
416 (self.on_info)(info)
417 }
418
419 async fn on_error(&mut self, error: Error) -> io::Result<()> {
420 (self.on_error)(error)
421 }
422}