gel_auth/gel/
server_state_machine.rs

1use gel_db_protocol::errors::EdbError;
2use gel_db_protocol::prelude::*;
3use gel_db_protocol::protocol::{
4    Annotation, AuthenticationOkBuilder, AuthenticationRequiredSASLMessageBuilder,
5    AuthenticationSASLContinueBuilder, AuthenticationSASLFinalBuilder,
6    AuthenticationSASLInitialResponse, AuthenticationSASLResponse, ClientHandshake,
7    EdgeDBBackendBuilder, ErrorResponseBuilder, IntoEdgeDBBackendBuilder, KeyValue, Message,
8    ParameterStatusBuilder, ProtocolExtension, ReadyForCommandBuilder, ServerHandshakeBuilder,
9    ServerKeyDataBuilder, TransactionState,
10};
11
12use crate::handshake::{ServerAuth, ServerAuthDrive, ServerAuthError, ServerAuthResponse};
13use crate::{AuthType, CredentialData};
14use std::str::Utf8Error;
15use tracing::{error, trace, warn};
16
17use super::ConnectionError;
18
19#[derive(Clone, Copy, Debug)]
20pub enum ConnectionStateType {
21    Connecting,
22    Authenticating,
23    Synchronizing,
24    Ready,
25}
26
27#[derive(Debug)]
28pub enum ConnectionDrive<'a> {
29    RawMessage(&'a [u8]),
30    Message(Result<Message<'a>, ParseError>),
31    AuthInfo(AuthType, CredentialData),
32    Parameter(String, String),
33    Ready([u8; 32]),
34    Fail(EdbError, &'a str),
35}
36
37pub trait ConnectionStateSend {
38    fn send<'a, M>(
39        &mut self,
40        message: impl IntoEdgeDBBackendBuilder<'a, M>,
41    ) -> Result<(), std::io::Error>;
42    fn auth(
43        &mut self,
44        user: String,
45        database: String,
46        branch: String,
47    ) -> Result<(), std::io::Error>;
48    fn params(&mut self) -> Result<(), std::io::Error>;
49}
50
51#[allow(unused)]
52pub trait ConnectionStateUpdate: ConnectionStateSend {
53    fn parameter(&mut self, name: &str, value: &str) {}
54    fn state_changed(&mut self, state: ConnectionStateType) {}
55    fn server_error(&mut self, error: &EdbError) {}
56}
57
58#[derive(derive_more::Debug)]
59pub enum ConnectionEvent<'a> {
60    #[debug("Send(...)")]
61    Send(EdgeDBBackendBuilder<'a>),
62    Auth(String, String, String),
63    Params,
64    Parameter(&'a str, &'a str),
65    StateChanged(ConnectionStateType),
66    ServerError(EdbError),
67}
68
69impl<F> ConnectionStateSend for F
70where
71    F: for<'a> FnMut(ConnectionEvent<'a>) -> Result<(), std::io::Error>,
72{
73    fn send<'a, M>(
74        &mut self,
75        message: impl IntoEdgeDBBackendBuilder<'a, M>,
76    ) -> Result<(), std::io::Error> {
77        self(ConnectionEvent::Send(message.into_builder()))
78    }
79
80    fn auth(
81        &mut self,
82        user: String,
83        database: String,
84        branch: String,
85    ) -> Result<(), std::io::Error> {
86        self(ConnectionEvent::Auth(user, database, branch))
87    }
88
89    fn params(&mut self) -> Result<(), std::io::Error> {
90        self(ConnectionEvent::Params)
91    }
92}
93
94impl<F> ConnectionStateUpdate for F
95where
96    F: FnMut(ConnectionEvent) -> Result<(), std::io::Error>,
97{
98    fn parameter(&mut self, name: &str, value: &str) {
99        let _ = self(ConnectionEvent::Parameter(name, value));
100    }
101
102    fn state_changed(&mut self, state: ConnectionStateType) {
103        let _ = self(ConnectionEvent::StateChanged(state));
104    }
105
106    fn server_error(&mut self, error: &EdbError) {
107        let _ = self(ConnectionEvent::ServerError(*error));
108    }
109}
110
111#[derive(Debug, derive_more::Display, derive_more::Error, derive_more::From)]
112enum ServerError {
113    IO(#[from] std::io::Error),
114    Protocol(#[from] EdbError),
115    Utf8Error(#[from] Utf8Error),
116}
117
118impl From<ServerAuthError> for ServerError {
119    fn from(value: ServerAuthError) -> Self {
120        match value {
121            ServerAuthError::InvalidAuthorizationSpecification => {
122                ServerError::Protocol(EdbError::AuthenticationError)
123            }
124            ServerAuthError::InvalidPassword => {
125                ServerError::Protocol(EdbError::AuthenticationError)
126            }
127            ServerAuthError::InvalidSaslMessage(_) => {
128                ServerError::Protocol(EdbError::ProtocolError)
129            }
130            ServerAuthError::UnsupportedAuthType => {
131                ServerError::Protocol(EdbError::UnsupportedFeatureError)
132            }
133            ServerAuthError::InvalidMessageType => ServerError::Protocol(EdbError::ProtocolError),
134        }
135    }
136}
137
138const PROTOCOL_ERROR: ServerError = ServerError::Protocol(EdbError::ProtocolError);
139const AUTH_ERROR: ServerError = ServerError::Protocol(EdbError::AuthenticationError);
140const PROTOCOL_VERSION_ERROR: ServerError =
141    ServerError::Protocol(EdbError::UnsupportedProtocolVersionError);
142
143#[derive(Debug, Default)]
144#[allow(clippy::large_enum_variant)] // Auth is much larger
145enum ServerStateImpl {
146    #[default]
147    Initial,
148    AuthInfo(String),
149    Authenticating(ServerAuth),
150    Synchronizing,
151    Ready,
152    Error,
153}
154
155#[derive(Debug, Default)]
156pub struct ServerState {
157    state: ServerStateImpl,
158    buffer: StructBuffer<Message<'static>>,
159}
160
161impl ServerState {
162    pub fn is_ready(&self) -> bool {
163        matches!(self.state, ServerStateImpl::Ready)
164    }
165
166    pub fn is_error(&self) -> bool {
167        matches!(self.state, ServerStateImpl::Error)
168    }
169
170    pub fn is_done(&self) -> bool {
171        self.is_ready() || self.is_error()
172    }
173
174    pub fn drive(
175        &mut self,
176        drive: ConnectionDrive,
177        update: &mut impl ConnectionStateUpdate,
178    ) -> Result<(), ConnectionError> {
179        trace!("SERVER DRIVE: {:?} {:?}", self.state, drive);
180        let res = match drive {
181            ConnectionDrive::RawMessage(raw) => self.buffer.push_fallible(raw, |message| {
182                trace!("Parsed message: {message:?}");
183                self.state
184                    .drive_inner(ConnectionDrive::Message(message), update)
185            }),
186            drive => self.state.drive_inner(drive, update),
187        };
188
189        match res {
190            Ok(_) => Ok(()),
191            Err(ServerError::IO(e)) => Err(e.into()),
192            Err(ServerError::Utf8Error(e)) => Err(e.into()),
193            Err(ServerError::Protocol(code)) => {
194                self.state = ServerStateImpl::Error;
195                send_error(update, code, "Connection error")?;
196                Err(code.into())
197            }
198        }
199    }
200}
201
202impl ServerStateImpl {
203    fn drive_inner(
204        &mut self,
205        drive: ConnectionDrive,
206        update: &mut impl ConnectionStateUpdate,
207    ) -> Result<(), ServerError> {
208        use ServerStateImpl::*;
209
210        match (&mut *self, drive) {
211            (Initial, ConnectionDrive::Message(message)) => {
212                match_message!(message, Message {
213                    (ClientHandshake as handshake) => {
214                        trace!("ClientHandshake: {handshake:?}");
215
216                        // The handshake should generate an event rather than hardcoding the min/max protocol versions.
217
218                        // We support 1.x and 2.0
219                        let major_ver = handshake.major_ver();
220                        let minor_ver = handshake.minor_ver();
221                        match (major_ver, minor_ver) {
222                            (..=0, _) => {
223                                update.send(&ServerHandshakeBuilder { major_ver: 1, minor_ver: 0, extensions: Array::<_, ProtocolExtension>::default() })?;
224                            }
225                            (1, 1..) => {
226                                // 1.(1+) never existed
227                                return Err(PROTOCOL_VERSION_ERROR);
228                            }
229                            (2, 1..) | (3.., _) => {
230                                update.send(&ServerHandshakeBuilder { major_ver: 2, minor_ver: 0, extensions: Array::<_, ProtocolExtension>::default() })?;
231                            }
232                            _ => {}
233                        }
234
235                        let mut user = String::new();
236                        let mut database = String::new();
237                        let mut branch = String::new();
238                        for param in handshake.params() {
239                            match param.name().to_str()? {
240                                "user" => user = param.value().to_owned()?,
241                                "database" => database = param.value().to_owned()?,
242                                "branch" => branch = param.value().to_owned()?,
243                                _ => {}
244                            }
245                            update.parameter(param.name().to_str()?, param.value().to_str()?);
246                        }
247                        if user.is_empty() {
248                            return Err(AUTH_ERROR);
249                        }
250                        if database.is_empty() {
251                            database = user.clone();
252                        }
253                        *self = AuthInfo(user.clone());
254                        update.auth(user, database, branch)?;
255                    },
256                    unknown => {
257                        log_unknown_message(unknown, "Initial")?;
258                    }
259                });
260            }
261            (AuthInfo(username), ConnectionDrive::AuthInfo(auth_type, credential_data)) => {
262                let mut auth = ServerAuth::new(username.clone(), auth_type, credential_data);
263                match auth.drive(ServerAuthDrive::Initial) {
264                    ServerAuthResponse::Initial(AuthType::ScramSha256, _) => {
265                        update.send(&AuthenticationRequiredSASLMessageBuilder {
266                            methods: &["SCRAM-SHA-256"],
267                        })?;
268                    }
269                    ServerAuthResponse::Complete(..) => {
270                        update.send(&AuthenticationOkBuilder {})?;
271                        *self = Synchronizing;
272                        update.params()?;
273                        return Ok(());
274                    }
275                    ServerAuthResponse::Error(e) => return Err(e.into()),
276                    _ => return Err(PROTOCOL_ERROR),
277                }
278                *self = Authenticating(auth);
279            }
280            (Authenticating(auth), ConnectionDrive::Message(message)) => {
281                match_message!(message, Message {
282                    (AuthenticationSASLInitialResponse as sasl) if auth.is_initial_message() => {
283                        match auth.drive(ServerAuthDrive::Message(AuthType::ScramSha256, sasl.sasl_data().as_ref())) {
284                            ServerAuthResponse::Continue(final_message) => {
285                                update.send(&AuthenticationSASLContinueBuilder {
286                                    sasl_data: final_message.as_slice(),
287                                })?;
288                            }
289                            ServerAuthResponse::Error(e) => return Err(e.into()),
290                            _ => return Err(PROTOCOL_ERROR),
291                        }
292                    },
293                    (AuthenticationSASLResponse as sasl) if !auth.is_initial_message() => {
294                        match auth.drive(ServerAuthDrive::Message(AuthType::ScramSha256, sasl.sasl_data().as_ref())) {
295                            ServerAuthResponse::Complete(data) => {
296                                update.send(&AuthenticationSASLFinalBuilder {
297                                    sasl_data: data.as_slice(),
298                                })?;
299                                update.send(&AuthenticationOkBuilder::default())?;
300                                *self = Synchronizing;
301                                update.params()?;
302                            }
303                            ServerAuthResponse::Error(e) => return Err(e.into()),
304                            _ => return Err(PROTOCOL_ERROR),
305                        }
306                    },
307                    unknown => {
308                        log_unknown_message(unknown, "Authenticating")?;
309                    }
310                });
311            }
312            (Synchronizing, ConnectionDrive::Parameter(name, value)) => {
313                update.send(&ParameterStatusBuilder {
314                    name: name.as_bytes(),
315                    value: value.as_bytes(),
316                })?;
317            }
318            (Synchronizing, ConnectionDrive::Ready(key_data)) => {
319                update.send(&ServerKeyDataBuilder { data: key_data })?;
320                update.send(&ReadyForCommandBuilder {
321                    annotations: Array::<_, Annotation>::default(),
322                    transaction_state: TransactionState::NotInTransaction,
323                })?;
324                *self = Ready;
325            }
326            (_, ConnectionDrive::Fail(error, _)) => {
327                return Err(ServerError::Protocol(error));
328            }
329            _ => {
330                error!("Unexpected drive in state {:?}", self);
331                return Err(PROTOCOL_ERROR);
332            }
333        }
334
335        Ok(())
336    }
337}
338
339fn log_unknown_message(
340    message: Result<Message, ParseError>,
341    state: &str,
342) -> Result<(), ServerError> {
343    match message {
344        Ok(message) => {
345            warn!(
346                "Unexpected message {:?} (length {}) received in {} state",
347                message.mtype(),
348                message.mlen(),
349                state
350            );
351            Ok(())
352        }
353        Err(e) => {
354            error!("Corrupted message received in {} state {:?}", state, e);
355            Err(PROTOCOL_ERROR)
356        }
357    }
358}
359
360fn send_error(
361    update: &mut impl ConnectionStateUpdate,
362    code: EdbError,
363    message: &str,
364) -> std::io::Result<()> {
365    update.server_error(&code);
366    update.send(&ErrorResponseBuilder {
367        severity: ErrorSeverity::Error as u8,
368        error_code: code as u32,
369        message,
370        attributes: Array::<_, KeyValue>::default(),
371    })
372}
373
374#[allow(unused)]
375enum ErrorSeverity {
376    Error = 0x78,
377    Fatal = 0xc8,
378    Panic = 0xff,
379}