1use std::io;
2
3use async_trait::async_trait;
4use distant_auth::msg::*;
5use distant_auth::{AuthHandler, Authenticate, Authenticator};
6use log::*;
7
8use crate::common::{utils, FramedTransport, Transport};
9
10macro_rules! write_frame {
11 ($transport:expr, $data:expr) => {{
12 let data = utils::serialize_to_vec(&$data)?;
13 if log_enabled!(Level::Trace) {
14 trace!("Writing data as frame: {data:?}");
15 }
16
17 $transport.write_frame(data).await?
18 }};
19}
20
21macro_rules! next_frame_as {
22 ($transport:expr, $type:ident, $variant:ident) => {{
23 match { next_frame_as!($transport, $type) } {
24 $type::$variant(x) => x,
25 x => {
26 return Err(io::Error::new(
27 io::ErrorKind::InvalidData,
28 format!("Unexpected frame: {x:?}"),
29 ))
30 }
31 }
32 }};
33 ($transport:expr, $type:ident) => {{
34 let frame = $transport.read_frame().await?.ok_or_else(|| {
35 io::Error::new(
36 io::ErrorKind::UnexpectedEof,
37 concat!(
38 "Transport closed early waiting for frame of type ",
39 stringify!($type),
40 ),
41 )
42 })?;
43
44 match utils::deserialize_from_slice::<$type>(frame.as_item()) {
45 Ok(frame) => frame,
46 Err(x) => {
47 if log_enabled!(Level::Trace) {
48 trace!(
49 "Failed to deserialize frame item as {}: {:?}",
50 stringify!($type),
51 frame.as_item()
52 );
53 }
54
55 Err(x)?;
56 unreachable!();
57 }
58 }
59 }};
60}
61
62#[async_trait]
63impl<T> Authenticate for FramedTransport<T>
64where
65 T: Transport,
66{
67 async fn authenticate(&mut self, mut handler: impl AuthHandler + Send) -> io::Result<()> {
68 loop {
69 trace!("Authenticate::authenticate waiting on next authentication frame");
70 match next_frame_as!(self, Authentication) {
71 Authentication::Initialization(x) => {
72 trace!("Authenticate::Initialization({x:?})");
73 let response = handler.on_initialization(x).await?;
74 write_frame!(self, AuthenticationResponse::Initialization(response));
75 }
76 Authentication::Challenge(x) => {
77 trace!("Authenticate::Challenge({x:?})");
78 let response = handler.on_challenge(x).await?;
79 write_frame!(self, AuthenticationResponse::Challenge(response));
80 }
81 Authentication::Verification(x) => {
82 trace!("Authenticate::Verify({x:?})");
83 let response = handler.on_verification(x).await?;
84 write_frame!(self, AuthenticationResponse::Verification(response));
85 }
86 Authentication::Info(x) => {
87 trace!("Authenticate::Info({x:?})");
88 handler.on_info(x).await?;
89 }
90 Authentication::Error(x) => {
91 trace!("Authenticate::Error({x:?})");
92 handler.on_error(x.clone()).await?;
93
94 if x.is_fatal() {
95 return Err(x.into_io_permission_denied());
96 }
97 }
98 Authentication::StartMethod(x) => {
99 trace!("Authenticate::StartMethod({x:?})");
100 handler.on_start_method(x).await?;
101 }
102 Authentication::Finished => {
103 trace!("Authenticate::Finished");
104 handler.on_finished().await?;
105 return Ok(());
106 }
107 }
108 }
109 }
110}
111
112#[async_trait]
113impl<T> Authenticator for FramedTransport<T>
114where
115 T: Transport,
116{
117 async fn initialize(
118 &mut self,
119 initialization: Initialization,
120 ) -> io::Result<InitializationResponse> {
121 trace!("Authenticator::initialize({initialization:?})");
122 write_frame!(self, Authentication::Initialization(initialization));
123 let response = next_frame_as!(self, AuthenticationResponse, Initialization);
124 Ok(response)
125 }
126
127 async fn challenge(&mut self, challenge: Challenge) -> io::Result<ChallengeResponse> {
128 trace!("Authenticator::challenge({challenge:?})");
129 write_frame!(self, Authentication::Challenge(challenge));
130 let response = next_frame_as!(self, AuthenticationResponse, Challenge);
131 Ok(response)
132 }
133
134 async fn verify(&mut self, verification: Verification) -> io::Result<VerificationResponse> {
135 trace!("Authenticator::verify({verification:?})");
136 write_frame!(self, Authentication::Verification(verification));
137 let response = next_frame_as!(self, AuthenticationResponse, Verification);
138 Ok(response)
139 }
140
141 async fn info(&mut self, info: Info) -> io::Result<()> {
142 trace!("Authenticator::info({info:?})");
143 write_frame!(self, Authentication::Info(info));
144 Ok(())
145 }
146
147 async fn error(&mut self, error: Error) -> io::Result<()> {
148 trace!("Authenticator::error({error:?})");
149 write_frame!(self, Authentication::Error(error));
150 Ok(())
151 }
152
153 async fn start_method(&mut self, start_method: StartMethod) -> io::Result<()> {
154 trace!("Authenticator::start_method({start_method:?})");
155 write_frame!(self, Authentication::StartMethod(start_method));
156 Ok(())
157 }
158
159 async fn finished(&mut self) -> io::Result<()> {
160 trace!("Authenticator::finished()");
161 write_frame!(self, Authentication::Finished);
162 Ok(())
163 }
164}
165
166#[cfg(test)]
167mod tests {
168 use distant_auth::tests::TestAuthHandler;
169 use test_log::test;
170 use tokio::sync::mpsc;
171
172 use super::*;
173
174 #[test(tokio::test)]
175 async fn authenticator_initialization_should_be_able_to_successfully_complete_round_trip() {
176 let (mut t1, mut t2) = FramedTransport::test_pair(100);
177
178 let task = tokio::spawn(async move {
179 t2.authenticate(TestAuthHandler {
180 on_initialization: Box::new(|x| Ok(InitializationResponse { methods: x.methods })),
181 ..Default::default()
182 })
183 .await
184 .unwrap()
185 });
186
187 let response = t1
188 .initialize(Initialization {
189 methods: vec!["test method".to_string()].into_iter().collect(),
190 })
191 .await
192 .unwrap();
193
194 assert!(
195 !task.is_finished(),
196 "Auth handler unexpectedly finished without signal"
197 );
198
199 assert_eq!(
200 response,
201 InitializationResponse {
202 methods: vec!["test method".to_string()].into_iter().collect()
203 }
204 );
205 }
206
207 #[test(tokio::test)]
208 async fn authenticator_challenge_should_be_able_to_successfully_complete_round_trip() {
209 let (mut t1, mut t2) = FramedTransport::test_pair(100);
210
211 let task = tokio::spawn(async move {
212 t2.authenticate(TestAuthHandler {
213 on_challenge: Box::new(|challenge| {
214 assert_eq!(
215 challenge.questions,
216 vec![Question {
217 label: "label".to_string(),
218 text: "text".to_string(),
219 options: vec![(
220 "question_key".to_string(),
221 "question_value".to_string()
222 )]
223 .into_iter()
224 .collect(),
225 }]
226 );
227 assert_eq!(
228 challenge.options,
229 vec![("key".to_string(), "value".to_string())]
230 .into_iter()
231 .collect(),
232 );
233 Ok(ChallengeResponse {
234 answers: vec!["some answer".to_string()].into_iter().collect(),
235 })
236 }),
237 ..Default::default()
238 })
239 .await
240 .unwrap()
241 });
242
243 let response = t1
244 .challenge(Challenge {
245 questions: vec![Question {
246 label: "label".to_string(),
247 text: "text".to_string(),
248 options: vec![("question_key".to_string(), "question_value".to_string())]
249 .into_iter()
250 .collect(),
251 }],
252 options: vec![("key".to_string(), "value".to_string())]
253 .into_iter()
254 .collect(),
255 })
256 .await
257 .unwrap();
258
259 assert!(
260 !task.is_finished(),
261 "Auth handler unexpectedly finished without signal"
262 );
263
264 assert_eq!(
265 response,
266 ChallengeResponse {
267 answers: vec!["some answer".to_string()],
268 }
269 );
270 }
271
272 #[test(tokio::test)]
273 async fn authenticator_verification_should_be_able_to_successfully_complete_round_trip() {
274 let (mut t1, mut t2) = FramedTransport::test_pair(100);
275
276 let task = tokio::spawn(async move {
277 t2.authenticate(TestAuthHandler {
278 on_verification: Box::new(|verification| {
279 assert_eq!(verification.kind, VerificationKind::Host);
280 assert_eq!(verification.text, "some text");
281 Ok(VerificationResponse { valid: true })
282 }),
283 ..Default::default()
284 })
285 .await
286 .unwrap()
287 });
288
289 let response = t1
290 .verify(Verification {
291 kind: VerificationKind::Host,
292 text: "some text".to_string(),
293 })
294 .await
295 .unwrap();
296
297 assert!(
298 !task.is_finished(),
299 "Auth handler unexpectedly finished without signal"
300 );
301
302 assert_eq!(response, VerificationResponse { valid: true });
303 }
304
305 #[test(tokio::test)]
306 async fn authenticator_info_should_be_able_to_be_sent_to_auth_handler() {
307 let (mut t1, mut t2) = FramedTransport::test_pair(100);
308 let (tx, mut rx) = mpsc::channel(1);
309
310 let task = tokio::spawn(async move {
311 t2.authenticate(TestAuthHandler {
312 on_info: Box::new(move |info| {
313 tx.try_send(info).unwrap();
314 Ok(())
315 }),
316 ..Default::default()
317 })
318 .await
319 .unwrap()
320 });
321
322 t1.info(Info {
323 text: "some text".to_string(),
324 })
325 .await
326 .unwrap();
327
328 assert_eq!(
329 rx.recv().await.unwrap(),
330 Info {
331 text: "some text".to_string()
332 }
333 );
334
335 assert!(
336 !task.is_finished(),
337 "Auth handler unexpectedly finished without signal"
338 );
339 }
340
341 #[test(tokio::test)]
342 async fn authenticator_error_should_be_able_to_be_sent_to_auth_handler() {
343 let (mut t1, mut t2) = FramedTransport::test_pair(100);
344 let (tx, mut rx) = mpsc::channel(1);
345
346 let task = tokio::spawn(async move {
347 t2.authenticate(TestAuthHandler {
348 on_error: Box::new(move |error| {
349 tx.try_send(error).unwrap();
350 Ok(())
351 }),
352 ..Default::default()
353 })
354 .await
355 .unwrap()
356 });
357
358 t1.error(Error {
359 kind: ErrorKind::Error,
360 text: "some text".to_string(),
361 })
362 .await
363 .unwrap();
364
365 assert_eq!(
366 rx.recv().await.unwrap(),
367 Error {
368 kind: ErrorKind::Error,
369 text: "some text".to_string(),
370 }
371 );
372
373 assert!(
374 !task.is_finished(),
375 "Auth handler unexpectedly finished without signal"
376 );
377 }
378
379 #[test(tokio::test)]
380 async fn auth_handler_received_error_should_fail_auth_handler_if_fatal() {
381 let (mut t1, mut t2) = FramedTransport::test_pair(100);
382 let (tx, mut rx) = mpsc::channel(1);
383
384 let task = tokio::spawn(async move {
385 t2.authenticate(TestAuthHandler {
386 on_error: Box::new(move |error| {
387 tx.try_send(error).unwrap();
388 Ok(())
389 }),
390 ..Default::default()
391 })
392 .await
393 .unwrap()
394 });
395
396 t1.error(Error {
397 kind: ErrorKind::Fatal,
398 text: "some text".to_string(),
399 })
400 .await
401 .unwrap();
402
403 assert_eq!(
404 rx.recv().await.unwrap(),
405 Error {
406 kind: ErrorKind::Fatal,
407 text: "some text".to_string(),
408 }
409 );
410
411 task.await.unwrap_err();
413 }
414
415 #[test(tokio::test)]
416 async fn authenticator_start_method_should_be_able_to_be_sent_to_auth_handler() {
417 let (mut t1, mut t2) = FramedTransport::test_pair(100);
418 let (tx, mut rx) = mpsc::channel(1);
419
420 let task = tokio::spawn(async move {
421 t2.authenticate(TestAuthHandler {
422 on_start_method: Box::new(move |start_method| {
423 tx.try_send(start_method).unwrap();
424 Ok(())
425 }),
426 ..Default::default()
427 })
428 .await
429 .unwrap()
430 });
431
432 t1.start_method(StartMethod {
433 method: "some method".to_string(),
434 })
435 .await
436 .unwrap();
437
438 assert_eq!(
439 rx.recv().await.unwrap(),
440 StartMethod {
441 method: "some method".to_string()
442 }
443 );
444
445 assert!(
446 !task.is_finished(),
447 "Auth handler unexpectedly finished without signal"
448 );
449 }
450
451 #[test(tokio::test)]
452 async fn authenticator_finished_should_be_able_to_be_sent_to_auth_handler() {
453 let (mut t1, mut t2) = FramedTransport::test_pair(100);
454 let (tx, mut rx) = mpsc::channel(1);
455
456 let task = tokio::spawn(async move {
457 t2.authenticate(TestAuthHandler {
458 on_finished: Box::new(move || {
459 tx.try_send(()).unwrap();
460 Ok(())
461 }),
462 ..Default::default()
463 })
464 .await
465 .unwrap()
466 });
467
468 t1.finished().await.unwrap();
469
470 rx.recv().await.unwrap();
472
473 task.await.unwrap();
475 }
476}