1use std::collections::HashMap;
2
3#[derive(Clone, Copy, PartialEq, Eq, Default, Debug)]
4pub enum ConnectionSslRequirement {
5 #[default]
7 Disable,
8 Optional,
10 Required,
12}
13
14mod client_state_machine;
15mod server_state_machine;
16
17pub mod client {
18 pub use super::client_state_machine::*;
19}
20
21pub mod server {
22 pub use super::server_state_machine::*;
23}
24
25macro_rules! __invalid_state {
26 ($error:literal) => {{
27 eprintln!(
28 "Invalid connection state: {}\n{}",
29 $error,
30 ::std::backtrace::Backtrace::capture()
31 );
32 #[allow(deprecated)]
33 $crate::postgres::ConnectionError::__InvalidState
34 }};
35}
36pub(crate) use __invalid_state as invalid_state;
37
38#[derive(Debug, derive_more::Error, derive_more::Display, derive_more::From)]
39pub enum ConnectionError {
40 #[display("Invalid state")]
43 #[deprecated = "Use invalid_state!"]
44 __InvalidState,
45
46 #[display("Server error: {_0}")]
48 ServerError(#[from] gel_pg_protocol::errors::PgServerError),
49
50 #[display("Unexpected server response: {_0}")]
52 UnexpectedResponse(#[error(not(source))] String),
53
54 #[display("SCRAM: {_0}")]
56 Scram(#[from] crate::scram::SCRAMError),
57
58 #[display("I/O error: {_0}")]
60 Io(#[from] std::io::Error),
61
62 #[display("UTF8 error: {_0}")]
64 Utf8Error(#[from] std::str::Utf8Error),
65
66 #[display("SSL error: {_0}")]
68 SslError(#[from] SslError),
69
70 #[display("Protocol error: {_0}")]
71 ParseError(#[from] gel_pg_protocol::prelude::ParseError),
72}
73
74#[derive(Debug, derive_more::Error, derive_more::Display)]
75pub enum SslError {
76 #[display("SSL is not supported by this client transport")]
77 SslUnsupportedByClient,
78 #[display("SSL was required by the client, but not offered by server (rejected SSL)")]
79 SslRequiredByClient,
80}
81
82#[derive(Clone, Default, derive_more::Debug)]
84pub struct Credentials {
85 pub username: String,
86 #[debug(skip)]
87 pub password: String,
88 pub database: String,
89 pub server_settings: HashMap<String, String>,
90}
91
92#[cfg(test)]
93mod tests {
94 use super::*;
95 use crate::*;
96 use gel_pg_protocol::errors::*;
97 use gel_pg_protocol::prelude::*;
98 use gel_pg_protocol::protocol::*;
99 use rstest::rstest;
100 use std::collections::VecDeque;
101
102 #[derive(Debug, Default)]
103 struct ConnectionPipe {
104 cmsg: VecDeque<(bool, Vec<u8>)>,
105 smsg: VecDeque<(bool, Vec<u8>)>,
106 sparams: bool,
107 sauth_user: Option<String>,
108 cauth: Option<AuthType>,
109 cerror: Option<PgError>,
110 serror: Option<PgError>,
111 }
112
113 impl client::ConnectionStateUpdate for ConnectionPipe {
114 fn auth(&mut self, auth: AuthType) {
115 eprintln!("Client: Auth = {auth:?}");
116 self.cauth = Some(auth);
117 }
118 fn cancellation_key(&mut self, _pid: i32, _key: i32) {}
119 fn parameter(&mut self, _name: &str, _value: &str) {}
120 fn server_error(&mut self, error: &PgServerError) {
121 self.cerror = Some(error.code);
122 }
123 fn state_changed(&mut self, state: client::ConnectionStateType) {
124 eprintln!("Client: Start = {state:?}");
125 }
126 }
127
128 impl client::ConnectionStateSend for ConnectionPipe {
129 fn send<'a, M>(
130 &mut self,
131 message: impl IntoFrontendBuilder<'a, M>,
132 ) -> Result<(), std::io::Error> {
133 let message = message.into_builder();
134 eprintln!("Client -> Server {message:?}");
135 self.smsg.push_back((false, message.to_vec()));
136 Ok(())
137 }
138 fn send_initial<'a, M>(
139 &mut self,
140 message: impl IntoInitialBuilder<'a, M>,
141 ) -> Result<(), std::io::Error> {
142 let message = message.into_builder();
143 eprintln!("Client -> Server {message:?}");
144 self.smsg.push_back((true, message.to_vec()));
145 Ok(())
146 }
147 fn upgrade(&mut self) -> Result<(), std::io::Error> {
148 unimplemented!()
149 }
150 }
151
152 impl server::ConnectionStateUpdate for ConnectionPipe {
153 fn state_changed(&mut self, _state: server::ConnectionStateType) {}
154 fn parameter(&mut self, _name: &str, _value: &str) {}
155 fn server_error(&mut self, error: &PgServerError) {
156 self.serror = Some(error.code);
157 }
158 }
159
160 impl server::ConnectionStateSend for ConnectionPipe {
161 fn auth(&mut self, user: String, database: String) -> Result<(), std::io::Error> {
162 eprintln!("Server: auth request {user}/{database}");
163 self.sauth_user = Some(user);
164 Ok(())
165 }
166 fn params(&mut self) -> Result<(), std::io::Error> {
167 eprintln!("Server: param request");
168 self.sparams = true;
169 Ok(())
170 }
171 fn send<'a, M>(
172 &mut self,
173 message: impl IntoBackendBuilder<'a, M>,
174 ) -> Result<(), std::io::Error> {
175 let message = message.into_builder();
176 eprintln!("Server -> Client {message:?}");
177 self.cmsg.push_back((false, message.to_vec()));
178 Ok(())
179 }
180 fn send_ssl(&mut self, message: SSLResponseBuilder) -> Result<(), std::io::Error> {
181 self.cmsg.push_back((true, message.to_vec()));
182 Ok(())
183 }
184 fn upgrade(&mut self) -> Result<(), std::io::Error> {
185 unimplemented!()
186 }
187 }
188
189 #[rstest]
191 fn test_both(
192 #[values(
193 AuthType::Deny,
194 AuthType::Trust,
195 AuthType::Plain,
196 AuthType::Md5,
197 AuthType::ScramSha256
198 )]
199 auth_type: AuthType,
200 #[values(
201 AuthType::Deny,
202 AuthType::Trust,
203 AuthType::Plain,
204 AuthType::Md5,
205 AuthType::ScramSha256
206 )]
207 credential_type: AuthType,
208 #[values(true, false)] correct_password: bool,
209 ) {
210 let mut client = client::ConnectionState::new(
211 Credentials {
212 username: "user".to_string(),
213 password: "password".to_string(),
214 database: "database".to_string(),
215 ..Default::default()
216 },
217 ConnectionSslRequirement::Disable,
218 );
219 let mut server = server::ServerState::new(ConnectionSslRequirement::Disable);
220
221 let expect_success = match (auth_type, credential_type, correct_password) {
224 (AuthType::Trust, ..) => true,
226 (AuthType::Deny, ..) => false,
228 (_, AuthType::Deny, _) => false,
230 (AuthType::ScramSha256, AuthType::ScramSha256 | AuthType::Plain, correct) => correct,
234 (AuthType::ScramSha256, _, _) => false,
235 (_, AuthType::Trust, _) => true,
237 (AuthType::Md5, AuthType::Md5 | AuthType::Plain, correct) => correct,
239 (AuthType::Md5, _, _) => false,
240 (AuthType::Plain, _, correct) => correct,
242 };
243
244 let mut client_error = false;
245 let mut server_error = false;
246
247 let mut pipe = ConnectionPipe::default();
248 client
250 .drive(client::ConnectionDrive::Initial, &mut pipe)
251 .unwrap();
252 let mut max_iterations: i32 = 100;
253 loop {
254 max_iterations -= 1;
255 if max_iterations == 0 {
256 panic!("Failed to complete");
257 }
258 if let Some(user) = pipe.sauth_user.take() {
259 eprintln!("Sending auth");
260 let password = if correct_password {
261 "password".to_owned()
262 } else {
263 "incorrect".to_owned()
264 };
265 let data = CredentialData::new(credential_type, user.clone(), password);
266 server_error |= server
267 .drive(
268 server::ConnectionDrive::AuthInfo(auth_type, data),
269 &mut pipe,
270 )
271 .is_err();
272 }
273 if pipe.sparams {
274 server_error |= server
275 .drive(
276 server::ConnectionDrive::Parameter("param1".to_owned(), "value".to_owned()),
277 &mut pipe,
278 )
279 .is_err();
280 server_error |= server
281 .drive(
282 server::ConnectionDrive::Parameter("param2".to_owned(), "value".to_owned()),
283 &mut pipe,
284 )
285 .is_err();
286 server_error |= server
287 .drive(server::ConnectionDrive::Ready(1234, 4567), &mut pipe)
288 .is_err();
289 }
290 while let Some((initial, msg)) = pipe.smsg.pop_front() {
291 if initial {
292 server_error |= server
293 .drive(
294 server::ConnectionDrive::Initial(InitialMessage::new(&msg)),
295 &mut pipe,
296 )
297 .is_err();
298 } else {
299 server_error |= server
300 .drive(
301 server::ConnectionDrive::Message(Message::new(&msg)),
302 &mut pipe,
303 )
304 .is_err();
305 }
306 }
307 while let Some((ssl, msg)) = pipe.cmsg.pop_front() {
308 if ssl {
309 unimplemented!()
310 } else {
311 client_error |= client
312 .drive(
313 client::ConnectionDrive::Message(Message::new(&msg)),
314 &mut pipe,
315 )
316 .is_err();
317 }
318 }
319 if client.is_done() && server.is_done() {
320 break;
321 }
322 }
323
324 if expect_success {
325 assert!(
326 client.is_ready() && server.is_ready(),
327 "client={client:?} server={server:?}"
328 );
329 } else {
330 assert!(client_error && server_error);
331 assert!(pipe.cerror.is_some() && pipe.serror.is_some());
332 assert!(client.is_error() && server.is_error())
333 }
334 }
335}