1use std::collections::HashMap;
2use std::io;
3use std::str::FromStr;
4
5use async_trait::async_trait;
6use log::*;
7
8use crate::authenticator::Authenticator;
9use crate::msg::*;
10
11mod none;
12mod static_key;
13
14pub use none::*;
15pub use static_key::*;
16
17pub struct Verifier {
19 methods: HashMap<&'static str, Box<dyn AuthenticationMethod>>,
20}
21
22impl Verifier {
23 pub fn new<I>(methods: I) -> Self
24 where
25 I: IntoIterator<Item = Box<dyn AuthenticationMethod>>,
26 {
27 let mut m = HashMap::new();
28
29 for method in methods {
30 m.insert(method.id(), method);
31 }
32
33 Self { methods: m }
34 }
35
36 pub fn empty() -> Self {
38 Self {
39 methods: HashMap::new(),
40 }
41 }
42
43 pub fn none() -> Self {
45 Self::new(vec![
46 Box::new(NoneAuthenticationMethod::new()) as Box<dyn AuthenticationMethod>
47 ])
48 }
49
50 pub fn static_key<K>(key: K) -> Self
52 where
53 K: FromStr + PartialEq + Send + Sync + 'static,
54 {
55 Self::new(vec![
56 Box::new(StaticKeyAuthenticationMethod::new(key)) as Box<dyn AuthenticationMethod>
57 ])
58 }
59
60 pub fn methods(&self) -> impl Iterator<Item = &'static str> + '_ {
62 self.methods.keys().copied()
63 }
64
65 pub async fn verify(&self, authenticator: &mut dyn Authenticator) -> io::Result<&'static str> {
68 let response = authenticator
70 .initialize(Initialization {
71 methods: self.methods.keys().map(ToString::to_string).collect(),
72 })
73 .await?;
74
75 for method in response.methods {
76 match self.methods.get(method.as_str()) {
77 Some(method) => {
78 authenticator
80 .start_method(StartMethod {
81 method: method.id().to_string(),
82 })
83 .await?;
84
85 if method.authenticate(authenticator).await.is_ok() {
87 authenticator.finished().await?;
88 return Ok(method.id());
89 }
90 }
91 None => {
92 trace!("Skipping authentication {method} as it is not available or supported");
93 }
94 }
95 }
96
97 Err(io::Error::new(
98 io::ErrorKind::PermissionDenied,
99 "No authentication method succeeded",
100 ))
101 }
102}
103
104impl From<Vec<Box<dyn AuthenticationMethod>>> for Verifier {
105 fn from(methods: Vec<Box<dyn AuthenticationMethod>>) -> Self {
106 Self::new(methods)
107 }
108}
109
110#[async_trait]
112pub trait AuthenticationMethod: Send + Sync {
113 fn id(&self) -> &'static str;
115
116 async fn authenticate(&self, authenticator: &mut dyn Authenticator) -> io::Result<()>;
119}
120
121#[cfg(test)]
122mod tests {
123 use std::sync::mpsc;
124
125 use test_log::test;
126
127 use super::*;
128 use crate::authenticator::TestAuthenticator;
129
130 struct SuccessAuthenticationMethod;
131
132 #[async_trait]
133 impl AuthenticationMethod for SuccessAuthenticationMethod {
134 fn id(&self) -> &'static str {
135 "success"
136 }
137
138 async fn authenticate(&self, _: &mut dyn Authenticator) -> io::Result<()> {
139 Ok(())
140 }
141 }
142
143 struct FailAuthenticationMethod;
144
145 #[async_trait]
146 impl AuthenticationMethod for FailAuthenticationMethod {
147 fn id(&self) -> &'static str {
148 "fail"
149 }
150
151 async fn authenticate(&self, _: &mut dyn Authenticator) -> io::Result<()> {
152 Err(io::Error::from(io::ErrorKind::Other))
153 }
154 }
155
156 #[test(tokio::test)]
157 async fn verifier_should_fail_to_verify_if_initialization_fails() {
158 let mut authenticator = TestAuthenticator {
159 initialize: Box::new(|_| Err(io::Error::from(io::ErrorKind::Other))),
160 ..Default::default()
161 };
162
163 let methods: Vec<Box<dyn AuthenticationMethod>> =
164 vec![Box::new(SuccessAuthenticationMethod)];
165 let verifier = Verifier::from(methods);
166 verifier.verify(&mut authenticator).await.unwrap_err();
167 }
168
169 #[test(tokio::test)]
170 async fn verifier_should_fail_to_verify_if_fails_to_send_finished_indicator_after_success() {
171 let mut authenticator = TestAuthenticator {
172 initialize: Box::new(|_| {
173 Ok(InitializationResponse {
174 methods: vec![SuccessAuthenticationMethod.id().to_string()]
175 .into_iter()
176 .collect(),
177 })
178 }),
179 finished: Box::new(|| Err(io::Error::new(io::ErrorKind::Other, "test error"))),
180 ..Default::default()
181 };
182
183 let methods: Vec<Box<dyn AuthenticationMethod>> =
184 vec![Box::new(SuccessAuthenticationMethod)];
185 let verifier = Verifier::from(methods);
186
187 let err = verifier.verify(&mut authenticator).await.unwrap_err();
188 assert_eq!(err.kind(), io::ErrorKind::Other);
189 assert_eq!(err.to_string(), "test error");
190 }
191
192 #[test(tokio::test)]
193 async fn verifier_should_fail_to_verify_if_has_no_authentication_methods() {
194 let mut authenticator = TestAuthenticator {
195 initialize: Box::new(|_| {
196 Ok(InitializationResponse {
197 methods: vec![SuccessAuthenticationMethod.id().to_string()]
198 .into_iter()
199 .collect(),
200 })
201 }),
202 ..Default::default()
203 };
204
205 let methods: Vec<Box<dyn AuthenticationMethod>> = vec![];
206 let verifier = Verifier::from(methods);
207 verifier.verify(&mut authenticator).await.unwrap_err();
208 }
209
210 #[test(tokio::test)]
211 async fn verifier_should_fail_to_verify_if_initialization_yields_no_valid_authentication_methods(
212 ) {
213 let mut authenticator = TestAuthenticator {
214 initialize: Box::new(|_| {
215 Ok(InitializationResponse {
216 methods: vec!["other".to_string()].into_iter().collect(),
217 })
218 }),
219 ..Default::default()
220 };
221
222 let methods: Vec<Box<dyn AuthenticationMethod>> =
223 vec![Box::new(SuccessAuthenticationMethod)];
224 let verifier = Verifier::from(methods);
225 verifier.verify(&mut authenticator).await.unwrap_err();
226 }
227
228 #[test(tokio::test)]
229 async fn verifier_should_fail_to_verify_if_no_authentication_method_succeeds() {
230 let mut authenticator = TestAuthenticator {
231 initialize: Box::new(|_| {
232 Ok(InitializationResponse {
233 methods: vec![FailAuthenticationMethod.id().to_string()]
234 .into_iter()
235 .collect(),
236 })
237 }),
238 ..Default::default()
239 };
240
241 let methods: Vec<Box<dyn AuthenticationMethod>> = vec![Box::new(FailAuthenticationMethod)];
242 let verifier = Verifier::from(methods);
243 verifier.verify(&mut authenticator).await.unwrap_err();
244 }
245
246 #[test(tokio::test)]
247 async fn verifier_should_return_id_of_authentication_method_upon_success() {
248 let mut authenticator = TestAuthenticator {
249 initialize: Box::new(|_| {
250 Ok(InitializationResponse {
251 methods: vec![SuccessAuthenticationMethod.id().to_string()]
252 .into_iter()
253 .collect(),
254 })
255 }),
256 ..Default::default()
257 };
258
259 let methods: Vec<Box<dyn AuthenticationMethod>> =
260 vec![Box::new(SuccessAuthenticationMethod)];
261 let verifier = Verifier::from(methods);
262 assert_eq!(
263 verifier.verify(&mut authenticator).await.unwrap(),
264 SuccessAuthenticationMethod.id()
265 );
266 }
267
268 #[test(tokio::test)]
269 async fn verifier_should_try_authentication_methods_in_order_until_one_succeeds() {
270 let mut authenticator = TestAuthenticator {
271 initialize: Box::new(|_| {
272 Ok(InitializationResponse {
273 methods: vec![
274 FailAuthenticationMethod.id().to_string(),
275 SuccessAuthenticationMethod.id().to_string(),
276 ]
277 .into_iter()
278 .collect(),
279 })
280 }),
281 ..Default::default()
282 };
283
284 let methods: Vec<Box<dyn AuthenticationMethod>> = vec![
285 Box::new(FailAuthenticationMethod),
286 Box::new(SuccessAuthenticationMethod),
287 ];
288 let verifier = Verifier::from(methods);
289 assert_eq!(
290 verifier.verify(&mut authenticator).await.unwrap(),
291 SuccessAuthenticationMethod.id()
292 );
293 }
294
295 #[test(tokio::test)]
296 async fn verifier_should_send_start_method_before_attempting_each_method() {
297 let (tx, rx) = mpsc::channel();
298
299 let mut authenticator = TestAuthenticator {
300 initialize: Box::new(|_| {
301 Ok(InitializationResponse {
302 methods: vec![
303 FailAuthenticationMethod.id().to_string(),
304 SuccessAuthenticationMethod.id().to_string(),
305 ]
306 .into_iter()
307 .collect(),
308 })
309 }),
310 start_method: Box::new(move |method| {
311 tx.send(method.method).unwrap();
312 Ok(())
313 }),
314 ..Default::default()
315 };
316
317 let methods: Vec<Box<dyn AuthenticationMethod>> = vec![
318 Box::new(FailAuthenticationMethod),
319 Box::new(SuccessAuthenticationMethod),
320 ];
321 Verifier::from(methods)
322 .verify(&mut authenticator)
323 .await
324 .unwrap();
325
326 assert_eq!(rx.try_recv().unwrap(), FailAuthenticationMethod.id());
327 assert_eq!(rx.try_recv().unwrap(), SuccessAuthenticationMethod.id());
328 assert_eq!(rx.try_recv().unwrap_err(), mpsc::TryRecvError::Empty);
329 }
330
331 #[test(tokio::test)]
332 async fn verifier_should_send_finished_when_a_method_succeeds() {
333 let (tx, rx) = mpsc::channel();
334
335 let mut authenticator = TestAuthenticator {
336 initialize: Box::new(|_| {
337 Ok(InitializationResponse {
338 methods: vec![
339 FailAuthenticationMethod.id().to_string(),
340 SuccessAuthenticationMethod.id().to_string(),
341 ]
342 .into_iter()
343 .collect(),
344 })
345 }),
346 finished: Box::new(move || {
347 tx.send(()).unwrap();
348 Ok(())
349 }),
350 ..Default::default()
351 };
352
353 let methods: Vec<Box<dyn AuthenticationMethod>> = vec![
354 Box::new(FailAuthenticationMethod),
355 Box::new(SuccessAuthenticationMethod),
356 ];
357 Verifier::from(methods)
358 .verify(&mut authenticator)
359 .await
360 .unwrap();
361
362 rx.try_recv().unwrap();
363 assert_eq!(rx.try_recv().unwrap_err(), mpsc::TryRecvError::Empty);
364 }
365}