1use gel_db_protocol::errors::EdbError;
2use gel_db_protocol::prelude::*;
3use gel_db_protocol::protocol::{
4 Annotation, AuthenticationOkBuilder, AuthenticationRequiredSASLMessageBuilder,
5 AuthenticationSASLContinueBuilder, AuthenticationSASLFinalBuilder,
6 AuthenticationSASLInitialResponse, AuthenticationSASLResponse, ClientHandshake,
7 EdgeDBBackendBuilder, ErrorResponseBuilder, IntoEdgeDBBackendBuilder, KeyValue, Message,
8 ParameterStatusBuilder, ProtocolExtension, ReadyForCommandBuilder, ServerHandshakeBuilder,
9 ServerKeyDataBuilder, TransactionState,
10};
11
12use crate::handshake::{ServerAuth, ServerAuthDrive, ServerAuthError, ServerAuthResponse};
13use crate::{AuthType, CredentialData};
14use std::str::Utf8Error;
15use tracing::{error, trace, warn};
16
17use super::ConnectionError;
18
19#[derive(Clone, Copy, Debug)]
20pub enum ConnectionStateType {
21 Connecting,
22 Authenticating,
23 Synchronizing,
24 Ready,
25}
26
27#[derive(Debug)]
28pub enum ConnectionDrive<'a> {
29 RawMessage(&'a [u8]),
30 Message(Result<Message<'a>, ParseError>),
31 AuthInfo(AuthType, CredentialData),
32 Parameter(String, String),
33 Ready([u8; 32]),
34 Fail(EdbError, &'a str),
35}
36
37pub trait ConnectionStateSend {
38 fn send<'a, M>(
39 &mut self,
40 message: impl IntoEdgeDBBackendBuilder<'a, M>,
41 ) -> Result<(), std::io::Error>;
42 fn auth(
43 &mut self,
44 user: String,
45 database: String,
46 branch: String,
47 ) -> Result<(), std::io::Error>;
48 fn params(&mut self) -> Result<(), std::io::Error>;
49}
50
51#[allow(unused)]
52pub trait ConnectionStateUpdate: ConnectionStateSend {
53 fn parameter(&mut self, name: &str, value: &str) {}
54 fn state_changed(&mut self, state: ConnectionStateType) {}
55 fn server_error(&mut self, error: &EdbError) {}
56}
57
58#[derive(derive_more::Debug)]
59pub enum ConnectionEvent<'a> {
60 #[debug("Send(...)")]
61 Send(EdgeDBBackendBuilder<'a>),
62 Auth(String, String, String),
63 Params,
64 Parameter(&'a str, &'a str),
65 StateChanged(ConnectionStateType),
66 ServerError(EdbError),
67}
68
69impl<F> ConnectionStateSend for F
70where
71 F: for<'a> FnMut(ConnectionEvent<'a>) -> Result<(), std::io::Error>,
72{
73 fn send<'a, M>(
74 &mut self,
75 message: impl IntoEdgeDBBackendBuilder<'a, M>,
76 ) -> Result<(), std::io::Error> {
77 self(ConnectionEvent::Send(message.into_builder()))
78 }
79
80 fn auth(
81 &mut self,
82 user: String,
83 database: String,
84 branch: String,
85 ) -> Result<(), std::io::Error> {
86 self(ConnectionEvent::Auth(user, database, branch))
87 }
88
89 fn params(&mut self) -> Result<(), std::io::Error> {
90 self(ConnectionEvent::Params)
91 }
92}
93
94impl<F> ConnectionStateUpdate for F
95where
96 F: FnMut(ConnectionEvent) -> Result<(), std::io::Error>,
97{
98 fn parameter(&mut self, name: &str, value: &str) {
99 let _ = self(ConnectionEvent::Parameter(name, value));
100 }
101
102 fn state_changed(&mut self, state: ConnectionStateType) {
103 let _ = self(ConnectionEvent::StateChanged(state));
104 }
105
106 fn server_error(&mut self, error: &EdbError) {
107 let _ = self(ConnectionEvent::ServerError(*error));
108 }
109}
110
111#[derive(Debug, derive_more::Display, derive_more::Error, derive_more::From)]
112enum ServerError {
113 IO(#[from] std::io::Error),
114 Protocol(#[from] EdbError),
115 Utf8Error(#[from] Utf8Error),
116}
117
118impl From<ServerAuthError> for ServerError {
119 fn from(value: ServerAuthError) -> Self {
120 match value {
121 ServerAuthError::InvalidAuthorizationSpecification => {
122 ServerError::Protocol(EdbError::AuthenticationError)
123 }
124 ServerAuthError::InvalidPassword => {
125 ServerError::Protocol(EdbError::AuthenticationError)
126 }
127 ServerAuthError::InvalidSaslMessage(_) => {
128 ServerError::Protocol(EdbError::ProtocolError)
129 }
130 ServerAuthError::UnsupportedAuthType => {
131 ServerError::Protocol(EdbError::UnsupportedFeatureError)
132 }
133 ServerAuthError::InvalidMessageType => ServerError::Protocol(EdbError::ProtocolError),
134 }
135 }
136}
137
138const PROTOCOL_ERROR: ServerError = ServerError::Protocol(EdbError::ProtocolError);
139const AUTH_ERROR: ServerError = ServerError::Protocol(EdbError::AuthenticationError);
140const PROTOCOL_VERSION_ERROR: ServerError =
141 ServerError::Protocol(EdbError::UnsupportedProtocolVersionError);
142
143#[derive(Debug, Default)]
144#[allow(clippy::large_enum_variant)] enum ServerStateImpl {
146 #[default]
147 Initial,
148 AuthInfo(String),
149 Authenticating(ServerAuth),
150 Synchronizing,
151 Ready,
152 Error,
153}
154
155#[derive(Debug, Default)]
156pub struct ServerState {
157 state: ServerStateImpl,
158 buffer: StructBuffer<Message<'static>>,
159}
160
161impl ServerState {
162 pub fn is_ready(&self) -> bool {
163 matches!(self.state, ServerStateImpl::Ready)
164 }
165
166 pub fn is_error(&self) -> bool {
167 matches!(self.state, ServerStateImpl::Error)
168 }
169
170 pub fn is_done(&self) -> bool {
171 self.is_ready() || self.is_error()
172 }
173
174 pub fn drive(
175 &mut self,
176 drive: ConnectionDrive,
177 update: &mut impl ConnectionStateUpdate,
178 ) -> Result<(), ConnectionError> {
179 trace!("SERVER DRIVE: {:?} {:?}", self.state, drive);
180 let res = match drive {
181 ConnectionDrive::RawMessage(raw) => self.buffer.push_fallible(raw, |message| {
182 trace!("Parsed message: {message:?}");
183 self.state
184 .drive_inner(ConnectionDrive::Message(message), update)
185 }),
186 drive => self.state.drive_inner(drive, update),
187 };
188
189 match res {
190 Ok(_) => Ok(()),
191 Err(ServerError::IO(e)) => Err(e.into()),
192 Err(ServerError::Utf8Error(e)) => Err(e.into()),
193 Err(ServerError::Protocol(code)) => {
194 self.state = ServerStateImpl::Error;
195 send_error(update, code, "Connection error")?;
196 Err(code.into())
197 }
198 }
199 }
200}
201
202impl ServerStateImpl {
203 fn drive_inner(
204 &mut self,
205 drive: ConnectionDrive,
206 update: &mut impl ConnectionStateUpdate,
207 ) -> Result<(), ServerError> {
208 use ServerStateImpl::*;
209
210 match (&mut *self, drive) {
211 (Initial, ConnectionDrive::Message(message)) => {
212 match_message!(message, Message {
213 (ClientHandshake as handshake) => {
214 trace!("ClientHandshake: {handshake:?}");
215
216 let major_ver = handshake.major_ver();
220 let minor_ver = handshake.minor_ver();
221 match (major_ver, minor_ver) {
222 (..=0, _) => {
223 update.send(&ServerHandshakeBuilder { major_ver: 1, minor_ver: 0, extensions: Array::<_, ProtocolExtension>::default() })?;
224 }
225 (1, 1..) => {
226 return Err(PROTOCOL_VERSION_ERROR);
228 }
229 (2, 1..) | (3.., _) => {
230 update.send(&ServerHandshakeBuilder { major_ver: 2, minor_ver: 0, extensions: Array::<_, ProtocolExtension>::default() })?;
231 }
232 _ => {}
233 }
234
235 let mut user = String::new();
236 let mut database = String::new();
237 let mut branch = String::new();
238 for param in handshake.params() {
239 match param.name().to_str()? {
240 "user" => user = param.value().to_owned()?,
241 "database" => database = param.value().to_owned()?,
242 "branch" => branch = param.value().to_owned()?,
243 _ => {}
244 }
245 update.parameter(param.name().to_str()?, param.value().to_str()?);
246 }
247 if user.is_empty() {
248 return Err(AUTH_ERROR);
249 }
250 if database.is_empty() {
251 database = user.clone();
252 }
253 *self = AuthInfo(user.clone());
254 update.auth(user, database, branch)?;
255 },
256 unknown => {
257 log_unknown_message(unknown, "Initial")?;
258 }
259 });
260 }
261 (AuthInfo(username), ConnectionDrive::AuthInfo(auth_type, credential_data)) => {
262 let mut auth = ServerAuth::new(username.clone(), auth_type, credential_data);
263 match auth.drive(ServerAuthDrive::Initial) {
264 ServerAuthResponse::Initial(AuthType::ScramSha256, _) => {
265 update.send(&AuthenticationRequiredSASLMessageBuilder {
266 methods: &["SCRAM-SHA-256"],
267 })?;
268 }
269 ServerAuthResponse::Complete(..) => {
270 update.send(&AuthenticationOkBuilder {})?;
271 *self = Synchronizing;
272 update.params()?;
273 return Ok(());
274 }
275 ServerAuthResponse::Error(e) => return Err(e.into()),
276 _ => return Err(PROTOCOL_ERROR),
277 }
278 *self = Authenticating(auth);
279 }
280 (Authenticating(auth), ConnectionDrive::Message(message)) => {
281 match_message!(message, Message {
282 (AuthenticationSASLInitialResponse as sasl) if auth.is_initial_message() => {
283 match auth.drive(ServerAuthDrive::Message(AuthType::ScramSha256, sasl.sasl_data().as_ref())) {
284 ServerAuthResponse::Continue(final_message) => {
285 update.send(&AuthenticationSASLContinueBuilder {
286 sasl_data: final_message.as_slice(),
287 })?;
288 }
289 ServerAuthResponse::Error(e) => return Err(e.into()),
290 _ => return Err(PROTOCOL_ERROR),
291 }
292 },
293 (AuthenticationSASLResponse as sasl) if !auth.is_initial_message() => {
294 match auth.drive(ServerAuthDrive::Message(AuthType::ScramSha256, sasl.sasl_data().as_ref())) {
295 ServerAuthResponse::Complete(data) => {
296 update.send(&AuthenticationSASLFinalBuilder {
297 sasl_data: data.as_slice(),
298 })?;
299 update.send(&AuthenticationOkBuilder::default())?;
300 *self = Synchronizing;
301 update.params()?;
302 }
303 ServerAuthResponse::Error(e) => return Err(e.into()),
304 _ => return Err(PROTOCOL_ERROR),
305 }
306 },
307 unknown => {
308 log_unknown_message(unknown, "Authenticating")?;
309 }
310 });
311 }
312 (Synchronizing, ConnectionDrive::Parameter(name, value)) => {
313 update.send(&ParameterStatusBuilder {
314 name: name.as_bytes(),
315 value: value.as_bytes(),
316 })?;
317 }
318 (Synchronizing, ConnectionDrive::Ready(key_data)) => {
319 update.send(&ServerKeyDataBuilder { data: key_data })?;
320 update.send(&ReadyForCommandBuilder {
321 annotations: Array::<_, Annotation>::default(),
322 transaction_state: TransactionState::NotInTransaction,
323 })?;
324 *self = Ready;
325 }
326 (_, ConnectionDrive::Fail(error, _)) => {
327 return Err(ServerError::Protocol(error));
328 }
329 _ => {
330 error!("Unexpected drive in state {:?}", self);
331 return Err(PROTOCOL_ERROR);
332 }
333 }
334
335 Ok(())
336 }
337}
338
339fn log_unknown_message(
340 message: Result<Message, ParseError>,
341 state: &str,
342) -> Result<(), ServerError> {
343 match message {
344 Ok(message) => {
345 warn!(
346 "Unexpected message {:?} (length {}) received in {} state",
347 message.mtype(),
348 message.mlen(),
349 state
350 );
351 Ok(())
352 }
353 Err(e) => {
354 error!("Corrupted message received in {} state {:?}", state, e);
355 Err(PROTOCOL_ERROR)
356 }
357 }
358}
359
360fn send_error(
361 update: &mut impl ConnectionStateUpdate,
362 code: EdbError,
363 message: &str,
364) -> std::io::Result<()> {
365 update.server_error(&code);
366 update.send(&ErrorResponseBuilder {
367 severity: ErrorSeverity::Error as u8,
368 error_code: code as u32,
369 message,
370 attributes: Array::<_, KeyValue>::default(),
371 })
372}
373
374#[allow(unused)]
375enum ErrorSeverity {
376 Error = 0x78,
377 Fatal = 0xc8,
378 Panic = 0xff,
379}