pgwire/api/
mod.rs

1//! APIs for building postgresql compatible servers.
2
3use std::collections::HashMap;
4use std::net::SocketAddr;
5use std::sync::Arc;
6
7pub use postgres_types::Type;
8#[cfg(any(feature = "_ring", feature = "_aws-lc-rs"))]
9use rustls_pki_types::CertificateDer;
10
11use crate::error::PgWireError;
12use crate::messages::response::TransactionStatus;
13use crate::messages::startup::SecretKey;
14use crate::messages::ProtocolVersion;
15
16pub mod auth;
17pub mod cancel;
18#[cfg(feature = "client-api")]
19pub mod client;
20pub mod copy;
21pub mod portal;
22pub mod query;
23pub mod results;
24pub mod stmt;
25pub mod store;
26pub mod transaction;
27
28pub const DEFAULT_NAME: &str = "POSTGRESQL_DEFAULT_NAME";
29
30#[derive(Debug, Clone, Copy, Default)]
31pub enum PgWireConnectionState {
32    #[default]
33    AwaitingSslRequest,
34    AwaitingStartup,
35    AuthenticationInProgress,
36    ReadyForQuery,
37    QueryInProgress,
38    CopyInProgress(bool),
39    AwaitingSync,
40}
41
42// TODO: add oauth scope and issuer
43/// Describe a client information holder
44pub trait ClientInfo {
45    fn socket_addr(&self) -> SocketAddr;
46
47    fn is_secure(&self) -> bool;
48
49    fn protocol_version(&self) -> ProtocolVersion;
50
51    fn set_protocol_version(&mut self, version: ProtocolVersion);
52
53    fn pid_and_secret_key(&self) -> (i32, SecretKey);
54
55    fn set_pid_and_secret_key(&mut self, pid: i32, secret_key: SecretKey);
56
57    fn state(&self) -> PgWireConnectionState;
58
59    fn set_state(&mut self, new_state: PgWireConnectionState);
60
61    fn transaction_status(&self) -> TransactionStatus;
62
63    fn set_transaction_status(&mut self, new_status: TransactionStatus);
64
65    fn metadata(&self) -> &HashMap<String, String>;
66
67    fn metadata_mut(&mut self) -> &mut HashMap<String, String>;
68
69    #[cfg(any(feature = "_ring", feature = "_aws-lc-rs"))]
70    fn sni_server_name(&self) -> Option<&str>;
71
72    #[cfg(any(feature = "_ring", feature = "_aws-lc-rs"))]
73    fn client_certificates<'a>(&self) -> Option<&[CertificateDer<'a>]>;
74}
75
76/// Client Portal Store
77pub trait ClientPortalStore {
78    type PortalStore;
79
80    fn portal_store(&self) -> &Self::PortalStore;
81}
82
83pub const METADATA_USER: &str = "user";
84pub const METADATA_DATABASE: &str = "database";
85pub const METADATA_CLIENT_ENCODING: &str = "client_encoding";
86pub const METADATA_APPLICATION_NAME: &str = "application_name";
87
88#[non_exhaustive]
89#[derive(Debug)]
90pub struct DefaultClient<S> {
91    pub socket_addr: SocketAddr,
92    pub is_secure: bool,
93    pub protocol_version: ProtocolVersion,
94    pub pid_secret_key: (i32, SecretKey),
95    pub state: PgWireConnectionState,
96    pub transaction_status: TransactionStatus,
97    pub metadata: HashMap<String, String>,
98    #[cfg(any(feature = "_ring", feature = "_aws-lc-rs"))]
99    pub sni_server_name: Option<String>,
100    pub portal_store: store::MemPortalStore<S>,
101}
102
103impl<S> ClientInfo for DefaultClient<S> {
104    fn socket_addr(&self) -> SocketAddr {
105        self.socket_addr
106    }
107
108    fn is_secure(&self) -> bool {
109        self.is_secure
110    }
111
112    fn pid_and_secret_key(&self) -> (i32, SecretKey) {
113        self.pid_secret_key.clone()
114    }
115
116    fn set_pid_and_secret_key(&mut self, pid: i32, secret_key: SecretKey) {
117        self.pid_secret_key = (pid, secret_key);
118    }
119
120    fn protocol_version(&self) -> ProtocolVersion {
121        self.protocol_version
122    }
123
124    fn set_protocol_version(&mut self, version: ProtocolVersion) {
125        self.protocol_version = version;
126    }
127
128    fn state(&self) -> PgWireConnectionState {
129        self.state
130    }
131
132    fn set_state(&mut self, new_state: PgWireConnectionState) {
133        self.state = new_state;
134    }
135
136    fn metadata(&self) -> &HashMap<String, String> {
137        &self.metadata
138    }
139
140    fn metadata_mut(&mut self) -> &mut HashMap<String, String> {
141        &mut self.metadata
142    }
143
144    fn transaction_status(&self) -> TransactionStatus {
145        self.transaction_status
146    }
147
148    fn set_transaction_status(&mut self, new_status: TransactionStatus) {
149        self.transaction_status = new_status
150    }
151
152    #[cfg(any(feature = "_ring", feature = "_aws-lc-rs"))]
153    fn sni_server_name(&self) -> Option<&str> {
154        self.sni_server_name.as_deref()
155    }
156
157    #[cfg(any(feature = "_ring", feature = "_aws-lc-rs"))]
158    fn client_certificates<'a>(&self) -> Option<&[CertificateDer<'a>]> {
159        None
160    }
161}
162
163impl<S> DefaultClient<S> {
164    pub fn new(socket_addr: SocketAddr, is_secure: bool) -> DefaultClient<S> {
165        DefaultClient {
166            socket_addr,
167            is_secure,
168            protocol_version: ProtocolVersion::default(),
169            pid_secret_key: (0, SecretKey::default()),
170            state: PgWireConnectionState::default(),
171            transaction_status: TransactionStatus::Idle,
172            metadata: HashMap::new(),
173            #[cfg(any(feature = "_ring", feature = "_aws-lc-rs"))]
174            sni_server_name: None,
175            portal_store: store::MemPortalStore::new(),
176        }
177    }
178}
179
180impl<S> ClientPortalStore for DefaultClient<S> {
181    type PortalStore = store::MemPortalStore<S>;
182
183    fn portal_store(&self) -> &Self::PortalStore {
184        &self.portal_store
185    }
186}
187
188/// A centralized handler for all errors
189///
190/// This handler captures all errors produces by authentication, query and
191/// copy. You can do logging, filtering or masking the error before it sent to
192/// client.
193pub trait ErrorHandler: Send + Sync {
194    fn on_error<C>(&self, _client: &C, _error: &mut PgWireError)
195    where
196        C: ClientInfo,
197    {
198    }
199}
200
201/// A noop implementation for `ErrorHandler`.
202#[derive(Debug)]
203pub struct NoopHandler;
204
205impl ErrorHandler for NoopHandler {}
206
207pub trait PgWireServerHandlers {
208    fn simple_query_handler(&self) -> Arc<impl query::SimpleQueryHandler> {
209        Arc::new(NoopHandler)
210    }
211
212    fn extended_query_handler(&self) -> Arc<impl query::ExtendedQueryHandler> {
213        Arc::new(NoopHandler)
214    }
215
216    fn startup_handler(&self) -> Arc<impl auth::StartupHandler> {
217        Arc::new(NoopHandler)
218    }
219
220    fn copy_handler(&self) -> Arc<impl copy::CopyHandler> {
221        Arc::new(NoopHandler)
222    }
223
224    fn error_handler(&self) -> Arc<impl ErrorHandler> {
225        Arc::new(NoopHandler)
226    }
227
228    fn cancel_handler(&self) -> Arc<impl cancel::CancelHandler> {
229        Arc::new(NoopHandler)
230    }
231}
232
233impl<T> PgWireServerHandlers for Arc<T>
234where
235    T: PgWireServerHandlers,
236{
237    fn simple_query_handler(&self) -> Arc<impl query::SimpleQueryHandler> {
238        (**self).simple_query_handler()
239    }
240
241    fn extended_query_handler(&self) -> Arc<impl query::ExtendedQueryHandler> {
242        (**self).extended_query_handler()
243    }
244
245    fn startup_handler(&self) -> Arc<impl auth::StartupHandler> {
246        (**self).startup_handler()
247    }
248
249    fn copy_handler(&self) -> Arc<impl copy::CopyHandler> {
250        (**self).copy_handler()
251    }
252
253    fn error_handler(&self) -> Arc<impl ErrorHandler> {
254        (**self).error_handler()
255    }
256
257    fn cancel_handler(&self) -> Arc<impl cancel::CancelHandler> {
258        (**self).cancel_handler()
259    }
260}