gel_auth/postgres/
server_state_machine.rs

1use super::{ConnectionError, ConnectionSslRequirement};
2use crate::{
3    handshake::{ServerAuth, ServerAuthDrive, ServerAuthError, ServerAuthResponse},
4    AuthType, CredentialData,
5};
6use gel_pg_protocol::{
7    errors::{
8        PgError, PgErrorConnectionException, PgErrorFeatureNotSupported,
9        PgErrorInvalidAuthorizationSpecification, PgServerError, PgServerErrorField,
10    },
11    prelude::*,
12    protocol::*,
13};
14use std::str::Utf8Error;
15use tracing::{error, trace, warn};
16
17#[derive(Clone, Copy, Debug)]
18pub enum ConnectionStateType {
19    Connecting,
20    SslConnecting,
21    Authenticating,
22    Synchronizing,
23    Ready,
24}
25
26#[derive(Debug)]
27pub enum ConnectionDrive<'a> {
28    /// Raw bytes from a client.
29    RawMessage(&'a [u8]),
30    /// Initial message from client.
31    Initial(Result<InitialMessage<'a>, ParseError>),
32    /// Non-initial message from the client.
33    Message(Result<Message<'a>, ParseError>),
34    /// SSL is ready.
35    SslReady,
36    /// Provide authentication information. The environment may supply credential data
37    /// that doesn't match the auth type. In such cases, the server will try to adapt
38    /// the auth data appropriately.
39    ///
40    /// Additionally, the environment can provide a "Trust" credential for automatic
41    /// success or a "Deny" credential for automatic failure. The server will simulate
42    /// a login process before unconditionally succeeding or failing in these cases.
43    AuthInfo(AuthType, CredentialData),
44    /// Once authorized, the server may sync any number of parameters until ready.
45    Parameter(String, String),
46    /// Ready, handshake complete.
47    Ready(i32, i32),
48    /// Fail the connection with a Postgres error code and message.
49    Fail(PgError, &'a str),
50}
51
52pub trait ConnectionStateSend {
53    /// Send the response to the SSL initiation.
54    fn send_ssl(&mut self, message: SSLResponseBuilder) -> Result<(), std::io::Error>;
55    /// Send an ordinary message.
56    fn send<'a, M>(
57        &mut self,
58        message: impl IntoBackendBuilder<'a, M>,
59    ) -> Result<(), std::io::Error>;
60    /// Perform the SSL upgrade.
61    fn upgrade(&mut self) -> Result<(), std::io::Error>;
62    /// Notify the environment that a user and database were selected.
63    fn auth(&mut self, user: String, database: String) -> Result<(), std::io::Error>;
64    /// Notify the environment that parameters are requested.
65    fn params(&mut self) -> Result<(), std::io::Error>;
66}
67
68/// A callback for connection state changes.
69#[allow(unused)]
70pub trait ConnectionStateUpdate: ConnectionStateSend {
71    fn parameter(&mut self, name: &str, value: &str) {}
72    fn state_changed(&mut self, state: ConnectionStateType) {}
73    fn server_error(&mut self, error: &PgServerError) {}
74}
75
76#[derive(Debug)]
77pub enum ConnectionEvent<'a> {
78    SendSSL(SSLResponseBuilder),
79    Send(BackendBuilder<'a>),
80    Upgrade,
81    Auth(String, String),
82    Params,
83    Parameter(&'a str, &'a str),
84    StateChanged(ConnectionStateType),
85    ServerError(&'a PgServerError),
86}
87
88impl<F> ConnectionStateSend for F
89where
90    F: FnMut(ConnectionEvent) -> Result<(), std::io::Error>,
91{
92    fn send_ssl(&mut self, message: SSLResponseBuilder) -> Result<(), std::io::Error> {
93        self(ConnectionEvent::SendSSL(message))
94    }
95
96    fn send<'a, M>(
97        &mut self,
98        message: impl IntoBackendBuilder<'a, M>,
99    ) -> Result<(), std::io::Error> {
100        self(ConnectionEvent::Send(message.into_builder()))
101    }
102
103    fn upgrade(&mut self) -> Result<(), std::io::Error> {
104        self(ConnectionEvent::Upgrade)
105    }
106
107    fn auth(&mut self, user: String, database: String) -> Result<(), std::io::Error> {
108        self(ConnectionEvent::Auth(user, database))
109    }
110
111    fn params(&mut self) -> Result<(), std::io::Error> {
112        self(ConnectionEvent::Params)
113    }
114}
115
116impl<F> ConnectionStateUpdate for F
117where
118    F: FnMut(ConnectionEvent) -> Result<(), std::io::Error>,
119{
120    fn parameter(&mut self, name: &str, value: &str) {
121        let _ = self(ConnectionEvent::Parameter(name, value));
122    }
123
124    fn state_changed(&mut self, state: ConnectionStateType) {
125        let _ = self(ConnectionEvent::StateChanged(state));
126    }
127
128    fn server_error(&mut self, error: &PgServerError) {
129        let _ = self(ConnectionEvent::ServerError(error));
130    }
131}
132
133#[derive(Debug)]
134#[allow(clippy::large_enum_variant)] // Auth is much larger
135enum ServerStateImpl {
136    /// Initial state, enum indicates whether SSL is required (or None if enabled)
137    Initial(Option<ConnectionSslRequirement>),
138    /// SSL connection is being established
139    SslConnecting,
140    /// Waiting for AuthInfo
141    AuthInfo(String),
142    /// Authentication process has begun
143    Authenticating(ServerAuth),
144    /// Synchronizing connection parameters
145    Synchronizing,
146    /// Connection is ready for queries
147    Ready,
148    /// An error has occurred
149    Error,
150}
151
152#[derive(derive_more::Debug)]
153pub struct ServerState {
154    state: ServerStateImpl,
155    #[debug(skip)]
156    initial_buffer: StructBuffer<InitialMessage<'static>>,
157    #[debug(skip)]
158    buffer: StructBuffer<Message<'static>>,
159}
160
161fn send_error(
162    update: &mut impl ConnectionStateUpdate,
163    code: PgError,
164    message: &str,
165) -> std::io::Result<()> {
166    let error = PgServerError::new(code, message, Default::default());
167    update.server_error(&error);
168    update.send(&ErrorResponseBuilder {
169        fields: &[
170            &ErrorFieldBuilder {
171                etype: PgServerErrorField::Severity as u8,
172                value: "ERROR",
173            },
174            &ErrorFieldBuilder {
175                etype: PgServerErrorField::SeverityNonLocalized as u8,
176                value: "ERROR",
177            },
178            &ErrorFieldBuilder {
179                etype: PgServerErrorField::Code as u8,
180                value: std::str::from_utf8(&code.to_code()).unwrap(),
181            },
182            &ErrorFieldBuilder {
183                etype: PgServerErrorField::Message as u8,
184                value: message,
185            },
186        ],
187    })
188}
189
190#[derive(Debug, derive_more::Display, derive_more::Error, derive_more::From)]
191enum ServerError {
192    IO(#[from] std::io::Error),
193    Protocol(#[from] PgError),
194    Utf8Error(#[from] Utf8Error),
195}
196
197impl From<ServerAuthError> for ServerError {
198    fn from(value: ServerAuthError) -> Self {
199        match value {
200            ServerAuthError::InvalidAuthorizationSpecification => {
201                ServerError::Protocol(PgError::InvalidAuthorizationSpecification(
202                    PgErrorInvalidAuthorizationSpecification::InvalidAuthorizationSpecification,
203                ))
204            }
205            ServerAuthError::InvalidPassword => {
206                ServerError::Protocol(PgError::InvalidAuthorizationSpecification(
207                    PgErrorInvalidAuthorizationSpecification::InvalidPassword,
208                ))
209            }
210            ServerAuthError::InvalidSaslMessage(_) => ServerError::Protocol(
211                PgError::ConnectionException(PgErrorConnectionException::ProtocolViolation),
212            ),
213            ServerAuthError::UnsupportedAuthType => ServerError::Protocol(
214                PgError::FeatureNotSupported(PgErrorFeatureNotSupported::FeatureNotSupported),
215            ),
216            ServerAuthError::InvalidMessageType => ServerError::Protocol(
217                PgError::ConnectionException(PgErrorConnectionException::ProtocolViolation),
218            ),
219        }
220    }
221}
222
223const PROTOCOL_ERROR: ServerError = ServerError::Protocol(PgError::ConnectionException(
224    PgErrorConnectionException::ProtocolViolation,
225));
226const AUTH_ERROR: ServerError = ServerError::Protocol(PgError::InvalidAuthorizationSpecification(
227    PgErrorInvalidAuthorizationSpecification::InvalidAuthorizationSpecification,
228));
229const PROTOCOL_VERSION_ERROR: ServerError = ServerError::Protocol(PgError::FeatureNotSupported(
230    PgErrorFeatureNotSupported::FeatureNotSupported,
231));
232
233impl ServerState {
234    pub fn new(ssl_requirement: ConnectionSslRequirement) -> Self {
235        Self {
236            state: ServerStateImpl::Initial(Some(ssl_requirement)),
237            initial_buffer: Default::default(),
238            buffer: Default::default(),
239        }
240    }
241
242    pub fn is_ready(&self) -> bool {
243        matches!(self.state, ServerStateImpl::Ready)
244    }
245
246    pub fn is_error(&self) -> bool {
247        matches!(self.state, ServerStateImpl::Error)
248    }
249
250    pub fn is_done(&self) -> bool {
251        self.is_ready() || self.is_error()
252    }
253
254    pub fn drive(
255        &mut self,
256        drive: ConnectionDrive,
257        update: &mut impl ConnectionStateUpdate,
258    ) -> Result<(), ConnectionError> {
259        trace!("SERVER DRIVE: {:?} {:?}", self.state, drive);
260        let res = match drive {
261            ConnectionDrive::RawMessage(raw) => match self.state {
262                ServerStateImpl::Initial(..) => self.initial_buffer.push_fallible(raw, |message| {
263                    self.state
264                        .drive_inner(ConnectionDrive::Initial(message), update)
265                }),
266                ServerStateImpl::Authenticating(..) => self.buffer.push_fallible(raw, |message| {
267                    self.state
268                        .drive_inner(ConnectionDrive::Message(message), update)
269                }),
270                _ => {
271                    error!("Unexpected drive in state {:?}", self.state);
272                    Err(PROTOCOL_ERROR)
273                }
274            },
275            drive => self.state.drive_inner(drive, update),
276        };
277
278        match res {
279            Ok(_) => Ok(()),
280            Err(ServerError::IO(e)) => Err(e.into()),
281            Err(ServerError::Utf8Error(e)) => Err(e.into()),
282            Err(ServerError::Protocol(code)) => {
283                self.state = ServerStateImpl::Error;
284                send_error(update, code, "Connection error")?;
285                Err(PgServerError::new(code, "Connection error", Default::default()).into())
286            }
287        }
288    }
289}
290
291impl ServerStateImpl {
292    fn drive_inner(
293        &mut self,
294        drive: ConnectionDrive,
295        update: &mut impl ConnectionStateUpdate,
296    ) -> Result<(), ServerError> {
297        use ServerStateImpl::*;
298
299        match (&mut *self, drive) {
300            (Initial(ssl), ConnectionDrive::Initial(initial_message)) => {
301                match_message!(initial_message, InitialMessage {
302                    (StartupMessage as startup) => {
303                        let mut user = String::new();
304                        let mut database = String::new();
305                        for param in startup.params() {
306                            if param.name() == "user" {
307                                user = param.value().to_owned()?;
308                            } else if param.name() == "database" {
309                                database = param.value().to_owned()?;
310                            }
311                            trace!("param: {:?}={:?}", param.name(), param.value());
312                            update.parameter(param.name().to_str()?, param.value().to_str()?);
313                        }
314                        if user.is_empty() {
315                            return Err(AUTH_ERROR);
316                        }
317                        if database.is_empty() {
318                            database = user.clone();
319                        }
320                        *self = AuthInfo(user.clone());
321                        update.auth(user, database)?;
322                    },
323                    (SSLRequest) => {
324                        let Some(ssl) = *ssl else {
325                            return Err(PROTOCOL_ERROR);
326                        };
327                        if ssl == ConnectionSslRequirement::Disable {
328                            update.send_ssl(SSLResponseBuilder { code: b'N' })?;
329                            update.upgrade()?;
330                        } else {
331                            update.send_ssl(SSLResponseBuilder { code: b'S' })?;
332                            *self = SslConnecting;
333                        }
334                    },
335                    unknown => {
336                        log_unknown_initial_message(unknown, "Initial")?;
337                    }
338                });
339            }
340            (SslConnecting, ConnectionDrive::SslReady) => {
341                *self = Initial(None);
342            }
343            (SslConnecting, _) => {
344                return Err(PROTOCOL_ERROR);
345            }
346            (AuthInfo(username), ConnectionDrive::AuthInfo(auth_type, credential_data)) => {
347                let mut auth = ServerAuth::new(username.clone(), auth_type, credential_data);
348                match auth.drive(ServerAuthDrive::Initial) {
349                    ServerAuthResponse::Initial(AuthType::Plain, _) => {
350                        update.send(&AuthenticationCleartextPasswordBuilder::default())?;
351                    }
352                    ServerAuthResponse::Initial(AuthType::Md5, salt) => {
353                        update.send(&AuthenticationMD5PasswordBuilder {
354                            salt: TryInto::<[u8; 4]>::try_into(salt).map_err(|_| PROTOCOL_ERROR)?,
355                        })?;
356                    }
357                    ServerAuthResponse::Initial(AuthType::ScramSha256, _) => {
358                        update.send(&AuthenticationSASLBuilder {
359                            mechanisms: ["SCRAM-SHA-256"],
360                        })?;
361                    }
362                    ServerAuthResponse::Complete(..) => {
363                        update.send(&AuthenticationOkBuilder::default())?;
364                        *self = Synchronizing;
365                        update.params()?;
366                        return Ok(());
367                    }
368                    ServerAuthResponse::Error(e) => {
369                        error!("Authentication error in initial state: {e:?}");
370                        return Err(e.into());
371                    }
372                    response => {
373                        error!("Unexpected response: {response:?}");
374                        return Err(PROTOCOL_ERROR);
375                    }
376                }
377                *self = Authenticating(auth);
378            }
379            (Authenticating(auth), ConnectionDrive::Message(message)) => {
380                trace!("auth = {auth:?}, initial = {}", auth.is_initial_message());
381                match_message!(message, Message {
382                    (PasswordMessage as password) if matches!(auth.auth_type(), AuthType::Plain | AuthType::Md5) => {
383                        match auth.drive(ServerAuthDrive::Message(auth.auth_type(), password.password().to_bytes())) {
384                            ServerAuthResponse::Complete(..) => {
385                                update.send(&AuthenticationOkBuilder::default())?;
386                                *self = Synchronizing;
387                                update.params()?;
388                            }
389                            ServerAuthResponse::Error(e) => {
390                                error!("Authentication error for password message: {e:?}");
391                                return Err(e.into())
392                            },
393                            response => {
394                                error!("Unexpected response for password message: {response:?}");
395                                return Err(PROTOCOL_ERROR);
396                            }
397                        }
398                    },
399                    (SASLInitialResponse as sasl) if auth.is_initial_message() => {
400                        if sasl.mechanism() != "SCRAM-SHA-256" {
401                            error!("Unexpected mechanism: {:?}", sasl.mechanism());
402                            return Err(PROTOCOL_ERROR);
403                        }
404                        match auth.drive(ServerAuthDrive::Message(AuthType::ScramSha256, sasl.response().as_ref())) {
405                            ServerAuthResponse::Continue(final_message) => {
406                                update.send(&AuthenticationSASLContinueBuilder {
407                                    data: &final_message,
408                                })?;
409                            }
410                            ServerAuthResponse::Error(e) => {
411                                error!("Authentication error for SASL initial response: {e:?}");
412                                return Err(e.into())
413                            },
414                            response => {
415                                error!("Unexpected response for SASL initial response: {response:?}");
416                                return Err(PROTOCOL_ERROR);
417                            }
418                        }
419                    },
420                    (SASLResponse as sasl) if !auth.is_initial_message() => {
421                        match auth.drive(ServerAuthDrive::Message(AuthType::ScramSha256, sasl.response().as_ref())) {
422                            ServerAuthResponse::Complete(data) => {
423                                update.send(&AuthenticationSASLFinalBuilder {
424                                    data,
425                                })?;
426                                update.send(&AuthenticationOkBuilder::default())?;
427                                *self = Synchronizing;
428                                update.params()?;
429                            }
430                            ServerAuthResponse::Error(e) => {
431                                error!("Authentication error for SASL response: {e:?}");
432                                return Err(e.into())
433                            },
434                            response => {
435                                error!("Unexpected response for SASL response: {response:?}");
436                                return Err(PROTOCOL_ERROR);
437                            }
438                        }
439                    },
440                    unknown => {
441                        log_unknown_message(unknown, "Authenticating")?;
442                    }
443                });
444            }
445            (Synchronizing, ConnectionDrive::Parameter(name, value)) => {
446                update.send(&ParameterStatusBuilder { name, value })?;
447            }
448            (Synchronizing, ConnectionDrive::Ready(pid, key)) => {
449                update.send(&BackendKeyDataBuilder { pid, key })?;
450                update.send(&ReadyForQueryBuilder { status: b'I' })?;
451                *self = Ready;
452            }
453            (_, ConnectionDrive::Fail(error, _)) => {
454                return Err(ServerError::Protocol(error));
455            }
456            _ => {
457                error!("Unexpected drive in state {:?}", self);
458                return Err(PROTOCOL_ERROR);
459            }
460        }
461
462        Ok(())
463    }
464}
465
466fn log_unknown_initial_message(
467    message: Result<InitialMessage, ParseError>,
468    state: &str,
469) -> Result<(), ServerError> {
470    match message {
471        Ok(message) => {
472            warn!(
473                "Unexpected message {:?} (length {}) received in {} state",
474                message.protocol_version(),
475                message.mlen(),
476                state
477            );
478            Err(PROTOCOL_VERSION_ERROR)
479        }
480        Err(e) => {
481            error!("Corrupted message received in {} state {:?}", state, e);
482            Err(PROTOCOL_ERROR)
483        }
484    }
485}
486
487fn log_unknown_message(
488    message: Result<Message, ParseError>,
489    state: &str,
490) -> Result<(), ServerError> {
491    match message {
492        Ok(message) => {
493            warn!(
494                "Unexpected message {:?} (length {}) received in {} state",
495                message.mtype(),
496                message.mlen(),
497                state
498            );
499            Ok(())
500        }
501        Err(e) => {
502            error!("Corrupted message received in {} state {:?}", state, e);
503            Err(PROTOCOL_ERROR)
504        }
505    }
506}