gel_auth/postgres/
client_state_machine.rs

1use super::{invalid_state, ConnectionError, ConnectionSslRequirement, Credentials};
2use crate::md5::md5_password;
3use crate::postgres::SslError;
4use crate::scram::{
5    generate_salted_password, ClientEnvironment, ClientTransaction, SCRAMError, Sha256Out,
6};
7use crate::AuthType;
8use base64::Engine;
9use gel_pg_protocol::{errors::PgServerError, prelude::*, protocol::*};
10use tracing::{error, trace, warn};
11
12#[derive(Debug)]
13struct ClientEnvironmentImpl {
14    credentials: Credentials,
15}
16
17impl ClientEnvironment for ClientEnvironmentImpl {
18    fn generate_nonce(&self) -> String {
19        let nonce: [u8; 32] = rand::random();
20        base64::engine::general_purpose::STANDARD.encode(nonce)
21    }
22    fn get_salted_password(&self, salt: &[u8], iterations: usize) -> Sha256Out {
23        generate_salted_password(self.credentials.password.as_bytes(), salt, iterations)
24    }
25}
26
27#[derive(Debug)]
28enum ConnectionStateImpl {
29    /// Uninitialized connection state. Requires an initialization message to
30    /// start.
31    SslInitializing(Credentials, ConnectionSslRequirement),
32    /// SSL upgrade message was sent, awaiting server response.
33    SslWaiting(Credentials, ConnectionSslRequirement),
34    /// SSL upgrade in progress, waiting for handshake to complete.
35    SslConnecting(Credentials),
36    /// Uninitialized connection state. Requires an initialization message to
37    /// start.
38    Initializing(Credentials),
39    /// The initial connection string has been sent and we are waiting for an
40    /// auth response.
41    Connecting(Credentials, bool),
42    /// The server has requested SCRAM auth. This holds a sub-state-machine that
43    /// manages a SCRAM challenge.
44    Scram(ClientTransaction, ClientEnvironmentImpl),
45    /// The authentication is successful and we are synchronizing server
46    /// parameters.
47    Connected,
48    /// The server is ready for queries.
49    Ready,
50    /// The connection failed.
51    Error,
52}
53
54#[derive(Clone, Copy, Debug)]
55pub enum ConnectionStateType {
56    Connecting,
57    SslConnecting,
58    Authenticating,
59    Synchronizing,
60    Ready,
61}
62
63#[derive(Debug)]
64pub enum ConnectionDrive<'a> {
65    Initial,
66    Message(Result<Message<'a>, ParseError>),
67    SslResponse(SSLResponse<'a>),
68    SslReady,
69}
70
71pub trait ConnectionStateSend {
72    fn send_initial<'a, M>(
73        &mut self,
74        message: impl IntoInitialBuilder<'a, M>,
75    ) -> Result<(), std::io::Error>;
76    fn send<'a, M>(
77        &mut self,
78        message: impl IntoFrontendBuilder<'a, M>,
79    ) -> Result<(), std::io::Error>;
80    fn upgrade(&mut self) -> Result<(), std::io::Error>;
81}
82
83/// A callback for connection state changes.
84#[allow(unused)]
85pub trait ConnectionStateUpdate: ConnectionStateSend {
86    fn parameter(&mut self, name: &str, value: &str) {}
87    fn cancellation_key(&mut self, pid: i32, key: i32) {}
88    fn state_changed(&mut self, state: ConnectionStateType) {}
89    fn server_error(&mut self, error: &PgServerError) {
90        error!("Server error during handshake: {:?}", error);
91    }
92    fn server_notice(&mut self, notice: &PgServerError) {
93        warn!("Server notice during handshake: {:?}", notice);
94    }
95    fn auth(&mut self, auth: AuthType) {}
96}
97
98/// ASCII state diagram for the connection state machine
99///
100/// ```mermaid
101/// stateDiagram-v2
102///     [*] --> SslInitializing: SSL not disabled
103///     [*] --> Initializing: SSL disabled
104///     SslInitializing --> SslWaiting: Send SSL request
105///     SslWaiting --> SslConnecting: SSL accepted
106///     SslWaiting --> Connecting: SSL rejected (if not required)
107///     SslConnecting --> Connecting: SSL handshake complete
108///     Initializing --> Connecting: Send startup message
109///     Connecting --> Connected: Authentication successful
110///     Connecting --> Scram: SCRAM auth requested
111///     Scram --> Connected: SCRAM auth successful
112///     Connected --> Ready: Parameter sync complete
113///     Ready --> [*]: Connection closed
114///     state Error {
115///         [*] --> [*]: Any state can transition to Error
116///     }
117/// ```
118///
119/// The state machine for a Postgres connection. The state machine is driven
120/// with calls to [`Self::drive`].
121#[derive(Debug)]
122pub struct ConnectionState(ConnectionStateImpl);
123
124impl ConnectionState {
125    pub fn new(credentials: Credentials, ssl_mode: ConnectionSslRequirement) -> Self {
126        if ssl_mode == ConnectionSslRequirement::Disable {
127            Self(ConnectionStateImpl::Initializing(credentials))
128        } else {
129            Self(ConnectionStateImpl::SslInitializing(credentials, ssl_mode))
130        }
131    }
132
133    pub fn is_ready(&self) -> bool {
134        matches!(self.0, ConnectionStateImpl::Ready)
135    }
136
137    pub fn is_error(&self) -> bool {
138        matches!(self.0, ConnectionStateImpl::Error)
139    }
140
141    pub fn is_done(&self) -> bool {
142        self.is_ready() || self.is_error()
143    }
144
145    pub fn read_ssl_response(&self) -> bool {
146        matches!(self.0, ConnectionStateImpl::SslWaiting(..))
147    }
148
149    pub fn drive(
150        &mut self,
151        drive: ConnectionDrive,
152        update: &mut impl ConnectionStateUpdate,
153    ) -> Result<(), ConnectionError> {
154        use ConnectionStateImpl::*;
155        trace!("Received drive {drive:?} in state {:?}", self.0);
156        match (&mut self.0, drive) {
157            (SslInitializing(credentials, mode), ConnectionDrive::Initial) => {
158                update.send_initial(&SSLRequestBuilder::default())?;
159                self.0 = SslWaiting(std::mem::take(credentials), *mode);
160                update.state_changed(ConnectionStateType::Connecting);
161            }
162            (SslWaiting(credentials, mode), ConnectionDrive::SslResponse(response)) => {
163                if *mode == ConnectionSslRequirement::Disable {
164                    // Should not be possible
165                    return Err(invalid_state!("SSL mode is Disable in SslWaiting state"));
166                }
167
168                if response.code() == b'S' {
169                    // Accepted
170                    update.upgrade()?;
171                    self.0 = SslConnecting(std::mem::take(credentials));
172                    update.state_changed(ConnectionStateType::SslConnecting);
173                } else if response.code() == b'N' {
174                    // Rejected
175                    if *mode == ConnectionSslRequirement::Required {
176                        return Err(ConnectionError::SslError(SslError::SslRequiredByClient));
177                    }
178                    Self::send_startup_message(credentials, update)?;
179                    self.0 = Connecting(std::mem::take(credentials), false);
180                } else {
181                    return Err(ConnectionError::UnexpectedResponse(format!(
182                        "Unexpected SSL response from server: {:?}",
183                        response.code() as char
184                    )));
185                }
186            }
187            (SslConnecting(credentials), ConnectionDrive::SslReady) => {
188                Self::send_startup_message(credentials, update)?;
189                self.0 = Connecting(std::mem::take(credentials), false);
190            }
191            (Initializing(credentials), ConnectionDrive::Initial) => {
192                Self::send_startup_message(credentials, update)?;
193                self.0 = Connecting(std::mem::take(credentials), false);
194                update.state_changed(ConnectionStateType::Connecting);
195            }
196            (Connecting(credentials, sent_auth), ConnectionDrive::Message(message)) => {
197                match_message!(message, Backend {
198                    (AuthenticationOk) => {
199                        if !*sent_auth {
200                            update.auth(AuthType::Trust);
201                        }
202                        trace!("auth ok");
203                        self.0 = Connected;
204                        update.state_changed(ConnectionStateType::Synchronizing);
205                    },
206                    (AuthenticationSASL as sasl) => {
207                        *sent_auth = true;
208                        let mut found_scram_sha256 = false;
209                        for mech in sasl.mechanisms() {
210                            trace!("auth sasl: {:?}", mech);
211                            if mech == "SCRAM-SHA-256" {
212                                found_scram_sha256 = true;
213                                break;
214                            }
215                        }
216                        if !found_scram_sha256 {
217                            return Err(ConnectionError::UnexpectedResponse("Server requested SASL authentication but does not support SCRAM-SHA-256".into()));
218                        }
219                        let credentials = credentials.clone();
220                        let mut tx = ClientTransaction::new("".into());
221                        let env = ClientEnvironmentImpl { credentials };
222                        let Some(initial_message) = tx.process_message(&[], &env)? else {
223                            return Err(SCRAMError::ProtocolError.into());
224                        };
225                        update.auth(AuthType::ScramSha256);
226                        update.send(&SASLInitialResponseBuilder {
227                            mechanism: "SCRAM-SHA-256",
228                            response: initial_message.as_slice(),
229                        })?;
230                        self.0 = Scram(tx, env);
231                        update.state_changed(ConnectionStateType::Authenticating);
232                    },
233                    (AuthenticationMD5Password as md5) => {
234                        *sent_auth = true;
235                        trace!("auth md5");
236                        let md5_hash = md5_password(&credentials.password, &credentials.username, md5.salt());
237                        update.auth(AuthType::Md5);
238                        update.send(&PasswordMessageBuilder {
239                            password: &md5_hash,
240                        })?;
241                    },
242                    (AuthenticationCleartextPassword) => {
243                        *sent_auth = true;
244                        trace!("auth cleartext");
245                        update.auth(AuthType::Plain);
246                        update.send(&PasswordMessageBuilder {
247                            password: &credentials.password,
248                        })?;
249                    },
250                    (NoticeResponse as notice) => {
251                        let err = PgServerError::from(notice);
252                        update.server_notice(&err);
253                    },
254                    (ErrorResponse as error) => {
255                        self.0 = Error;
256                        let err = PgServerError::from(error);
257                        update.server_error(&err);
258                        return Err(err.into());
259                    },
260                    message => {
261                        log_unknown_message(message, "Connecting")?
262                    },
263                });
264            }
265            (Scram(tx, env), ConnectionDrive::Message(message)) => {
266                match_message!(message, Backend {
267                    (AuthenticationSASLContinue as sasl) => {
268                        let Some(message) = tx.process_message(&sasl.data(), env)? else {
269                            return Err(SCRAMError::ProtocolError.into());
270                        };
271                        update.send(&SASLResponseBuilder {
272                            response: &message,
273                        })?;
274                    },
275                    (AuthenticationSASLFinal as sasl) => {
276                        let None = tx.process_message(&sasl.data(), env)? else {
277                            return Err(SCRAMError::ProtocolError.into());
278                        };
279                    },
280                    (AuthenticationOk) => {
281                        trace!("auth ok");
282                        self.0 = Connected;
283                        update.state_changed(ConnectionStateType::Synchronizing);
284                    },
285                    (AuthenticationMessage as auth) => {
286                        trace!("SCRAM Unknown auth message: {}", auth.status())
287                    },
288                    (NoticeResponse as notice) => {
289                        let err = PgServerError::from(notice);
290                        update.server_notice(&err);
291                    },
292                    (ErrorResponse as error) => {
293                        self.0 = Error;
294                        let err = PgServerError::from(error);
295                        update.server_error(&err);
296                        return Err(err.into());
297                    },
298                    message => {
299                        log_unknown_message(message, "SCRAM")?
300                    },
301                });
302            }
303            (Connected, ConnectionDrive::Message(message)) => {
304                match_message!(message, Backend {
305                    (ParameterStatus as param) => {
306                        trace!("param: {:?}={:?}", param.name(), param.value());
307                        update.parameter(param.name().try_into()?, param.value().try_into()?);
308                    },
309                    (BackendKeyData as key_data) => {
310                        trace!("key={:?} pid={:?}", key_data.key(), key_data.pid());
311                        update.cancellation_key(key_data.pid(), key_data.key());
312                    },
313                    (ReadyForQuery as ready) => {
314                        trace!("ready: {:?}", ready.status() as char);
315                        trace!("-> Ready");
316                        self.0 = Ready;
317                        update.state_changed(ConnectionStateType::Ready);
318                    },
319                    (NoticeResponse as notice) => {
320                        let err = PgServerError::from(notice);
321                        update.server_notice(&err);
322                    },
323                    (ErrorResponse as error) => {
324                        self.0 = Error;
325                        let err = PgServerError::from(error);
326                        update.server_error(&err);
327                        return Err(err.into());
328                    },
329                    message => {
330                        log_unknown_message(message, "Connected")?
331                    },
332                });
333            }
334            (Ready, _) | (Error, _) => {
335                return Err(invalid_state!("Unexpected drive for Ready or Error state"))
336            }
337            _ => return Err(invalid_state!("Unexpected (state, drive) combination")),
338        }
339        Ok(())
340    }
341
342    fn send_startup_message(
343        credentials: &Credentials,
344        update: &mut impl ConnectionStateUpdate,
345    ) -> Result<(), std::io::Error> {
346        let mut params = vec![
347            StartupNameValueBuilder {
348                name: "user",
349                value: &credentials.username,
350            },
351            StartupNameValueBuilder {
352                name: "database",
353                value: &credentials.database,
354            },
355        ];
356        for (name, value) in &credentials.server_settings {
357            params.push(StartupNameValueBuilder { name, value })
358        }
359
360        update.send_initial(&StartupMessageBuilder { params: || &params })
361    }
362}
363
364fn log_unknown_message(
365    message: Result<Message, ParseError>,
366    state: &str,
367) -> Result<(), ParseError> {
368    match message {
369        Ok(message) => {
370            warn!(
371                "Unexpected message {:?} (length {}) received in {} state",
372                message.mtype(),
373                message.mlen(),
374                state
375            );
376            Ok(())
377        }
378        Err(e) => {
379            error!("Corrupted message received in {} state", state);
380            Err(e)
381        }
382    }
383}