distant_net/
authentication.rs

1use std::io;
2
3use async_trait::async_trait;
4use distant_auth::msg::*;
5use distant_auth::{AuthHandler, Authenticate, Authenticator};
6use log::*;
7
8use crate::common::{utils, FramedTransport, Transport};
9
10macro_rules! write_frame {
11    ($transport:expr, $data:expr) => {{
12        let data = utils::serialize_to_vec(&$data)?;
13        if log_enabled!(Level::Trace) {
14            trace!("Writing data as frame: {data:?}");
15        }
16
17        $transport.write_frame(data).await?
18    }};
19}
20
21macro_rules! next_frame_as {
22    ($transport:expr, $type:ident, $variant:ident) => {{
23        match { next_frame_as!($transport, $type) } {
24            $type::$variant(x) => x,
25            x => {
26                return Err(io::Error::new(
27                    io::ErrorKind::InvalidData,
28                    format!("Unexpected frame: {x:?}"),
29                ))
30            }
31        }
32    }};
33    ($transport:expr, $type:ident) => {{
34        let frame = $transport.read_frame().await?.ok_or_else(|| {
35            io::Error::new(
36                io::ErrorKind::UnexpectedEof,
37                concat!(
38                    "Transport closed early waiting for frame of type ",
39                    stringify!($type),
40                ),
41            )
42        })?;
43
44        match utils::deserialize_from_slice::<$type>(frame.as_item()) {
45            Ok(frame) => frame,
46            Err(x) => {
47                if log_enabled!(Level::Trace) {
48                    trace!(
49                        "Failed to deserialize frame item as {}: {:?}",
50                        stringify!($type),
51                        frame.as_item()
52                    );
53                }
54
55                Err(x)?;
56                unreachable!();
57            }
58        }
59    }};
60}
61
62#[async_trait]
63impl<T> Authenticate for FramedTransport<T>
64where
65    T: Transport,
66{
67    async fn authenticate(&mut self, mut handler: impl AuthHandler + Send) -> io::Result<()> {
68        loop {
69            trace!("Authenticate::authenticate waiting on next authentication frame");
70            match next_frame_as!(self, Authentication) {
71                Authentication::Initialization(x) => {
72                    trace!("Authenticate::Initialization({x:?})");
73                    let response = handler.on_initialization(x).await?;
74                    write_frame!(self, AuthenticationResponse::Initialization(response));
75                }
76                Authentication::Challenge(x) => {
77                    trace!("Authenticate::Challenge({x:?})");
78                    let response = handler.on_challenge(x).await?;
79                    write_frame!(self, AuthenticationResponse::Challenge(response));
80                }
81                Authentication::Verification(x) => {
82                    trace!("Authenticate::Verify({x:?})");
83                    let response = handler.on_verification(x).await?;
84                    write_frame!(self, AuthenticationResponse::Verification(response));
85                }
86                Authentication::Info(x) => {
87                    trace!("Authenticate::Info({x:?})");
88                    handler.on_info(x).await?;
89                }
90                Authentication::Error(x) => {
91                    trace!("Authenticate::Error({x:?})");
92                    handler.on_error(x.clone()).await?;
93
94                    if x.is_fatal() {
95                        return Err(x.into_io_permission_denied());
96                    }
97                }
98                Authentication::StartMethod(x) => {
99                    trace!("Authenticate::StartMethod({x:?})");
100                    handler.on_start_method(x).await?;
101                }
102                Authentication::Finished => {
103                    trace!("Authenticate::Finished");
104                    handler.on_finished().await?;
105                    return Ok(());
106                }
107            }
108        }
109    }
110}
111
112#[async_trait]
113impl<T> Authenticator for FramedTransport<T>
114where
115    T: Transport,
116{
117    async fn initialize(
118        &mut self,
119        initialization: Initialization,
120    ) -> io::Result<InitializationResponse> {
121        trace!("Authenticator::initialize({initialization:?})");
122        write_frame!(self, Authentication::Initialization(initialization));
123        let response = next_frame_as!(self, AuthenticationResponse, Initialization);
124        Ok(response)
125    }
126
127    async fn challenge(&mut self, challenge: Challenge) -> io::Result<ChallengeResponse> {
128        trace!("Authenticator::challenge({challenge:?})");
129        write_frame!(self, Authentication::Challenge(challenge));
130        let response = next_frame_as!(self, AuthenticationResponse, Challenge);
131        Ok(response)
132    }
133
134    async fn verify(&mut self, verification: Verification) -> io::Result<VerificationResponse> {
135        trace!("Authenticator::verify({verification:?})");
136        write_frame!(self, Authentication::Verification(verification));
137        let response = next_frame_as!(self, AuthenticationResponse, Verification);
138        Ok(response)
139    }
140
141    async fn info(&mut self, info: Info) -> io::Result<()> {
142        trace!("Authenticator::info({info:?})");
143        write_frame!(self, Authentication::Info(info));
144        Ok(())
145    }
146
147    async fn error(&mut self, error: Error) -> io::Result<()> {
148        trace!("Authenticator::error({error:?})");
149        write_frame!(self, Authentication::Error(error));
150        Ok(())
151    }
152
153    async fn start_method(&mut self, start_method: StartMethod) -> io::Result<()> {
154        trace!("Authenticator::start_method({start_method:?})");
155        write_frame!(self, Authentication::StartMethod(start_method));
156        Ok(())
157    }
158
159    async fn finished(&mut self) -> io::Result<()> {
160        trace!("Authenticator::finished()");
161        write_frame!(self, Authentication::Finished);
162        Ok(())
163    }
164}
165
166#[cfg(test)]
167mod tests {
168    use distant_auth::tests::TestAuthHandler;
169    use test_log::test;
170    use tokio::sync::mpsc;
171
172    use super::*;
173
174    #[test(tokio::test)]
175    async fn authenticator_initialization_should_be_able_to_successfully_complete_round_trip() {
176        let (mut t1, mut t2) = FramedTransport::test_pair(100);
177
178        let task = tokio::spawn(async move {
179            t2.authenticate(TestAuthHandler {
180                on_initialization: Box::new(|x| Ok(InitializationResponse { methods: x.methods })),
181                ..Default::default()
182            })
183            .await
184            .unwrap()
185        });
186
187        let response = t1
188            .initialize(Initialization {
189                methods: vec!["test method".to_string()].into_iter().collect(),
190            })
191            .await
192            .unwrap();
193
194        assert!(
195            !task.is_finished(),
196            "Auth handler unexpectedly finished without signal"
197        );
198
199        assert_eq!(
200            response,
201            InitializationResponse {
202                methods: vec!["test method".to_string()].into_iter().collect()
203            }
204        );
205    }
206
207    #[test(tokio::test)]
208    async fn authenticator_challenge_should_be_able_to_successfully_complete_round_trip() {
209        let (mut t1, mut t2) = FramedTransport::test_pair(100);
210
211        let task = tokio::spawn(async move {
212            t2.authenticate(TestAuthHandler {
213                on_challenge: Box::new(|challenge| {
214                    assert_eq!(
215                        challenge.questions,
216                        vec![Question {
217                            label: "label".to_string(),
218                            text: "text".to_string(),
219                            options: vec![(
220                                "question_key".to_string(),
221                                "question_value".to_string()
222                            )]
223                            .into_iter()
224                            .collect(),
225                        }]
226                    );
227                    assert_eq!(
228                        challenge.options,
229                        vec![("key".to_string(), "value".to_string())]
230                            .into_iter()
231                            .collect(),
232                    );
233                    Ok(ChallengeResponse {
234                        answers: vec!["some answer".to_string()].into_iter().collect(),
235                    })
236                }),
237                ..Default::default()
238            })
239            .await
240            .unwrap()
241        });
242
243        let response = t1
244            .challenge(Challenge {
245                questions: vec![Question {
246                    label: "label".to_string(),
247                    text: "text".to_string(),
248                    options: vec![("question_key".to_string(), "question_value".to_string())]
249                        .into_iter()
250                        .collect(),
251                }],
252                options: vec![("key".to_string(), "value".to_string())]
253                    .into_iter()
254                    .collect(),
255            })
256            .await
257            .unwrap();
258
259        assert!(
260            !task.is_finished(),
261            "Auth handler unexpectedly finished without signal"
262        );
263
264        assert_eq!(
265            response,
266            ChallengeResponse {
267                answers: vec!["some answer".to_string()],
268            }
269        );
270    }
271
272    #[test(tokio::test)]
273    async fn authenticator_verification_should_be_able_to_successfully_complete_round_trip() {
274        let (mut t1, mut t2) = FramedTransport::test_pair(100);
275
276        let task = tokio::spawn(async move {
277            t2.authenticate(TestAuthHandler {
278                on_verification: Box::new(|verification| {
279                    assert_eq!(verification.kind, VerificationKind::Host);
280                    assert_eq!(verification.text, "some text");
281                    Ok(VerificationResponse { valid: true })
282                }),
283                ..Default::default()
284            })
285            .await
286            .unwrap()
287        });
288
289        let response = t1
290            .verify(Verification {
291                kind: VerificationKind::Host,
292                text: "some text".to_string(),
293            })
294            .await
295            .unwrap();
296
297        assert!(
298            !task.is_finished(),
299            "Auth handler unexpectedly finished without signal"
300        );
301
302        assert_eq!(response, VerificationResponse { valid: true });
303    }
304
305    #[test(tokio::test)]
306    async fn authenticator_info_should_be_able_to_be_sent_to_auth_handler() {
307        let (mut t1, mut t2) = FramedTransport::test_pair(100);
308        let (tx, mut rx) = mpsc::channel(1);
309
310        let task = tokio::spawn(async move {
311            t2.authenticate(TestAuthHandler {
312                on_info: Box::new(move |info| {
313                    tx.try_send(info).unwrap();
314                    Ok(())
315                }),
316                ..Default::default()
317            })
318            .await
319            .unwrap()
320        });
321
322        t1.info(Info {
323            text: "some text".to_string(),
324        })
325        .await
326        .unwrap();
327
328        assert_eq!(
329            rx.recv().await.unwrap(),
330            Info {
331                text: "some text".to_string()
332            }
333        );
334
335        assert!(
336            !task.is_finished(),
337            "Auth handler unexpectedly finished without signal"
338        );
339    }
340
341    #[test(tokio::test)]
342    async fn authenticator_error_should_be_able_to_be_sent_to_auth_handler() {
343        let (mut t1, mut t2) = FramedTransport::test_pair(100);
344        let (tx, mut rx) = mpsc::channel(1);
345
346        let task = tokio::spawn(async move {
347            t2.authenticate(TestAuthHandler {
348                on_error: Box::new(move |error| {
349                    tx.try_send(error).unwrap();
350                    Ok(())
351                }),
352                ..Default::default()
353            })
354            .await
355            .unwrap()
356        });
357
358        t1.error(Error {
359            kind: ErrorKind::Error,
360            text: "some text".to_string(),
361        })
362        .await
363        .unwrap();
364
365        assert_eq!(
366            rx.recv().await.unwrap(),
367            Error {
368                kind: ErrorKind::Error,
369                text: "some text".to_string(),
370            }
371        );
372
373        assert!(
374            !task.is_finished(),
375            "Auth handler unexpectedly finished without signal"
376        );
377    }
378
379    #[test(tokio::test)]
380    async fn auth_handler_received_error_should_fail_auth_handler_if_fatal() {
381        let (mut t1, mut t2) = FramedTransport::test_pair(100);
382        let (tx, mut rx) = mpsc::channel(1);
383
384        let task = tokio::spawn(async move {
385            t2.authenticate(TestAuthHandler {
386                on_error: Box::new(move |error| {
387                    tx.try_send(error).unwrap();
388                    Ok(())
389                }),
390                ..Default::default()
391            })
392            .await
393            .unwrap()
394        });
395
396        t1.error(Error {
397            kind: ErrorKind::Fatal,
398            text: "some text".to_string(),
399        })
400        .await
401        .unwrap();
402
403        assert_eq!(
404            rx.recv().await.unwrap(),
405            Error {
406                kind: ErrorKind::Fatal,
407                text: "some text".to_string(),
408            }
409        );
410
411        // Verify that the handler exited with an error
412        task.await.unwrap_err();
413    }
414
415    #[test(tokio::test)]
416    async fn authenticator_start_method_should_be_able_to_be_sent_to_auth_handler() {
417        let (mut t1, mut t2) = FramedTransport::test_pair(100);
418        let (tx, mut rx) = mpsc::channel(1);
419
420        let task = tokio::spawn(async move {
421            t2.authenticate(TestAuthHandler {
422                on_start_method: Box::new(move |start_method| {
423                    tx.try_send(start_method).unwrap();
424                    Ok(())
425                }),
426                ..Default::default()
427            })
428            .await
429            .unwrap()
430        });
431
432        t1.start_method(StartMethod {
433            method: "some method".to_string(),
434        })
435        .await
436        .unwrap();
437
438        assert_eq!(
439            rx.recv().await.unwrap(),
440            StartMethod {
441                method: "some method".to_string()
442            }
443        );
444
445        assert!(
446            !task.is_finished(),
447            "Auth handler unexpectedly finished without signal"
448        );
449    }
450
451    #[test(tokio::test)]
452    async fn authenticator_finished_should_be_able_to_be_sent_to_auth_handler() {
453        let (mut t1, mut t2) = FramedTransport::test_pair(100);
454        let (tx, mut rx) = mpsc::channel(1);
455
456        let task = tokio::spawn(async move {
457            t2.authenticate(TestAuthHandler {
458                on_finished: Box::new(move || {
459                    tx.try_send(()).unwrap();
460                    Ok(())
461                }),
462                ..Default::default()
463            })
464            .await
465            .unwrap()
466        });
467
468        t1.finished().await.unwrap();
469
470        // Verify that the callback was triggered
471        rx.recv().await.unwrap();
472
473        // Finished should signal that the handler completed successfully
474        task.await.unwrap();
475    }
476}