distant_auth/
handler.rs

1use std::collections::HashMap;
2use std::fmt::Display;
3use std::io;
4
5use async_trait::async_trait;
6
7use crate::authenticator::Authenticator;
8use crate::msg::*;
9
10mod methods;
11pub use methods::*;
12
13/// Interface for a handler of authentication requests for all methods.
14#[async_trait]
15pub trait AuthHandler: AuthMethodHandler + Send {
16    /// Callback when authentication is beginning, providing available authentication methods and
17    /// returning selected authentication methods to pursue.
18    async fn on_initialization(
19        &mut self,
20        initialization: Initialization,
21    ) -> io::Result<InitializationResponse> {
22        Ok(InitializationResponse {
23            methods: initialization.methods,
24        })
25    }
26
27    /// Callback when authentication starts for a specific method.
28    #[allow(unused_variables)]
29    async fn on_start_method(&mut self, start_method: StartMethod) -> io::Result<()> {
30        Ok(())
31    }
32
33    /// Callback when authentication is finished and no more requests will be received.
34    async fn on_finished(&mut self) -> io::Result<()> {
35        Ok(())
36    }
37}
38
39/// Dummy implementation of [`AuthHandler`] where any challenge or verification request will
40/// instantly fail.
41pub struct DummyAuthHandler;
42
43#[async_trait]
44impl AuthHandler for DummyAuthHandler {}
45
46#[async_trait]
47impl AuthMethodHandler for DummyAuthHandler {
48    async fn on_challenge(&mut self, _: Challenge) -> io::Result<ChallengeResponse> {
49        Err(io::Error::from(io::ErrorKind::Unsupported))
50    }
51
52    async fn on_verification(&mut self, _: Verification) -> io::Result<VerificationResponse> {
53        Err(io::Error::from(io::ErrorKind::Unsupported))
54    }
55
56    async fn on_info(&mut self, _: Info) -> io::Result<()> {
57        Err(io::Error::from(io::ErrorKind::Unsupported))
58    }
59
60    async fn on_error(&mut self, _: Error) -> io::Result<()> {
61        Err(io::Error::from(io::ErrorKind::Unsupported))
62    }
63}
64
65/// Implementation of [`AuthHandler`] that uses the same [`AuthMethodHandler`] for all methods.
66pub struct SingleAuthHandler(Box<dyn AuthMethodHandler>);
67
68impl SingleAuthHandler {
69    pub fn new<T: AuthMethodHandler + 'static>(method_handler: T) -> Self {
70        Self(Box::new(method_handler))
71    }
72}
73
74#[async_trait]
75impl AuthHandler for SingleAuthHandler {}
76
77#[async_trait]
78impl AuthMethodHandler for SingleAuthHandler {
79    async fn on_challenge(&mut self, challenge: Challenge) -> io::Result<ChallengeResponse> {
80        self.0.on_challenge(challenge).await
81    }
82
83    async fn on_verification(
84        &mut self,
85        verification: Verification,
86    ) -> io::Result<VerificationResponse> {
87        self.0.on_verification(verification).await
88    }
89
90    async fn on_info(&mut self, info: Info) -> io::Result<()> {
91        self.0.on_info(info).await
92    }
93
94    async fn on_error(&mut self, error: Error) -> io::Result<()> {
95        self.0.on_error(error).await
96    }
97}
98
99/// Implementation of [`AuthHandler`] that maintains a map of [`AuthMethodHandler`] implementations
100/// for specific methods, invoking [`on_challenge`], [`on_verification`], [`on_info`], and
101/// [`on_error`] for a specific handler based on an associated id.
102///
103/// [`on_challenge`]: AuthMethodHandler::on_challenge
104/// [`on_verification`]: AuthMethodHandler::on_verification
105/// [`on_info`]: AuthMethodHandler::on_info
106/// [`on_error`]: AuthMethodHandler::on_error
107pub struct AuthHandlerMap {
108    active: String,
109    map: HashMap<&'static str, Box<dyn AuthMethodHandler>>,
110}
111
112impl AuthHandlerMap {
113    /// Creates a new, empty map of auth method handlers.
114    pub fn new() -> Self {
115        Self {
116            active: String::new(),
117            map: HashMap::new(),
118        }
119    }
120
121    /// Returns the `id` of the active [`AuthMethodHandler`].
122    pub fn active_id(&self) -> &str {
123        &self.active
124    }
125
126    /// Sets the active [`AuthMethodHandler`] by its `id`.
127    pub fn set_active_id(&mut self, id: impl Into<String>) {
128        self.active = id.into();
129    }
130
131    /// Inserts the specified `handler` into the map, associating it with `id` for determining the
132    /// method that would trigger this handler.
133    pub fn insert_method_handler<T: AuthMethodHandler + 'static>(
134        &mut self,
135        id: &'static str,
136        handler: T,
137    ) -> Option<Box<dyn AuthMethodHandler>> {
138        self.map.insert(id, Box::new(handler))
139    }
140
141    /// Removes a handler with the associated `id`.
142    pub fn remove_method_handler(
143        &mut self,
144        id: &'static str,
145    ) -> Option<Box<dyn AuthMethodHandler>> {
146        self.map.remove(id)
147    }
148
149    /// Retrieves a mutable reference to the active [`AuthMethodHandler`] with the specified `id`,
150    /// returning an error if no handler for the active id is found.
151    pub fn get_mut_active_method_handler_or_error(
152        &mut self,
153    ) -> io::Result<&mut (dyn AuthMethodHandler + 'static)> {
154        let id = self.active.clone();
155        self.get_mut_active_method_handler().ok_or_else(|| {
156            io::Error::new(io::ErrorKind::Other, format!("No active handler for {id}"))
157        })
158    }
159
160    /// Retrieves a mutable reference to the active [`AuthMethodHandler`] with the specified `id`.
161    pub fn get_mut_active_method_handler(
162        &mut self,
163    ) -> Option<&mut (dyn AuthMethodHandler + 'static)> {
164        // TODO: Optimize this
165        self.get_mut_method_handler(&self.active.clone())
166    }
167
168    /// Retrieves a mutable reference to the [`AuthMethodHandler`] with the specified `id`.
169    pub fn get_mut_method_handler(
170        &mut self,
171        id: &str,
172    ) -> Option<&mut (dyn AuthMethodHandler + 'static)> {
173        self.map.get_mut(id).map(|h| h.as_mut())
174    }
175}
176
177impl AuthHandlerMap {
178    /// Consumes the map, returning a new map that supports the `static_key` method.
179    pub fn with_static_key<K>(mut self, key: K) -> Self
180    where
181        K: Display + Send + 'static,
182    {
183        self.insert_method_handler("static_key", StaticKeyAuthMethodHandler::simple(key));
184        self
185    }
186}
187
188impl Default for AuthHandlerMap {
189    fn default() -> Self {
190        Self::new()
191    }
192}
193
194#[async_trait]
195impl AuthHandler for AuthHandlerMap {
196    async fn on_initialization(
197        &mut self,
198        initialization: Initialization,
199    ) -> io::Result<InitializationResponse> {
200        let methods = initialization
201            .methods
202            .into_iter()
203            .filter(|method| self.map.contains_key(method.as_str()))
204            .collect();
205
206        Ok(InitializationResponse { methods })
207    }
208
209    async fn on_start_method(&mut self, start_method: StartMethod) -> io::Result<()> {
210        self.set_active_id(start_method.method);
211        Ok(())
212    }
213
214    async fn on_finished(&mut self) -> io::Result<()> {
215        Ok(())
216    }
217}
218
219#[async_trait]
220impl AuthMethodHandler for AuthHandlerMap {
221    async fn on_challenge(&mut self, challenge: Challenge) -> io::Result<ChallengeResponse> {
222        let handler = self.get_mut_active_method_handler_or_error()?;
223        handler.on_challenge(challenge).await
224    }
225
226    async fn on_verification(
227        &mut self,
228        verification: Verification,
229    ) -> io::Result<VerificationResponse> {
230        let handler = self.get_mut_active_method_handler_or_error()?;
231        handler.on_verification(verification).await
232    }
233
234    async fn on_info(&mut self, info: Info) -> io::Result<()> {
235        let handler = self.get_mut_active_method_handler_or_error()?;
236        handler.on_info(info).await
237    }
238
239    async fn on_error(&mut self, error: Error) -> io::Result<()> {
240        let handler = self.get_mut_active_method_handler_or_error()?;
241        handler.on_error(error).await
242    }
243}
244
245/// Implementation of [`AuthHandler`] that redirects all requests to an [`Authenticator`].
246pub struct ProxyAuthHandler<'a>(&'a mut dyn Authenticator);
247
248impl<'a> ProxyAuthHandler<'a> {
249    pub fn new(authenticator: &'a mut dyn Authenticator) -> Self {
250        Self(authenticator)
251    }
252}
253
254#[async_trait]
255impl<'a> AuthHandler for ProxyAuthHandler<'a> {
256    async fn on_initialization(
257        &mut self,
258        initialization: Initialization,
259    ) -> io::Result<InitializationResponse> {
260        Authenticator::initialize(self.0, initialization).await
261    }
262
263    async fn on_start_method(&mut self, start_method: StartMethod) -> io::Result<()> {
264        Authenticator::start_method(self.0, start_method).await
265    }
266
267    async fn on_finished(&mut self) -> io::Result<()> {
268        Authenticator::finished(self.0).await
269    }
270}
271
272#[async_trait]
273impl<'a> AuthMethodHandler for ProxyAuthHandler<'a> {
274    async fn on_challenge(&mut self, challenge: Challenge) -> io::Result<ChallengeResponse> {
275        Authenticator::challenge(self.0, challenge).await
276    }
277
278    async fn on_verification(
279        &mut self,
280        verification: Verification,
281    ) -> io::Result<VerificationResponse> {
282        Authenticator::verify(self.0, verification).await
283    }
284
285    async fn on_info(&mut self, info: Info) -> io::Result<()> {
286        Authenticator::info(self.0, info).await
287    }
288
289    async fn on_error(&mut self, error: Error) -> io::Result<()> {
290        Authenticator::error(self.0, error).await
291    }
292}
293
294/// Implementation of [`AuthHandler`] that holds a mutable reference to another [`AuthHandler`]
295/// trait object to use underneath.
296pub struct DynAuthHandler<'a>(&'a mut dyn AuthHandler);
297
298impl<'a> DynAuthHandler<'a> {
299    pub fn new(handler: &'a mut dyn AuthHandler) -> Self {
300        Self(handler)
301    }
302}
303
304impl<'a, T: AuthHandler> From<&'a mut T> for DynAuthHandler<'a> {
305    fn from(handler: &'a mut T) -> Self {
306        Self::new(handler as &mut dyn AuthHandler)
307    }
308}
309
310#[async_trait]
311impl<'a> AuthHandler for DynAuthHandler<'a> {
312    async fn on_initialization(
313        &mut self,
314        initialization: Initialization,
315    ) -> io::Result<InitializationResponse> {
316        self.0.on_initialization(initialization).await
317    }
318
319    async fn on_start_method(&mut self, start_method: StartMethod) -> io::Result<()> {
320        self.0.on_start_method(start_method).await
321    }
322
323    async fn on_finished(&mut self) -> io::Result<()> {
324        self.0.on_finished().await
325    }
326}
327
328#[async_trait]
329impl<'a> AuthMethodHandler for DynAuthHandler<'a> {
330    async fn on_challenge(&mut self, challenge: Challenge) -> io::Result<ChallengeResponse> {
331        self.0.on_challenge(challenge).await
332    }
333
334    async fn on_verification(
335        &mut self,
336        verification: Verification,
337    ) -> io::Result<VerificationResponse> {
338        self.0.on_verification(verification).await
339    }
340
341    async fn on_info(&mut self, info: Info) -> io::Result<()> {
342        self.0.on_info(info).await
343    }
344
345    async fn on_error(&mut self, error: Error) -> io::Result<()> {
346        self.0.on_error(error).await
347    }
348}
349
350/// Represents an implementator of [`AuthHandler`] used purely for testing purposes.
351#[cfg(any(test, feature = "tests"))]
352pub struct TestAuthHandler {
353    pub on_initialization:
354        Box<dyn FnMut(Initialization) -> io::Result<InitializationResponse> + Send>,
355    pub on_challenge: Box<dyn FnMut(Challenge) -> io::Result<ChallengeResponse> + Send>,
356    pub on_verification: Box<dyn FnMut(Verification) -> io::Result<VerificationResponse> + Send>,
357    pub on_info: Box<dyn FnMut(Info) -> io::Result<()> + Send>,
358    pub on_error: Box<dyn FnMut(Error) -> io::Result<()> + Send>,
359    pub on_start_method: Box<dyn FnMut(StartMethod) -> io::Result<()> + Send>,
360    pub on_finished: Box<dyn FnMut() -> io::Result<()> + Send>,
361}
362
363#[cfg(any(test, feature = "tests"))]
364impl Default for TestAuthHandler {
365    fn default() -> Self {
366        Self {
367            on_initialization: Box::new(|x| Ok(InitializationResponse { methods: x.methods })),
368            on_challenge: Box::new(|x| {
369                Ok(ChallengeResponse {
370                    answers: x.questions.into_iter().map(|x| x.text).collect(),
371                })
372            }),
373            on_verification: Box::new(|_| Ok(VerificationResponse { valid: true })),
374            on_info: Box::new(|_| Ok(())),
375            on_error: Box::new(|_| Ok(())),
376            on_start_method: Box::new(|_| Ok(())),
377            on_finished: Box::new(|| Ok(())),
378        }
379    }
380}
381
382#[cfg(any(test, feature = "tests"))]
383#[async_trait]
384impl AuthHandler for TestAuthHandler {
385    async fn on_initialization(
386        &mut self,
387        initialization: Initialization,
388    ) -> io::Result<InitializationResponse> {
389        (self.on_initialization)(initialization)
390    }
391
392    async fn on_start_method(&mut self, start_method: StartMethod) -> io::Result<()> {
393        (self.on_start_method)(start_method)
394    }
395
396    async fn on_finished(&mut self) -> io::Result<()> {
397        (self.on_finished)()
398    }
399}
400
401#[cfg(any(test, feature = "tests"))]
402#[async_trait]
403impl AuthMethodHandler for TestAuthHandler {
404    async fn on_challenge(&mut self, challenge: Challenge) -> io::Result<ChallengeResponse> {
405        (self.on_challenge)(challenge)
406    }
407
408    async fn on_verification(
409        &mut self,
410        verification: Verification,
411    ) -> io::Result<VerificationResponse> {
412        (self.on_verification)(verification)
413    }
414
415    async fn on_info(&mut self, info: Info) -> io::Result<()> {
416        (self.on_info)(info)
417    }
418
419    async fn on_error(&mut self, error: Error) -> io::Result<()> {
420        (self.on_error)(error)
421    }
422}