1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
use std::fmt;

use crate::common::Credentials;

#[cfg(feature = "scram")]
use crate::common::scram::DeriveError;
#[cfg(feature = "scram")]
use hmac::digest::InvalidLength;

#[derive(Debug, PartialEq)]
pub enum MechanismError {
    AnonymousRequiresNoCredentials,

    PlainRequiresUsername,
    PlainRequiresPlaintextPassword,

    CannotGenerateNonce,
    ScramRequiresUsername,
    ScramRequiresPassword,

    CannotDecodeChallenge,
    NoServerNonce,
    NoServerSalt,
    NoServerIterations,
    #[cfg(feature = "scram")]
    DeriveError(DeriveError),
    #[cfg(feature = "scram")]
    InvalidKeyLength(InvalidLength),
    InvalidState,

    CannotDecodeSuccessResponse,
    InvalidSignatureInSuccessResponse,
    NoSignatureInSuccessResponse,
}

#[cfg(feature = "scram")]
impl From<DeriveError> for MechanismError {
    fn from(err: DeriveError) -> MechanismError {
        MechanismError::DeriveError(err)
    }
}

#[cfg(feature = "scram")]
impl From<InvalidLength> for MechanismError {
    fn from(err: InvalidLength) -> MechanismError {
        MechanismError::InvalidKeyLength(err)
    }
}

impl fmt::Display for MechanismError {
    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
        write!(
            fmt,
            "{}",
            match self {
                MechanismError::AnonymousRequiresNoCredentials =>
                    "ANONYMOUS mechanism requires no credentials",

                MechanismError::PlainRequiresUsername => "PLAIN requires a username",
                MechanismError::PlainRequiresPlaintextPassword =>
                    "PLAIN requires a plaintext password",

                MechanismError::CannotGenerateNonce => "can't generate nonce",
                MechanismError::ScramRequiresUsername => "SCRAM requires a username",
                MechanismError::ScramRequiresPassword => "SCRAM requires a password",

                MechanismError::CannotDecodeChallenge => "can't decode challenge",
                MechanismError::NoServerNonce => "no server nonce",
                MechanismError::NoServerSalt => "no server salt",
                MechanismError::NoServerIterations => "no server iterations",
                #[cfg(feature = "scram")]
                MechanismError::DeriveError(err) => return write!(fmt, "derive error: {}", err),
                #[cfg(feature = "scram")]
                MechanismError::InvalidKeyLength(err) =>
                    return write!(fmt, "invalid key length: {}", err),
                MechanismError::InvalidState => "not in the right state to receive this response",

                MechanismError::CannotDecodeSuccessResponse => "can't decode success response",
                MechanismError::InvalidSignatureInSuccessResponse =>
                    "invalid signature in success response",
                MechanismError::NoSignatureInSuccessResponse => "no signature in success response",
            }
        )
    }
}

impl std::error::Error for MechanismError {}

/// A trait which defines SASL mechanisms.
pub trait Mechanism {
    /// The name of the mechanism.
    fn name(&self) -> &str;

    /// Creates this mechanism from `Credentials`.
    fn from_credentials(credentials: Credentials) -> Result<Self, MechanismError>
    where
        Self: Sized;

    /// Provides initial payload of the SASL mechanism.
    fn initial(&mut self) -> Vec<u8> {
        Vec::new()
    }

    /// Creates a response to the SASL challenge.
    fn response(&mut self, _challenge: &[u8]) -> Result<Vec<u8>, MechanismError> {
        Ok(Vec::new())
    }

    /// Verifies the server success response, if there is one.
    fn success(&mut self, _data: &[u8]) -> Result<(), MechanismError> {
        Ok(())
    }
}

pub mod mechanisms;