1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
use super::{
    AuthMethodHandler, Challenge, ChallengeResponse, Error, Info, Verification,
    VerificationResponse,
};
use crate::common::HeapSecretKey;
use async_trait::async_trait;
use log::*;
use std::io;

/// Implementation of [`AuthMethodHandler`] that answers challenge requests using a static
/// [`HeapSecretKey`]. All other portions of method authentication are handled by another
/// [`AuthMethodHandler`].
pub struct StaticKeyAuthMethodHandler {
    key: HeapSecretKey,
    handler: Box<dyn AuthMethodHandler>,
}

impl StaticKeyAuthMethodHandler {
    /// Creates a new [`StaticKeyAuthMethodHandler`] that responds to challenges using a static
    /// `key`. All other requests are passed to the `handler`.
    pub fn new<T: AuthMethodHandler + 'static>(key: impl Into<HeapSecretKey>, handler: T) -> Self {
        Self {
            key: key.into(),
            handler: Box::new(handler),
        }
    }

    /// Creates a new [`StaticKeyAuthMethodHandler`] that responds to challenges using a static
    /// `key`. All other requests are passed automatically, meaning that verification is always
    /// approvide and info/errors are ignored.
    pub fn simple(key: impl Into<HeapSecretKey>) -> Self {
        Self::new(key, {
            struct __AuthMethodHandler;

            #[async_trait]
            impl AuthMethodHandler for __AuthMethodHandler {
                async fn on_challenge(&mut self, _: Challenge) -> io::Result<ChallengeResponse> {
                    unreachable!("on_challenge should be handled by StaticKeyAuthMethodHandler");
                }

                async fn on_verification(
                    &mut self,
                    _: Verification,
                ) -> io::Result<VerificationResponse> {
                    Ok(VerificationResponse { valid: true })
                }

                async fn on_info(&mut self, _: Info) -> io::Result<()> {
                    Ok(())
                }

                async fn on_error(&mut self, _: Error) -> io::Result<()> {
                    Ok(())
                }
            }

            __AuthMethodHandler
        })
    }
}

#[async_trait]
impl AuthMethodHandler for StaticKeyAuthMethodHandler {
    async fn on_challenge(&mut self, challenge: Challenge) -> io::Result<ChallengeResponse> {
        trace!("on_challenge({challenge:?})");
        let mut answers = Vec::new();
        for question in challenge.questions.iter() {
            // Only challenges with a "key" label are allowed, all else will fail
            if question.label != "key" {
                return Err(io::Error::new(
                    io::ErrorKind::InvalidInput,
                    "Only 'key' challenges are supported",
                ));
            }
            answers.push(self.key.to_string());
        }
        Ok(ChallengeResponse { answers })
    }

    async fn on_verification(
        &mut self,
        verification: Verification,
    ) -> io::Result<VerificationResponse> {
        trace!("on_verify({verification:?})");
        self.handler.on_verification(verification).await
    }

    async fn on_info(&mut self, info: Info) -> io::Result<()> {
        trace!("on_info({info:?})");
        self.handler.on_info(info).await
    }

    async fn on_error(&mut self, error: Error) -> io::Result<()> {
        trace!("on_error({error:?})");
        self.handler.on_error(error).await
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::common::authentication::msg::{ErrorKind, Question, VerificationKind};
    use test_log::test;

    #[test(tokio::test)]
    async fn on_challenge_should_fail_if_non_key_question_received() {
        let mut handler = StaticKeyAuthMethodHandler::simple(HeapSecretKey::generate(32).unwrap());

        handler
            .on_challenge(Challenge {
                questions: vec![Question::new("test")],
                options: Default::default(),
            })
            .await
            .unwrap_err();
    }

    #[test(tokio::test)]
    async fn on_challenge_should_answer_with_stringified_key_for_key_questions() {
        let mut handler = StaticKeyAuthMethodHandler::simple(HeapSecretKey::generate(32).unwrap());

        let response = handler
            .on_challenge(Challenge {
                questions: vec![Question::new("key")],
                options: Default::default(),
            })
            .await
            .unwrap();
        assert_eq!(response.answers.len(), 1, "Wrong answer set received");
        assert!(!response.answers[0].is_empty(), "Empty answer being sent");
    }

    #[test(tokio::test)]
    async fn on_verification_should_leverage_fallback_handler() {
        let mut handler = StaticKeyAuthMethodHandler::simple(HeapSecretKey::generate(32).unwrap());

        let response = handler
            .on_verification(Verification {
                kind: VerificationKind::Host,
                text: "host".to_string(),
            })
            .await
            .unwrap();
        assert!(response.valid, "Unexpected result from fallback handler");
    }

    #[test(tokio::test)]
    async fn on_info_should_leverage_fallback_handler() {
        let mut handler = StaticKeyAuthMethodHandler::simple(HeapSecretKey::generate(32).unwrap());

        handler
            .on_info(Info {
                text: "info".to_string(),
            })
            .await
            .unwrap();
    }

    #[test(tokio::test)]
    async fn on_error_should_leverage_fallback_handler() {
        let mut handler = StaticKeyAuthMethodHandler::simple(HeapSecretKey::generate(32).unwrap());

        handler
            .on_error(Error {
                kind: ErrorKind::Error,
                text: "text".to_string(),
            })
            .await
            .unwrap();
    }
}