datafusion_postgres/
testing.rs1use std::{collections::HashMap, sync::Arc};
2
3use datafusion::logical_expr::LogicalPlan;
4use datafusion::prelude::{SessionConfig, SessionContext};
5use datafusion::sql::sqlparser;
6use datafusion_pg_catalog::pg_catalog::setup_pg_catalog;
7use futures::Sink;
8use pgwire::{
9 api::{
10 ClientInfo, ClientPortalStore, METADATA_USER, PgWireConnectionState, SessionExtensions,
11 store::MemPortalStore,
12 },
13 messages::{
14 PgWireBackendMessage, ProtocolVersion, response::TransactionStatus, startup::SecretKey,
15 },
16};
17
18use crate::{DfSessionService, auth::AuthManager};
19
20pub fn setup_handlers() -> DfSessionService {
21 let session_config = SessionConfig::new().with_information_schema(true);
22 let session_context = SessionContext::new_with_config(session_config);
23
24 setup_pg_catalog(
25 &session_context,
26 "datafusion",
27 Arc::new(AuthManager::default()),
28 )
29 .expect("Failed to setup sesession context");
30
31 DfSessionService::new(Arc::new(session_context))
32}
33
34type DfStatement = (String, Option<(sqlparser::ast::Statement, LogicalPlan)>);
35
36#[derive(Debug, Default)]
37pub struct MockClient {
38 metadata: HashMap<String, String>,
39 portal_store: MemPortalStore<DfStatement>,
40 pub sent_messages: Vec<PgWireBackendMessage>,
41 session_extensions: SessionExtensions,
42}
43
44impl MockClient {
45 pub fn new() -> MockClient {
46 let mut metadata = HashMap::new();
47 metadata.insert(METADATA_USER.to_string(), "postgres".to_string());
48
49 MockClient {
50 metadata,
51 portal_store: MemPortalStore::new(),
52 sent_messages: Vec::new(),
53 session_extensions: SessionExtensions::new(),
54 }
55 }
56
57 pub fn sent_messages(&self) -> &[PgWireBackendMessage] {
58 &self.sent_messages
59 }
60}
61
62impl ClientInfo for MockClient {
63 fn socket_addr(&self) -> std::net::SocketAddr {
64 "127.0.0.1".parse().unwrap()
65 }
66
67 fn is_secure(&self) -> bool {
68 false
69 }
70
71 fn protocol_version(&self) -> ProtocolVersion {
72 ProtocolVersion::PROTOCOL3_0
73 }
74
75 fn set_protocol_version(&mut self, _version: ProtocolVersion) {}
76
77 fn pid_and_secret_key(&self) -> (i32, SecretKey) {
78 (0, SecretKey::I32(0))
79 }
80
81 fn set_pid_and_secret_key(&mut self, _pid: i32, _secret_key: SecretKey) {}
82
83 fn state(&self) -> PgWireConnectionState {
84 PgWireConnectionState::ReadyForQuery
85 }
86
87 fn set_state(&mut self, _new_state: PgWireConnectionState) {}
88
89 fn transaction_status(&self) -> TransactionStatus {
90 TransactionStatus::Idle
91 }
92
93 fn set_transaction_status(&mut self, _new_status: TransactionStatus) {}
94
95 fn metadata(&self) -> &HashMap<String, String> {
96 &self.metadata
97 }
98
99 fn metadata_mut(&mut self) -> &mut HashMap<String, String> {
100 &mut self.metadata
101 }
102
103 fn client_certificates<'a>(&self) -> Option<&[rustls_pki_types::CertificateDer<'a>]> {
104 None
105 }
106
107 fn sni_server_name(&self) -> Option<&str> {
108 None
109 }
110
111 fn session_extensions(&self) -> &pgwire::api::SessionExtensions {
112 &self.session_extensions
113 }
114}
115
116impl ClientPortalStore for MockClient {
117 type PortalStore = MemPortalStore<DfStatement>;
118 fn portal_store(&self) -> &Self::PortalStore {
119 &self.portal_store
120 }
121}
122
123#[cfg(test)]
124mod tests {
125 use super::*;
126
127 #[test]
128 fn test_mock_client_captures_messages() {
129 let client = MockClient::new();
130 assert!(client.sent_messages().is_empty());
131 }
132}
133
134impl Sink<PgWireBackendMessage> for MockClient {
135 type Error = std::io::Error;
136
137 fn poll_ready(
138 self: std::pin::Pin<&mut Self>,
139 _cx: &mut std::task::Context<'_>,
140 ) -> std::task::Poll<Result<(), Self::Error>> {
141 std::task::Poll::Ready(Ok(()))
142 }
143
144 fn start_send(
145 mut self: std::pin::Pin<&mut Self>,
146 item: PgWireBackendMessage,
147 ) -> Result<(), Self::Error> {
148 self.sent_messages.push(item);
149 Ok(())
150 }
151
152 fn poll_flush(
153 self: std::pin::Pin<&mut Self>,
154 _cx: &mut std::task::Context<'_>,
155 ) -> std::task::Poll<Result<(), Self::Error>> {
156 std::task::Poll::Ready(Ok(()))
157 }
158
159 fn poll_close(
160 self: std::pin::Pin<&mut Self>,
161 _cx: &mut std::task::Context<'_>,
162 ) -> std::task::Poll<Result<(), Self::Error>> {
163 std::task::Poll::Ready(Ok(()))
164 }
165}