gel_auth/handshake/
server_auth.rs

1use crate::{
2    md5::StoredHash,
3    scram::{SCRAMError, ServerTransaction, StoredKey},
4    AuthType, CredentialData,
5};
6use tracing::error;
7
8#[derive(Debug)]
9pub enum ServerAuthResponse {
10    Initial(AuthType, Vec<u8>),
11    Continue(Vec<u8>),
12    Complete(Vec<u8>),
13    Error(ServerAuthError),
14}
15
16#[derive(Debug, derive_more::Error, derive_more::Display, derive_more::From)]
17pub enum ServerAuthError {
18    #[display("Invalid authorization specification")]
19    InvalidAuthorizationSpecification,
20    #[display("Invalid password")]
21    InvalidPassword,
22    #[display("Invalid SASL message ({_0})")]
23    InvalidSaslMessage(#[from] SCRAMError),
24    #[display("Unsupported authentication type")]
25    UnsupportedAuthType,
26    #[display("Invalid message type")]
27    InvalidMessageType,
28}
29
30#[derive(Debug)]
31enum ServerAuthState {
32    Initial,
33    Password(CredentialData),
34    MD5([u8; 4], CredentialData),
35    Sasl(ServerTransaction, StoredKey),
36    Complete,
37}
38
39#[derive(Debug)]
40pub enum ServerAuthDrive<'a> {
41    Initial,
42    Message(AuthType, &'a [u8]),
43}
44
45#[derive(Debug)]
46pub struct ServerAuth {
47    state: ServerAuthState,
48    username: String,
49    auth_type: AuthType,
50    credential_data: CredentialData,
51}
52
53impl ServerAuth {
54    pub fn new(username: String, auth_type: AuthType, credential_data: CredentialData) -> Self {
55        Self {
56            state: ServerAuthState::Initial,
57            username,
58            auth_type,
59            credential_data,
60        }
61    }
62
63    pub fn is_complete(&self) -> bool {
64        matches!(self.state, ServerAuthState::Complete)
65    }
66
67    pub fn is_initial_message(&self) -> bool {
68        match &self.state {
69            ServerAuthState::Initial => false,
70            ServerAuthState::Sasl(tx, _) => tx.initial(),
71            _ => true,
72        }
73    }
74
75    pub fn auth_type(&self) -> AuthType {
76        self.auth_type
77    }
78
79    pub fn drive(&mut self, drive: ServerAuthDrive) -> ServerAuthResponse {
80        match (&mut self.state, drive) {
81            (ServerAuthState::Initial, ServerAuthDrive::Initial) => self.handle_initial(),
82            (ServerAuthState::Password(data), ServerAuthDrive::Message(AuthType::Plain, input)) => {
83                let client_password = input;
84                let success = match data {
85                    CredentialData::Deny => false,
86                    CredentialData::Trust => true,
87                    CredentialData::Plain(password) => client_password == password.as_bytes(),
88                    CredentialData::Md5(md5) => {
89                        let md5_1 = StoredHash::generate(client_password, &self.username);
90                        md5_1 == *md5
91                    }
92                    CredentialData::Scram(scram) => {
93                        let key =
94                            StoredKey::generate(client_password, &scram.salt, scram.iterations);
95                        key.stored_key == scram.stored_key
96                    }
97                };
98                self.state = ServerAuthState::Complete;
99                if success {
100                    ServerAuthResponse::Complete(Vec::new())
101                } else {
102                    ServerAuthResponse::Error(ServerAuthError::InvalidPassword)
103                }
104            }
105            (ServerAuthState::MD5(salt, data), ServerAuthDrive::Message(AuthType::Md5, input)) => {
106                let success = match data {
107                    CredentialData::Deny => false,
108                    CredentialData::Trust => true,
109                    CredentialData::Plain(password) => {
110                        let server_md5 = StoredHash::generate(password.as_bytes(), &self.username);
111                        server_md5.matches(input, *salt)
112                    }
113                    CredentialData::Md5(server_md5) => server_md5.matches(input, *salt),
114                    CredentialData::Scram(_) => {
115                        // Unreachable
116                        false
117                    }
118                };
119
120                self.state = ServerAuthState::Complete;
121                if success {
122                    ServerAuthResponse::Complete(Vec::new())
123                } else {
124                    ServerAuthResponse::Error(ServerAuthError::InvalidPassword)
125                }
126            }
127            (
128                ServerAuthState::Sasl(tx, data),
129                ServerAuthDrive::Message(AuthType::ScramSha256, input),
130            ) => {
131                let initial = tx.initial();
132                match tx.process_message(input, data) {
133                    Ok(final_message) => {
134                        if initial {
135                            ServerAuthResponse::Continue(final_message)
136                        } else {
137                            self.state = ServerAuthState::Complete;
138                            ServerAuthResponse::Complete(final_message)
139                        }
140                    }
141                    Err(e) => {
142                        self.state = ServerAuthState::Complete;
143                        ServerAuthResponse::Error(ServerAuthError::InvalidSaslMessage(e))
144                    }
145                }
146            }
147            (_, drive) => {
148                self.state = ServerAuthState::Complete;
149                error!("Received invalid drive {drive:?} in state {:?}", self.state);
150                ServerAuthResponse::Error(ServerAuthError::InvalidMessageType)
151            }
152        }
153    }
154
155    fn handle_initial(&mut self) -> ServerAuthResponse {
156        match self.auth_type {
157            AuthType::Deny => {
158                self.state = ServerAuthState::Complete;
159                ServerAuthResponse::Error(ServerAuthError::InvalidAuthorizationSpecification)
160            }
161            AuthType::Trust => {
162                self.state = ServerAuthState::Complete;
163                ServerAuthResponse::Complete(Vec::new())
164            }
165            AuthType::Plain => {
166                self.state = ServerAuthState::Password(self.credential_data.clone());
167                ServerAuthResponse::Initial(AuthType::Plain, Vec::new())
168            }
169            AuthType::Md5 => {
170                let salt: [u8; 4] = rand::random();
171                match self.credential_data {
172                    CredentialData::Scram(..) => {
173                        ServerAuthResponse::Error(ServerAuthError::UnsupportedAuthType)
174                    }
175                    _ => {
176                        self.state = ServerAuthState::MD5(salt, self.credential_data.clone());
177                        ServerAuthResponse::Initial(AuthType::Md5, salt.into())
178                    }
179                }
180            }
181            AuthType::ScramSha256 => {
182                let salt: [u8; 32] = rand::random();
183                let scram = match &self.credential_data {
184                    CredentialData::Scram(scram) => scram.clone(),
185                    CredentialData::Plain(password) => {
186                        StoredKey::generate(password.as_bytes(), &salt, 4096)
187                    }
188                    CredentialData::Deny => StoredKey::generate(b"", &salt, 4096),
189                    _ => {
190                        return ServerAuthResponse::Error(ServerAuthError::UnsupportedAuthType);
191                    }
192                };
193                let tx = ServerTransaction::default();
194                self.state = ServerAuthState::Sasl(tx, scram);
195                ServerAuthResponse::Initial(AuthType::ScramSha256, Vec::new())
196            }
197        }
198    }
199}