1use super::{invalid_state, ConnectionError, ConnectionSslRequirement, Credentials};
2use crate::md5::md5_password;
3use crate::postgres::SslError;
4use crate::scram::{
5 generate_salted_password, ClientEnvironment, ClientTransaction, SCRAMError, Sha256Out,
6};
7use crate::AuthType;
8use base64::Engine;
9use gel_pg_protocol::{errors::PgServerError, prelude::*, protocol::*};
10use tracing::{error, trace, warn};
11
12#[derive(Debug)]
13struct ClientEnvironmentImpl {
14 credentials: Credentials,
15}
16
17impl ClientEnvironment for ClientEnvironmentImpl {
18 fn generate_nonce(&self) -> String {
19 let nonce: [u8; 32] = rand::random();
20 base64::engine::general_purpose::STANDARD.encode(nonce)
21 }
22 fn get_salted_password(&self, salt: &[u8], iterations: usize) -> Sha256Out {
23 generate_salted_password(self.credentials.password.as_bytes(), salt, iterations)
24 }
25}
26
27#[derive(Debug)]
28enum ConnectionStateImpl {
29 SslInitializing(Credentials, ConnectionSslRequirement),
32 SslWaiting(Credentials, ConnectionSslRequirement),
34 SslConnecting(Credentials),
36 Initializing(Credentials),
39 Connecting(Credentials, bool),
42 Scram(ClientTransaction, ClientEnvironmentImpl),
45 Connected,
48 Ready,
50 Error,
52}
53
54#[derive(Clone, Copy, Debug)]
55pub enum ConnectionStateType {
56 Connecting,
57 SslConnecting,
58 Authenticating,
59 Synchronizing,
60 Ready,
61}
62
63#[derive(Debug)]
64pub enum ConnectionDrive<'a> {
65 Initial,
66 Message(Result<Message<'a>, ParseError>),
67 SslResponse(SSLResponse<'a>),
68 SslReady,
69}
70
71pub trait ConnectionStateSend {
72 fn send_initial<'a, M>(
73 &mut self,
74 message: impl IntoInitialBuilder<'a, M>,
75 ) -> Result<(), std::io::Error>;
76 fn send<'a, M>(
77 &mut self,
78 message: impl IntoFrontendBuilder<'a, M>,
79 ) -> Result<(), std::io::Error>;
80 fn upgrade(&mut self) -> Result<(), std::io::Error>;
81}
82
83#[allow(unused)]
85pub trait ConnectionStateUpdate: ConnectionStateSend {
86 fn parameter(&mut self, name: &str, value: &str) {}
87 fn cancellation_key(&mut self, pid: i32, key: i32) {}
88 fn state_changed(&mut self, state: ConnectionStateType) {}
89 fn server_error(&mut self, error: &PgServerError) {
90 error!("Server error during handshake: {:?}", error);
91 }
92 fn server_notice(&mut self, notice: &PgServerError) {
93 warn!("Server notice during handshake: {:?}", notice);
94 }
95 fn auth(&mut self, auth: AuthType) {}
96}
97
98#[derive(Debug)]
122pub struct ConnectionState(ConnectionStateImpl);
123
124impl ConnectionState {
125 pub fn new(credentials: Credentials, ssl_mode: ConnectionSslRequirement) -> Self {
126 if ssl_mode == ConnectionSslRequirement::Disable {
127 Self(ConnectionStateImpl::Initializing(credentials))
128 } else {
129 Self(ConnectionStateImpl::SslInitializing(credentials, ssl_mode))
130 }
131 }
132
133 pub fn is_ready(&self) -> bool {
134 matches!(self.0, ConnectionStateImpl::Ready)
135 }
136
137 pub fn is_error(&self) -> bool {
138 matches!(self.0, ConnectionStateImpl::Error)
139 }
140
141 pub fn is_done(&self) -> bool {
142 self.is_ready() || self.is_error()
143 }
144
145 pub fn read_ssl_response(&self) -> bool {
146 matches!(self.0, ConnectionStateImpl::SslWaiting(..))
147 }
148
149 pub fn drive(
150 &mut self,
151 drive: ConnectionDrive,
152 update: &mut impl ConnectionStateUpdate,
153 ) -> Result<(), ConnectionError> {
154 use ConnectionStateImpl::*;
155 trace!("Received drive {drive:?} in state {:?}", self.0);
156 match (&mut self.0, drive) {
157 (SslInitializing(credentials, mode), ConnectionDrive::Initial) => {
158 update.send_initial(&SSLRequestBuilder::default())?;
159 self.0 = SslWaiting(std::mem::take(credentials), *mode);
160 update.state_changed(ConnectionStateType::Connecting);
161 }
162 (SslWaiting(credentials, mode), ConnectionDrive::SslResponse(response)) => {
163 if *mode == ConnectionSslRequirement::Disable {
164 return Err(invalid_state!("SSL mode is Disable in SslWaiting state"));
166 }
167
168 if response.code() == b'S' {
169 update.upgrade()?;
171 self.0 = SslConnecting(std::mem::take(credentials));
172 update.state_changed(ConnectionStateType::SslConnecting);
173 } else if response.code() == b'N' {
174 if *mode == ConnectionSslRequirement::Required {
176 return Err(ConnectionError::SslError(SslError::SslRequiredByClient));
177 }
178 Self::send_startup_message(credentials, update)?;
179 self.0 = Connecting(std::mem::take(credentials), false);
180 } else {
181 return Err(ConnectionError::UnexpectedResponse(format!(
182 "Unexpected SSL response from server: {:?}",
183 response.code() as char
184 )));
185 }
186 }
187 (SslConnecting(credentials), ConnectionDrive::SslReady) => {
188 Self::send_startup_message(credentials, update)?;
189 self.0 = Connecting(std::mem::take(credentials), false);
190 }
191 (Initializing(credentials), ConnectionDrive::Initial) => {
192 Self::send_startup_message(credentials, update)?;
193 self.0 = Connecting(std::mem::take(credentials), false);
194 update.state_changed(ConnectionStateType::Connecting);
195 }
196 (Connecting(credentials, sent_auth), ConnectionDrive::Message(message)) => {
197 match_message!(message, Backend {
198 (AuthenticationOk) => {
199 if !*sent_auth {
200 update.auth(AuthType::Trust);
201 }
202 trace!("auth ok");
203 self.0 = Connected;
204 update.state_changed(ConnectionStateType::Synchronizing);
205 },
206 (AuthenticationSASL as sasl) => {
207 *sent_auth = true;
208 let mut found_scram_sha256 = false;
209 for mech in sasl.mechanisms() {
210 trace!("auth sasl: {:?}", mech);
211 if mech == "SCRAM-SHA-256" {
212 found_scram_sha256 = true;
213 break;
214 }
215 }
216 if !found_scram_sha256 {
217 return Err(ConnectionError::UnexpectedResponse("Server requested SASL authentication but does not support SCRAM-SHA-256".into()));
218 }
219 let credentials = credentials.clone();
220 let mut tx = ClientTransaction::new("".into());
221 let env = ClientEnvironmentImpl { credentials };
222 let Some(initial_message) = tx.process_message(&[], &env)? else {
223 return Err(SCRAMError::ProtocolError.into());
224 };
225 update.auth(AuthType::ScramSha256);
226 update.send(&SASLInitialResponseBuilder {
227 mechanism: "SCRAM-SHA-256",
228 response: initial_message.as_slice(),
229 })?;
230 self.0 = Scram(tx, env);
231 update.state_changed(ConnectionStateType::Authenticating);
232 },
233 (AuthenticationMD5Password as md5) => {
234 *sent_auth = true;
235 trace!("auth md5");
236 let md5_hash = md5_password(&credentials.password, &credentials.username, md5.salt());
237 update.auth(AuthType::Md5);
238 update.send(&PasswordMessageBuilder {
239 password: &md5_hash,
240 })?;
241 },
242 (AuthenticationCleartextPassword) => {
243 *sent_auth = true;
244 trace!("auth cleartext");
245 update.auth(AuthType::Plain);
246 update.send(&PasswordMessageBuilder {
247 password: &credentials.password,
248 })?;
249 },
250 (NoticeResponse as notice) => {
251 let err = PgServerError::from(notice);
252 update.server_notice(&err);
253 },
254 (ErrorResponse as error) => {
255 self.0 = Error;
256 let err = PgServerError::from(error);
257 update.server_error(&err);
258 return Err(err.into());
259 },
260 message => {
261 log_unknown_message(message, "Connecting")?
262 },
263 });
264 }
265 (Scram(tx, env), ConnectionDrive::Message(message)) => {
266 match_message!(message, Backend {
267 (AuthenticationSASLContinue as sasl) => {
268 let Some(message) = tx.process_message(&sasl.data(), env)? else {
269 return Err(SCRAMError::ProtocolError.into());
270 };
271 update.send(&SASLResponseBuilder {
272 response: &message,
273 })?;
274 },
275 (AuthenticationSASLFinal as sasl) => {
276 let None = tx.process_message(&sasl.data(), env)? else {
277 return Err(SCRAMError::ProtocolError.into());
278 };
279 },
280 (AuthenticationOk) => {
281 trace!("auth ok");
282 self.0 = Connected;
283 update.state_changed(ConnectionStateType::Synchronizing);
284 },
285 (AuthenticationMessage as auth) => {
286 trace!("SCRAM Unknown auth message: {}", auth.status())
287 },
288 (NoticeResponse as notice) => {
289 let err = PgServerError::from(notice);
290 update.server_notice(&err);
291 },
292 (ErrorResponse as error) => {
293 self.0 = Error;
294 let err = PgServerError::from(error);
295 update.server_error(&err);
296 return Err(err.into());
297 },
298 message => {
299 log_unknown_message(message, "SCRAM")?
300 },
301 });
302 }
303 (Connected, ConnectionDrive::Message(message)) => {
304 match_message!(message, Backend {
305 (ParameterStatus as param) => {
306 trace!("param: {:?}={:?}", param.name(), param.value());
307 update.parameter(param.name().try_into()?, param.value().try_into()?);
308 },
309 (BackendKeyData as key_data) => {
310 trace!("key={:?} pid={:?}", key_data.key(), key_data.pid());
311 update.cancellation_key(key_data.pid(), key_data.key());
312 },
313 (ReadyForQuery as ready) => {
314 trace!("ready: {:?}", ready.status() as char);
315 trace!("-> Ready");
316 self.0 = Ready;
317 update.state_changed(ConnectionStateType::Ready);
318 },
319 (NoticeResponse as notice) => {
320 let err = PgServerError::from(notice);
321 update.server_notice(&err);
322 },
323 (ErrorResponse as error) => {
324 self.0 = Error;
325 let err = PgServerError::from(error);
326 update.server_error(&err);
327 return Err(err.into());
328 },
329 message => {
330 log_unknown_message(message, "Connected")?
331 },
332 });
333 }
334 (Ready, _) | (Error, _) => {
335 return Err(invalid_state!("Unexpected drive for Ready or Error state"))
336 }
337 _ => return Err(invalid_state!("Unexpected (state, drive) combination")),
338 }
339 Ok(())
340 }
341
342 fn send_startup_message(
343 credentials: &Credentials,
344 update: &mut impl ConnectionStateUpdate,
345 ) -> Result<(), std::io::Error> {
346 let mut params = vec![
347 StartupNameValueBuilder {
348 name: "user",
349 value: &credentials.username,
350 },
351 StartupNameValueBuilder {
352 name: "database",
353 value: &credentials.database,
354 },
355 ];
356 for (name, value) in &credentials.server_settings {
357 params.push(StartupNameValueBuilder { name, value })
358 }
359
360 update.send_initial(&StartupMessageBuilder { params: || ¶ms })
361 }
362}
363
364fn log_unknown_message(
365 message: Result<Message, ParseError>,
366 state: &str,
367) -> Result<(), ParseError> {
368 match message {
369 Ok(message) => {
370 warn!(
371 "Unexpected message {:?} (length {}) received in {} state",
372 message.mtype(),
373 message.mlen(),
374 state
375 );
376 Ok(())
377 }
378 Err(e) => {
379 error!("Corrupted message received in {} state", state);
380 Err(e)
381 }
382 }
383}