1use super::{ConnectionError, ConnectionSslRequirement};
2use crate::{
3 handshake::{ServerAuth, ServerAuthDrive, ServerAuthError, ServerAuthResponse},
4 AuthType, CredentialData,
5};
6use gel_pg_protocol::{
7 errors::{
8 PgError, PgErrorConnectionException, PgErrorFeatureNotSupported,
9 PgErrorInvalidAuthorizationSpecification, PgServerError, PgServerErrorField,
10 },
11 prelude::*,
12 protocol::*,
13};
14use std::str::Utf8Error;
15use tracing::{error, trace, warn};
16
17#[derive(Clone, Copy, Debug)]
18pub enum ConnectionStateType {
19 Connecting,
20 SslConnecting,
21 Authenticating,
22 Synchronizing,
23 Ready,
24}
25
26#[derive(Debug)]
27pub enum ConnectionDrive<'a> {
28 RawMessage(&'a [u8]),
30 Initial(Result<InitialMessage<'a>, ParseError>),
32 Message(Result<Message<'a>, ParseError>),
34 SslReady,
36 AuthInfo(AuthType, CredentialData),
44 Parameter(String, String),
46 Ready(i32, i32),
48 Fail(PgError, &'a str),
50}
51
52pub trait ConnectionStateSend {
53 fn send_ssl(&mut self, message: SSLResponseBuilder) -> Result<(), std::io::Error>;
55 fn send<'a, M>(
57 &mut self,
58 message: impl IntoBackendBuilder<'a, M>,
59 ) -> Result<(), std::io::Error>;
60 fn upgrade(&mut self) -> Result<(), std::io::Error>;
62 fn auth(&mut self, user: String, database: String) -> Result<(), std::io::Error>;
64 fn params(&mut self) -> Result<(), std::io::Error>;
66}
67
68#[allow(unused)]
70pub trait ConnectionStateUpdate: ConnectionStateSend {
71 fn parameter(&mut self, name: &str, value: &str) {}
72 fn state_changed(&mut self, state: ConnectionStateType) {}
73 fn server_error(&mut self, error: &PgServerError) {}
74}
75
76#[derive(Debug)]
77pub enum ConnectionEvent<'a> {
78 SendSSL(SSLResponseBuilder),
79 Send(BackendBuilder<'a>),
80 Upgrade,
81 Auth(String, String),
82 Params,
83 Parameter(&'a str, &'a str),
84 StateChanged(ConnectionStateType),
85 ServerError(&'a PgServerError),
86}
87
88impl<F> ConnectionStateSend for F
89where
90 F: FnMut(ConnectionEvent) -> Result<(), std::io::Error>,
91{
92 fn send_ssl(&mut self, message: SSLResponseBuilder) -> Result<(), std::io::Error> {
93 self(ConnectionEvent::SendSSL(message))
94 }
95
96 fn send<'a, M>(
97 &mut self,
98 message: impl IntoBackendBuilder<'a, M>,
99 ) -> Result<(), std::io::Error> {
100 self(ConnectionEvent::Send(message.into_builder()))
101 }
102
103 fn upgrade(&mut self) -> Result<(), std::io::Error> {
104 self(ConnectionEvent::Upgrade)
105 }
106
107 fn auth(&mut self, user: String, database: String) -> Result<(), std::io::Error> {
108 self(ConnectionEvent::Auth(user, database))
109 }
110
111 fn params(&mut self) -> Result<(), std::io::Error> {
112 self(ConnectionEvent::Params)
113 }
114}
115
116impl<F> ConnectionStateUpdate for F
117where
118 F: FnMut(ConnectionEvent) -> Result<(), std::io::Error>,
119{
120 fn parameter(&mut self, name: &str, value: &str) {
121 let _ = self(ConnectionEvent::Parameter(name, value));
122 }
123
124 fn state_changed(&mut self, state: ConnectionStateType) {
125 let _ = self(ConnectionEvent::StateChanged(state));
126 }
127
128 fn server_error(&mut self, error: &PgServerError) {
129 let _ = self(ConnectionEvent::ServerError(error));
130 }
131}
132
133#[derive(Debug)]
134#[allow(clippy::large_enum_variant)] enum ServerStateImpl {
136 Initial(Option<ConnectionSslRequirement>),
138 SslConnecting,
140 AuthInfo(String),
142 Authenticating(ServerAuth),
144 Synchronizing,
146 Ready,
148 Error,
150}
151
152#[derive(derive_more::Debug)]
153pub struct ServerState {
154 state: ServerStateImpl,
155 #[debug(skip)]
156 initial_buffer: StructBuffer<InitialMessage<'static>>,
157 #[debug(skip)]
158 buffer: StructBuffer<Message<'static>>,
159}
160
161fn send_error(
162 update: &mut impl ConnectionStateUpdate,
163 code: PgError,
164 message: &str,
165) -> std::io::Result<()> {
166 let error = PgServerError::new(code, message, Default::default());
167 update.server_error(&error);
168 update.send(&ErrorResponseBuilder {
169 fields: &[
170 &ErrorFieldBuilder {
171 etype: PgServerErrorField::Severity as u8,
172 value: "ERROR",
173 },
174 &ErrorFieldBuilder {
175 etype: PgServerErrorField::SeverityNonLocalized as u8,
176 value: "ERROR",
177 },
178 &ErrorFieldBuilder {
179 etype: PgServerErrorField::Code as u8,
180 value: std::str::from_utf8(&code.to_code()).unwrap(),
181 },
182 &ErrorFieldBuilder {
183 etype: PgServerErrorField::Message as u8,
184 value: message,
185 },
186 ],
187 })
188}
189
190#[derive(Debug, derive_more::Display, derive_more::Error, derive_more::From)]
191enum ServerError {
192 IO(#[from] std::io::Error),
193 Protocol(#[from] PgError),
194 Utf8Error(#[from] Utf8Error),
195}
196
197impl From<ServerAuthError> for ServerError {
198 fn from(value: ServerAuthError) -> Self {
199 match value {
200 ServerAuthError::InvalidAuthorizationSpecification => {
201 ServerError::Protocol(PgError::InvalidAuthorizationSpecification(
202 PgErrorInvalidAuthorizationSpecification::InvalidAuthorizationSpecification,
203 ))
204 }
205 ServerAuthError::InvalidPassword => {
206 ServerError::Protocol(PgError::InvalidAuthorizationSpecification(
207 PgErrorInvalidAuthorizationSpecification::InvalidPassword,
208 ))
209 }
210 ServerAuthError::InvalidSaslMessage(_) => ServerError::Protocol(
211 PgError::ConnectionException(PgErrorConnectionException::ProtocolViolation),
212 ),
213 ServerAuthError::UnsupportedAuthType => ServerError::Protocol(
214 PgError::FeatureNotSupported(PgErrorFeatureNotSupported::FeatureNotSupported),
215 ),
216 ServerAuthError::InvalidMessageType => ServerError::Protocol(
217 PgError::ConnectionException(PgErrorConnectionException::ProtocolViolation),
218 ),
219 }
220 }
221}
222
223const PROTOCOL_ERROR: ServerError = ServerError::Protocol(PgError::ConnectionException(
224 PgErrorConnectionException::ProtocolViolation,
225));
226const AUTH_ERROR: ServerError = ServerError::Protocol(PgError::InvalidAuthorizationSpecification(
227 PgErrorInvalidAuthorizationSpecification::InvalidAuthorizationSpecification,
228));
229const PROTOCOL_VERSION_ERROR: ServerError = ServerError::Protocol(PgError::FeatureNotSupported(
230 PgErrorFeatureNotSupported::FeatureNotSupported,
231));
232
233impl ServerState {
234 pub fn new(ssl_requirement: ConnectionSslRequirement) -> Self {
235 Self {
236 state: ServerStateImpl::Initial(Some(ssl_requirement)),
237 initial_buffer: Default::default(),
238 buffer: Default::default(),
239 }
240 }
241
242 pub fn is_ready(&self) -> bool {
243 matches!(self.state, ServerStateImpl::Ready)
244 }
245
246 pub fn is_error(&self) -> bool {
247 matches!(self.state, ServerStateImpl::Error)
248 }
249
250 pub fn is_done(&self) -> bool {
251 self.is_ready() || self.is_error()
252 }
253
254 pub fn drive(
255 &mut self,
256 drive: ConnectionDrive,
257 update: &mut impl ConnectionStateUpdate,
258 ) -> Result<(), ConnectionError> {
259 trace!("SERVER DRIVE: {:?} {:?}", self.state, drive);
260 let res = match drive {
261 ConnectionDrive::RawMessage(raw) => match self.state {
262 ServerStateImpl::Initial(..) => self.initial_buffer.push_fallible(raw, |message| {
263 self.state
264 .drive_inner(ConnectionDrive::Initial(message), update)
265 }),
266 ServerStateImpl::Authenticating(..) => self.buffer.push_fallible(raw, |message| {
267 self.state
268 .drive_inner(ConnectionDrive::Message(message), update)
269 }),
270 _ => {
271 error!("Unexpected drive in state {:?}", self.state);
272 Err(PROTOCOL_ERROR)
273 }
274 },
275 drive => self.state.drive_inner(drive, update),
276 };
277
278 match res {
279 Ok(_) => Ok(()),
280 Err(ServerError::IO(e)) => Err(e.into()),
281 Err(ServerError::Utf8Error(e)) => Err(e.into()),
282 Err(ServerError::Protocol(code)) => {
283 self.state = ServerStateImpl::Error;
284 send_error(update, code, "Connection error")?;
285 Err(PgServerError::new(code, "Connection error", Default::default()).into())
286 }
287 }
288 }
289}
290
291impl ServerStateImpl {
292 fn drive_inner(
293 &mut self,
294 drive: ConnectionDrive,
295 update: &mut impl ConnectionStateUpdate,
296 ) -> Result<(), ServerError> {
297 use ServerStateImpl::*;
298
299 match (&mut *self, drive) {
300 (Initial(ssl), ConnectionDrive::Initial(initial_message)) => {
301 match_message!(initial_message, InitialMessage {
302 (StartupMessage as startup) => {
303 let mut user = String::new();
304 let mut database = String::new();
305 for param in startup.params() {
306 if param.name() == "user" {
307 user = param.value().to_owned()?;
308 } else if param.name() == "database" {
309 database = param.value().to_owned()?;
310 }
311 trace!("param: {:?}={:?}", param.name(), param.value());
312 update.parameter(param.name().to_str()?, param.value().to_str()?);
313 }
314 if user.is_empty() {
315 return Err(AUTH_ERROR);
316 }
317 if database.is_empty() {
318 database = user.clone();
319 }
320 *self = AuthInfo(user.clone());
321 update.auth(user, database)?;
322 },
323 (SSLRequest) => {
324 let Some(ssl) = *ssl else {
325 return Err(PROTOCOL_ERROR);
326 };
327 if ssl == ConnectionSslRequirement::Disable {
328 update.send_ssl(SSLResponseBuilder { code: b'N' })?;
329 update.upgrade()?;
330 } else {
331 update.send_ssl(SSLResponseBuilder { code: b'S' })?;
332 *self = SslConnecting;
333 }
334 },
335 unknown => {
336 log_unknown_initial_message(unknown, "Initial")?;
337 }
338 });
339 }
340 (SslConnecting, ConnectionDrive::SslReady) => {
341 *self = Initial(None);
342 }
343 (SslConnecting, _) => {
344 return Err(PROTOCOL_ERROR);
345 }
346 (AuthInfo(username), ConnectionDrive::AuthInfo(auth_type, credential_data)) => {
347 let mut auth = ServerAuth::new(username.clone(), auth_type, credential_data);
348 match auth.drive(ServerAuthDrive::Initial) {
349 ServerAuthResponse::Initial(AuthType::Plain, _) => {
350 update.send(&AuthenticationCleartextPasswordBuilder::default())?;
351 }
352 ServerAuthResponse::Initial(AuthType::Md5, salt) => {
353 update.send(&AuthenticationMD5PasswordBuilder {
354 salt: TryInto::<[u8; 4]>::try_into(salt).map_err(|_| PROTOCOL_ERROR)?,
355 })?;
356 }
357 ServerAuthResponse::Initial(AuthType::ScramSha256, _) => {
358 update.send(&AuthenticationSASLBuilder {
359 mechanisms: ["SCRAM-SHA-256"],
360 })?;
361 }
362 ServerAuthResponse::Complete(..) => {
363 update.send(&AuthenticationOkBuilder::default())?;
364 *self = Synchronizing;
365 update.params()?;
366 return Ok(());
367 }
368 ServerAuthResponse::Error(e) => {
369 error!("Authentication error in initial state: {e:?}");
370 return Err(e.into());
371 }
372 response => {
373 error!("Unexpected response: {response:?}");
374 return Err(PROTOCOL_ERROR);
375 }
376 }
377 *self = Authenticating(auth);
378 }
379 (Authenticating(auth), ConnectionDrive::Message(message)) => {
380 trace!("auth = {auth:?}, initial = {}", auth.is_initial_message());
381 match_message!(message, Message {
382 (PasswordMessage as password) if matches!(auth.auth_type(), AuthType::Plain | AuthType::Md5) => {
383 match auth.drive(ServerAuthDrive::Message(auth.auth_type(), password.password().to_bytes())) {
384 ServerAuthResponse::Complete(..) => {
385 update.send(&AuthenticationOkBuilder::default())?;
386 *self = Synchronizing;
387 update.params()?;
388 }
389 ServerAuthResponse::Error(e) => {
390 error!("Authentication error for password message: {e:?}");
391 return Err(e.into())
392 },
393 response => {
394 error!("Unexpected response for password message: {response:?}");
395 return Err(PROTOCOL_ERROR);
396 }
397 }
398 },
399 (SASLInitialResponse as sasl) if auth.is_initial_message() => {
400 if sasl.mechanism() != "SCRAM-SHA-256" {
401 error!("Unexpected mechanism: {:?}", sasl.mechanism());
402 return Err(PROTOCOL_ERROR);
403 }
404 match auth.drive(ServerAuthDrive::Message(AuthType::ScramSha256, sasl.response().as_ref())) {
405 ServerAuthResponse::Continue(final_message) => {
406 update.send(&AuthenticationSASLContinueBuilder {
407 data: &final_message,
408 })?;
409 }
410 ServerAuthResponse::Error(e) => {
411 error!("Authentication error for SASL initial response: {e:?}");
412 return Err(e.into())
413 },
414 response => {
415 error!("Unexpected response for SASL initial response: {response:?}");
416 return Err(PROTOCOL_ERROR);
417 }
418 }
419 },
420 (SASLResponse as sasl) if !auth.is_initial_message() => {
421 match auth.drive(ServerAuthDrive::Message(AuthType::ScramSha256, sasl.response().as_ref())) {
422 ServerAuthResponse::Complete(data) => {
423 update.send(&AuthenticationSASLFinalBuilder {
424 data,
425 })?;
426 update.send(&AuthenticationOkBuilder::default())?;
427 *self = Synchronizing;
428 update.params()?;
429 }
430 ServerAuthResponse::Error(e) => {
431 error!("Authentication error for SASL response: {e:?}");
432 return Err(e.into())
433 },
434 response => {
435 error!("Unexpected response for SASL response: {response:?}");
436 return Err(PROTOCOL_ERROR);
437 }
438 }
439 },
440 unknown => {
441 log_unknown_message(unknown, "Authenticating")?;
442 }
443 });
444 }
445 (Synchronizing, ConnectionDrive::Parameter(name, value)) => {
446 update.send(&ParameterStatusBuilder { name, value })?;
447 }
448 (Synchronizing, ConnectionDrive::Ready(pid, key)) => {
449 update.send(&BackendKeyDataBuilder { pid, key })?;
450 update.send(&ReadyForQueryBuilder { status: b'I' })?;
451 *self = Ready;
452 }
453 (_, ConnectionDrive::Fail(error, _)) => {
454 return Err(ServerError::Protocol(error));
455 }
456 _ => {
457 error!("Unexpected drive in state {:?}", self);
458 return Err(PROTOCOL_ERROR);
459 }
460 }
461
462 Ok(())
463 }
464}
465
466fn log_unknown_initial_message(
467 message: Result<InitialMessage, ParseError>,
468 state: &str,
469) -> Result<(), ServerError> {
470 match message {
471 Ok(message) => {
472 warn!(
473 "Unexpected message {:?} (length {}) received in {} state",
474 message.protocol_version(),
475 message.mlen(),
476 state
477 );
478 Err(PROTOCOL_VERSION_ERROR)
479 }
480 Err(e) => {
481 error!("Corrupted message received in {} state {:?}", state, e);
482 Err(PROTOCOL_ERROR)
483 }
484 }
485}
486
487fn log_unknown_message(
488 message: Result<Message, ParseError>,
489 state: &str,
490) -> Result<(), ServerError> {
491 match message {
492 Ok(message) => {
493 warn!(
494 "Unexpected message {:?} (length {}) received in {} state",
495 message.mtype(),
496 message.mlen(),
497 state
498 );
499 Ok(())
500 }
501 Err(e) => {
502 error!("Corrupted message received in {} state {:?}", state, e);
503 Err(PROTOCOL_ERROR)
504 }
505 }
506}