1use std::collections::HashMap;
4use std::net::SocketAddr;
5use std::sync::Arc;
6
7pub use postgres_types::Type;
8#[cfg(any(feature = "_ring", feature = "_aws-lc-rs"))]
9use rustls_pki_types::CertificateDer;
10
11use crate::error::PgWireError;
12use crate::messages::response::TransactionStatus;
13use crate::messages::startup::SecretKey;
14use crate::messages::ProtocolVersion;
15
16pub mod auth;
17pub mod cancel;
18#[cfg(feature = "client-api")]
19pub mod client;
20pub mod copy;
21pub mod portal;
22pub mod query;
23pub mod results;
24pub mod stmt;
25pub mod store;
26pub mod transaction;
27
28pub const DEFAULT_NAME: &str = "POSTGRESQL_DEFAULT_NAME";
29
30#[derive(Debug, Clone, Copy, Default)]
31pub enum PgWireConnectionState {
32 #[default]
33 AwaitingSslRequest,
34 AwaitingStartup,
35 AuthenticationInProgress,
36 ReadyForQuery,
37 QueryInProgress,
38 CopyInProgress(bool),
39 AwaitingSync,
40}
41
42pub trait ClientInfo {
45 fn socket_addr(&self) -> SocketAddr;
46
47 fn is_secure(&self) -> bool;
48
49 fn protocol_version(&self) -> ProtocolVersion;
50
51 fn set_protocol_version(&mut self, version: ProtocolVersion);
52
53 fn pid_and_secret_key(&self) -> (i32, SecretKey);
54
55 fn set_pid_and_secret_key(&mut self, pid: i32, secret_key: SecretKey);
56
57 fn state(&self) -> PgWireConnectionState;
58
59 fn set_state(&mut self, new_state: PgWireConnectionState);
60
61 fn transaction_status(&self) -> TransactionStatus;
62
63 fn set_transaction_status(&mut self, new_status: TransactionStatus);
64
65 fn metadata(&self) -> &HashMap<String, String>;
66
67 fn metadata_mut(&mut self) -> &mut HashMap<String, String>;
68
69 #[cfg(any(feature = "_ring", feature = "_aws-lc-rs"))]
70 fn sni_server_name(&self) -> Option<&str>;
71
72 #[cfg(any(feature = "_ring", feature = "_aws-lc-rs"))]
73 fn client_certificates<'a>(&self) -> Option<&[CertificateDer<'a>]>;
74}
75
76pub trait ClientPortalStore {
78 type PortalStore;
79
80 fn portal_store(&self) -> &Self::PortalStore;
81}
82
83pub const METADATA_USER: &str = "user";
84pub const METADATA_DATABASE: &str = "database";
85pub const METADATA_CLIENT_ENCODING: &str = "client_encoding";
86pub const METADATA_APPLICATION_NAME: &str = "application_name";
87
88#[non_exhaustive]
89#[derive(Debug)]
90pub struct DefaultClient<S> {
91 pub socket_addr: SocketAddr,
92 pub is_secure: bool,
93 pub protocol_version: ProtocolVersion,
94 pub pid_secret_key: (i32, SecretKey),
95 pub state: PgWireConnectionState,
96 pub transaction_status: TransactionStatus,
97 pub metadata: HashMap<String, String>,
98 #[cfg(any(feature = "_ring", feature = "_aws-lc-rs"))]
99 pub sni_server_name: Option<String>,
100 pub portal_store: store::MemPortalStore<S>,
101}
102
103impl<S> ClientInfo for DefaultClient<S> {
104 fn socket_addr(&self) -> SocketAddr {
105 self.socket_addr
106 }
107
108 fn is_secure(&self) -> bool {
109 self.is_secure
110 }
111
112 fn pid_and_secret_key(&self) -> (i32, SecretKey) {
113 self.pid_secret_key.clone()
114 }
115
116 fn set_pid_and_secret_key(&mut self, pid: i32, secret_key: SecretKey) {
117 self.pid_secret_key = (pid, secret_key);
118 }
119
120 fn protocol_version(&self) -> ProtocolVersion {
121 self.protocol_version
122 }
123
124 fn set_protocol_version(&mut self, version: ProtocolVersion) {
125 self.protocol_version = version;
126 }
127
128 fn state(&self) -> PgWireConnectionState {
129 self.state
130 }
131
132 fn set_state(&mut self, new_state: PgWireConnectionState) {
133 self.state = new_state;
134 }
135
136 fn metadata(&self) -> &HashMap<String, String> {
137 &self.metadata
138 }
139
140 fn metadata_mut(&mut self) -> &mut HashMap<String, String> {
141 &mut self.metadata
142 }
143
144 fn transaction_status(&self) -> TransactionStatus {
145 self.transaction_status
146 }
147
148 fn set_transaction_status(&mut self, new_status: TransactionStatus) {
149 self.transaction_status = new_status
150 }
151
152 #[cfg(any(feature = "_ring", feature = "_aws-lc-rs"))]
153 fn sni_server_name(&self) -> Option<&str> {
154 self.sni_server_name.as_deref()
155 }
156
157 #[cfg(any(feature = "_ring", feature = "_aws-lc-rs"))]
158 fn client_certificates<'a>(&self) -> Option<&[CertificateDer<'a>]> {
159 None
160 }
161}
162
163impl<S> DefaultClient<S> {
164 pub fn new(socket_addr: SocketAddr, is_secure: bool) -> DefaultClient<S> {
165 DefaultClient {
166 socket_addr,
167 is_secure,
168 protocol_version: ProtocolVersion::default(),
169 pid_secret_key: (0, SecretKey::default()),
170 state: PgWireConnectionState::default(),
171 transaction_status: TransactionStatus::Idle,
172 metadata: HashMap::new(),
173 #[cfg(any(feature = "_ring", feature = "_aws-lc-rs"))]
174 sni_server_name: None,
175 portal_store: store::MemPortalStore::new(),
176 }
177 }
178}
179
180impl<S> ClientPortalStore for DefaultClient<S> {
181 type PortalStore = store::MemPortalStore<S>;
182
183 fn portal_store(&self) -> &Self::PortalStore {
184 &self.portal_store
185 }
186}
187
188pub trait ErrorHandler: Send + Sync {
194 fn on_error<C>(&self, _client: &C, _error: &mut PgWireError)
195 where
196 C: ClientInfo,
197 {
198 }
199}
200
201#[derive(Debug)]
203pub struct NoopHandler;
204
205impl ErrorHandler for NoopHandler {}
206
207pub trait PgWireServerHandlers {
208 fn simple_query_handler(&self) -> Arc<impl query::SimpleQueryHandler> {
209 Arc::new(NoopHandler)
210 }
211
212 fn extended_query_handler(&self) -> Arc<impl query::ExtendedQueryHandler> {
213 Arc::new(NoopHandler)
214 }
215
216 fn startup_handler(&self) -> Arc<impl auth::StartupHandler> {
217 Arc::new(NoopHandler)
218 }
219
220 fn copy_handler(&self) -> Arc<impl copy::CopyHandler> {
221 Arc::new(NoopHandler)
222 }
223
224 fn error_handler(&self) -> Arc<impl ErrorHandler> {
225 Arc::new(NoopHandler)
226 }
227
228 fn cancel_handler(&self) -> Arc<impl cancel::CancelHandler> {
229 Arc::new(NoopHandler)
230 }
231}
232
233impl<T> PgWireServerHandlers for Arc<T>
234where
235 T: PgWireServerHandlers,
236{
237 fn simple_query_handler(&self) -> Arc<impl query::SimpleQueryHandler> {
238 (**self).simple_query_handler()
239 }
240
241 fn extended_query_handler(&self) -> Arc<impl query::ExtendedQueryHandler> {
242 (**self).extended_query_handler()
243 }
244
245 fn startup_handler(&self) -> Arc<impl auth::StartupHandler> {
246 (**self).startup_handler()
247 }
248
249 fn copy_handler(&self) -> Arc<impl copy::CopyHandler> {
250 (**self).copy_handler()
251 }
252
253 fn error_handler(&self) -> Arc<impl ErrorHandler> {
254 (**self).error_handler()
255 }
256
257 fn cancel_handler(&self) -> Arc<impl cancel::CancelHandler> {
258 (**self).cancel_handler()
259 }
260}