gel_auth/postgres/
mod.rs

1use std::collections::HashMap;
2
3#[derive(Clone, Copy, PartialEq, Eq, Default, Debug)]
4pub enum ConnectionSslRequirement {
5    /// SSL is disabled, and it is an error to attempt to use it.
6    #[default]
7    Disable,
8    /// SSL is optional, but we prefer to use it.
9    Optional,
10    /// SSL is required and it is an error to reject it.
11    Required,
12}
13
14mod client_state_machine;
15mod server_state_machine;
16
17pub mod client {
18    pub use super::client_state_machine::*;
19}
20
21pub mod server {
22    pub use super::server_state_machine::*;
23}
24
25macro_rules! __invalid_state {
26    ($error:literal) => {{
27        eprintln!(
28            "Invalid connection state: {}\n{}",
29            $error,
30            ::std::backtrace::Backtrace::capture()
31        );
32        #[allow(deprecated)]
33        $crate::postgres::ConnectionError::__InvalidState
34    }};
35}
36pub(crate) use __invalid_state as invalid_state;
37
38#[derive(Debug, derive_more::Error, derive_more::Display, derive_more::From)]
39pub enum ConnectionError {
40    /// Invalid state error, suggesting a logic error in code rather than a server or client failure.
41    /// Use the `invalid_state!` macro instead which will print a backtrace.
42    #[display("Invalid state")]
43    #[deprecated = "Use invalid_state!"]
44    __InvalidState,
45
46    /// Error returned by the server.
47    #[display("Server error: {_0}")]
48    ServerError(#[from] gel_pg_protocol::errors::PgServerError),
49
50    /// The server sent something we didn't expect
51    #[display("Unexpected server response: {_0}")]
52    UnexpectedResponse(#[error(not(source))] String),
53
54    /// Error related to SCRAM authentication.
55    #[display("SCRAM: {_0}")]
56    Scram(#[from] crate::scram::SCRAMError),
57
58    /// I/O error encountered during connection operations.
59    #[display("I/O error: {_0}")]
60    Io(#[from] std::io::Error),
61
62    /// UTF-8 decoding error.
63    #[display("UTF8 error: {_0}")]
64    Utf8Error(#[from] std::str::Utf8Error),
65
66    /// SSL-related error.
67    #[display("SSL error: {_0}")]
68    SslError(#[from] SslError),
69
70    #[display("Protocol error: {_0}")]
71    ParseError(#[from] gel_pg_protocol::prelude::ParseError),
72}
73
74#[derive(Debug, derive_more::Error, derive_more::Display)]
75pub enum SslError {
76    #[display("SSL is not supported by this client transport")]
77    SslUnsupportedByClient,
78    #[display("SSL was required by the client, but not offered by server (rejected SSL)")]
79    SslRequiredByClient,
80}
81
82/// A sufficient set of required parameters to connect to a given transport.
83#[derive(Clone, Default, derive_more::Debug)]
84pub struct Credentials {
85    pub username: String,
86    #[debug(skip)]
87    pub password: String,
88    pub database: String,
89    pub server_settings: HashMap<String, String>,
90}
91
92#[cfg(test)]
93mod tests {
94    use super::*;
95    use crate::*;
96    use gel_pg_protocol::errors::*;
97    use gel_pg_protocol::prelude::*;
98    use gel_pg_protocol::protocol::*;
99    use rstest::rstest;
100    use std::collections::VecDeque;
101
102    #[derive(Debug, Default)]
103    struct ConnectionPipe {
104        cmsg: VecDeque<(bool, Vec<u8>)>,
105        smsg: VecDeque<(bool, Vec<u8>)>,
106        sparams: bool,
107        sauth_user: Option<String>,
108        cauth: Option<AuthType>,
109        cerror: Option<PgError>,
110        serror: Option<PgError>,
111    }
112
113    impl client::ConnectionStateUpdate for ConnectionPipe {
114        fn auth(&mut self, auth: AuthType) {
115            eprintln!("Client: Auth = {auth:?}");
116            self.cauth = Some(auth);
117        }
118        fn cancellation_key(&mut self, _pid: i32, _key: i32) {}
119        fn parameter(&mut self, _name: &str, _value: &str) {}
120        fn server_error(&mut self, error: &PgServerError) {
121            self.cerror = Some(error.code);
122        }
123        fn state_changed(&mut self, state: client::ConnectionStateType) {
124            eprintln!("Client: Start = {state:?}");
125        }
126    }
127
128    impl client::ConnectionStateSend for ConnectionPipe {
129        fn send<'a, M>(
130            &mut self,
131            message: impl IntoFrontendBuilder<'a, M>,
132        ) -> Result<(), std::io::Error> {
133            let message = message.into_builder();
134            eprintln!("Client -> Server {message:?}");
135            self.smsg.push_back((false, message.to_vec()));
136            Ok(())
137        }
138        fn send_initial<'a, M>(
139            &mut self,
140            message: impl IntoInitialBuilder<'a, M>,
141        ) -> Result<(), std::io::Error> {
142            let message = message.into_builder();
143            eprintln!("Client -> Server {message:?}");
144            self.smsg.push_back((true, message.to_vec()));
145            Ok(())
146        }
147        fn upgrade(&mut self) -> Result<(), std::io::Error> {
148            unimplemented!()
149        }
150    }
151
152    impl server::ConnectionStateUpdate for ConnectionPipe {
153        fn state_changed(&mut self, _state: server::ConnectionStateType) {}
154        fn parameter(&mut self, _name: &str, _value: &str) {}
155        fn server_error(&mut self, error: &PgServerError) {
156            self.serror = Some(error.code);
157        }
158    }
159
160    impl server::ConnectionStateSend for ConnectionPipe {
161        fn auth(&mut self, user: String, database: String) -> Result<(), std::io::Error> {
162            eprintln!("Server: auth request {user}/{database}");
163            self.sauth_user = Some(user);
164            Ok(())
165        }
166        fn params(&mut self) -> Result<(), std::io::Error> {
167            eprintln!("Server: param request");
168            self.sparams = true;
169            Ok(())
170        }
171        fn send<'a, M>(
172            &mut self,
173            message: impl IntoBackendBuilder<'a, M>,
174        ) -> Result<(), std::io::Error> {
175            let message = message.into_builder();
176            eprintln!("Server -> Client {message:?}");
177            self.cmsg.push_back((false, message.to_vec()));
178            Ok(())
179        }
180        fn send_ssl(&mut self, message: SSLResponseBuilder) -> Result<(), std::io::Error> {
181            self.cmsg.push_back((true, message.to_vec()));
182            Ok(())
183        }
184        fn upgrade(&mut self) -> Result<(), std::io::Error> {
185            unimplemented!()
186        }
187    }
188
189    /// We test the full matrix of server and client combinations.
190    #[rstest]
191    fn test_both(
192        #[values(
193            AuthType::Deny,
194            AuthType::Trust,
195            AuthType::Plain,
196            AuthType::Md5,
197            AuthType::ScramSha256
198        )]
199        auth_type: AuthType,
200        #[values(
201            AuthType::Deny,
202            AuthType::Trust,
203            AuthType::Plain,
204            AuthType::Md5,
205            AuthType::ScramSha256
206        )]
207        credential_type: AuthType,
208        #[values(true, false)] correct_password: bool,
209    ) {
210        let mut client = client::ConnectionState::new(
211            Credentials {
212                username: "user".to_string(),
213                password: "password".to_string(),
214                database: "database".to_string(),
215                ..Default::default()
216            },
217            ConnectionSslRequirement::Disable,
218        );
219        let mut server = server::ServerState::new(ConnectionSslRequirement::Disable);
220
221        // We test all variations here, but not all combinations will result in
222        // valid auth, even with a correct password.
223        let expect_success = match (auth_type, credential_type, correct_password) {
224            // If the server is set to trust, we always succeed (as no password is exchanged)
225            (AuthType::Trust, ..) => true,
226            // If the server is asking for a denial auth type, it'll always fail
227            (AuthType::Deny, ..) => false,
228            // If the credential is denial, it'll always fail
229            (_, AuthType::Deny, _) => false,
230            // SCRAM succeeds if the credential is SCRAM or Password (it cannot
231            // succeed with a Trust credential because the server also sends a
232            // verifier to the client.
233            (AuthType::ScramSha256, AuthType::ScramSha256 | AuthType::Plain, correct) => correct,
234            (AuthType::ScramSha256, _, _) => false,
235            // Other auth types will always succeed if credential type is trust
236            (_, AuthType::Trust, _) => true,
237            // MD5 succeeds if the credential is not SCRAM
238            (AuthType::Md5, AuthType::Md5 | AuthType::Plain, correct) => correct,
239            (AuthType::Md5, _, _) => false,
240            // Plain text works in all cases
241            (AuthType::Plain, _, correct) => correct,
242        };
243
244        let mut client_error = false;
245        let mut server_error = false;
246
247        let mut pipe = ConnectionPipe::default();
248        // This one can never fail
249        client
250            .drive(client::ConnectionDrive::Initial, &mut pipe)
251            .unwrap();
252        let mut max_iterations: i32 = 100;
253        loop {
254            max_iterations -= 1;
255            if max_iterations == 0 {
256                panic!("Failed to complete");
257            }
258            if let Some(user) = pipe.sauth_user.take() {
259                eprintln!("Sending auth");
260                let password = if correct_password {
261                    "password".to_owned()
262                } else {
263                    "incorrect".to_owned()
264                };
265                let data = CredentialData::new(credential_type, user.clone(), password);
266                server_error |= server
267                    .drive(
268                        server::ConnectionDrive::AuthInfo(auth_type, data),
269                        &mut pipe,
270                    )
271                    .is_err();
272            }
273            if pipe.sparams {
274                server_error |= server
275                    .drive(
276                        server::ConnectionDrive::Parameter("param1".to_owned(), "value".to_owned()),
277                        &mut pipe,
278                    )
279                    .is_err();
280                server_error |= server
281                    .drive(
282                        server::ConnectionDrive::Parameter("param2".to_owned(), "value".to_owned()),
283                        &mut pipe,
284                    )
285                    .is_err();
286                server_error |= server
287                    .drive(server::ConnectionDrive::Ready(1234, 4567), &mut pipe)
288                    .is_err();
289            }
290            while let Some((initial, msg)) = pipe.smsg.pop_front() {
291                if initial {
292                    server_error |= server
293                        .drive(
294                            server::ConnectionDrive::Initial(InitialMessage::new(&msg)),
295                            &mut pipe,
296                        )
297                        .is_err();
298                } else {
299                    server_error |= server
300                        .drive(
301                            server::ConnectionDrive::Message(Message::new(&msg)),
302                            &mut pipe,
303                        )
304                        .is_err();
305                }
306            }
307            while let Some((ssl, msg)) = pipe.cmsg.pop_front() {
308                if ssl {
309                    unimplemented!()
310                } else {
311                    client_error |= client
312                        .drive(
313                            client::ConnectionDrive::Message(Message::new(&msg)),
314                            &mut pipe,
315                        )
316                        .is_err();
317                }
318            }
319            if client.is_done() && server.is_done() {
320                break;
321            }
322        }
323
324        if expect_success {
325            assert!(
326                client.is_ready() && server.is_ready(),
327                "client={client:?} server={server:?}"
328            );
329        } else {
330            assert!(client_error && server_error);
331            assert!(pipe.cerror.is_some() && pipe.serror.is_some());
332            assert!(client.is_error() && server.is_error())
333        }
334    }
335}