gel_auth/handshake/
client_auth.rs

1use crate::{
2    md5::md5_password,
3    scram::{
4        generate_nonce, generate_salted_password, ClientEnvironment, ClientTransaction, SCRAMError,
5        Sha256Out,
6    },
7    AuthType, CredentialData,
8};
9use tracing::error;
10
11#[derive(Debug)]
12pub enum ClientAuthResponse {
13    Initial(AuthType, Vec<u8>),
14    Continue(Vec<u8>),
15    Complete,
16    Waiting,
17    Error(ClientAuthError),
18}
19
20#[derive(Debug, thiserror::Error)]
21pub enum ClientAuthError {
22    #[error("SCRAM protocol error: {0}")]
23    ScramError(#[from] SCRAMError),
24    #[error("Invalid authentication state")]
25    InvalidState,
26    #[error("Invalid credentials")]
27    InvalidCredentials,
28    #[error("Unexpected message during authentication")]
29    UnexpectedMessage,
30}
31
32#[derive(Debug)]
33enum ClientAuthState {
34    Initial(String, CredentialData),
35    Complete,
36    Waiting,
37    Sasl(ClientTransaction, ClientEnvironmentImpl),
38}
39
40#[derive(Debug)]
41pub enum ClientAuthDrive<'a> {
42    /// Authentication is successful.
43    Ok,
44    /// Server requested plain authentication.
45    Plain,
46    /// Server requested MD5 authentication (with salt).
47    Md5([u8; 4]),
48    /// Server requested SCRAM authentication.
49    Scram,
50    /// Server sent SCRAM message.
51    ScramResponse(&'a [u8]),
52}
53
54#[derive(Debug)]
55pub struct ClientAuth {
56    state: ClientAuthState,
57    auth_type: Option<AuthType>,
58}
59
60impl ClientAuth {
61    /// Create a new client authentication state.
62    pub fn new(username: String, credentials: CredentialData) -> Self {
63        Self {
64            state: ClientAuthState::Initial(username, credentials),
65            auth_type: None,
66        }
67    }
68
69    pub fn is_complete(&self) -> bool {
70        matches!(self.state, ClientAuthState::Complete)
71    }
72
73    pub fn auth_type(&self) -> Option<AuthType> {
74        self.auth_type
75    }
76
77    pub fn drive(&mut self, drive: ClientAuthDrive) -> Result<ClientAuthResponse, ClientAuthError> {
78        match (&mut self.state, drive) {
79            (ClientAuthState::Initial(username, credentials), drive) => {
80                let username = std::mem::take(username);
81                let credentials = std::mem::replace(credentials, CredentialData::Deny);
82                self.handle_initial(username, credentials, drive)
83            }
84            // SCRAM authentication: Handle SCRAM protocol messages.
85            (ClientAuthState::Sasl(tx, env), ClientAuthDrive::ScramResponse(message)) => {
86                let response = tx.process_message(&message, env)?;
87                match response {
88                    Some(response) => Ok(ClientAuthResponse::Continue(response)),
89                    None => {
90                        self.state = ClientAuthState::Waiting;
91                        Ok(ClientAuthResponse::Waiting)
92                    }
93                }
94            }
95            // Handle "Ok" drive (authentication successful).
96            (ClientAuthState::Waiting, ClientAuthDrive::Ok) => {
97                self.state = ClientAuthState::Complete;
98                Ok(ClientAuthResponse::Complete)
99            }
100            // Invalid state/drive combination.
101            (_, drive) => {
102                error!("Received invalid drive {drive:?} in state {:?}", self.state);
103                Err(ClientAuthError::InvalidState)
104            }
105        }
106    }
107
108    fn handle_initial(
109        &mut self,
110        username: String,
111        credentials: CredentialData,
112        drive: ClientAuthDrive,
113    ) -> Result<ClientAuthResponse, ClientAuthError> {
114        let (auth_type, (state, response)) = match drive {
115            ClientAuthDrive::Ok => (
116                AuthType::Trust,
117                match credentials {
118                    CredentialData::Deny => (
119                        ClientAuthState::Complete,
120                        ClientAuthResponse::Error(ClientAuthError::InvalidCredentials),
121                    ),
122                    _ => (ClientAuthState::Complete, ClientAuthResponse::Complete),
123                },
124            ),
125            ClientAuthDrive::Plain => (
126                AuthType::Plain,
127                match credentials {
128                    CredentialData::Plain(credentials) => (
129                        ClientAuthState::Waiting,
130                        ClientAuthResponse::Initial(
131                            AuthType::Plain,
132                            credentials.clone().into_bytes(),
133                        ),
134                    ),
135                    _ => (
136                        ClientAuthState::Complete,
137                        ClientAuthResponse::Error(ClientAuthError::InvalidCredentials),
138                    ),
139                },
140            ),
141            ClientAuthDrive::Md5(salt) => (
142                AuthType::Md5,
143                match credentials {
144                    CredentialData::Md5(credentials) => (
145                        ClientAuthState::Waiting,
146                        ClientAuthResponse::Initial(
147                            AuthType::Md5,
148                            credentials.salted(salt).into_bytes(),
149                        ),
150                    ),
151                    CredentialData::Plain(credentials) => (
152                        ClientAuthState::Waiting,
153                        ClientAuthResponse::Initial(
154                            AuthType::Md5,
155                            md5_password(&credentials, &username, salt).into_bytes(),
156                        ),
157                    ),
158                    _ => (
159                        ClientAuthState::Complete,
160                        ClientAuthResponse::Error(ClientAuthError::InvalidCredentials),
161                    ),
162                },
163            ),
164            ClientAuthDrive::Scram => (
165                AuthType::ScramSha256,
166                match credentials {
167                    CredentialData::Plain(credentials) => {
168                        let env = ClientEnvironmentImpl {
169                            password: credentials,
170                        };
171                        let mut tx = ClientTransaction::new(username.into());
172                        let response = tx.process_message(&[], &env);
173                        match response {
174                            Ok(Some(response)) => (
175                                ClientAuthState::Sasl(tx, env),
176                                ClientAuthResponse::Initial(AuthType::ScramSha256, response),
177                            ),
178                            Ok(None) => (
179                                ClientAuthState::Complete,
180                                ClientAuthResponse::Error(ClientAuthError::InvalidCredentials),
181                            ),
182                            Err(e) => (
183                                ClientAuthState::Complete,
184                                ClientAuthResponse::Error(ClientAuthError::ScramError(e)),
185                            ),
186                        }
187                    }
188                    _ => (
189                        ClientAuthState::Complete,
190                        ClientAuthResponse::Error(ClientAuthError::InvalidCredentials),
191                    ),
192                },
193            ),
194            _ => {
195                error!("Received invalid drive {drive:?} in state Initial");
196                return Err(ClientAuthError::InvalidState);
197            }
198        };
199
200        self.auth_type = Some(auth_type);
201        self.state = state;
202        Ok(response)
203    }
204}
205
206#[derive(Debug)]
207struct ClientEnvironmentImpl {
208    password: String,
209}
210
211impl ClientEnvironment for ClientEnvironmentImpl {
212    fn generate_nonce(&self) -> String {
213        generate_nonce()
214    }
215
216    fn get_salted_password(&self, salt: &[u8], iterations: usize) -> Sha256Out {
217        generate_salted_password(self.password.as_bytes(), salt, iterations)
218    }
219}