Skip to main content

datafusion_postgres/
testing.rs

1use std::{collections::HashMap, sync::Arc};
2
3use datafusion::prelude::{SessionConfig, SessionContext};
4use datafusion_pg_catalog::pg_catalog::setup_pg_catalog;
5use futures::Sink;
6use pgwire::{
7    api::{ClientInfo, ClientPortalStore, PgWireConnectionState, METADATA_USER},
8    messages::{
9        response::TransactionStatus, startup::SecretKey, PgWireBackendMessage, ProtocolVersion,
10    },
11};
12
13use crate::{auth::AuthManager, DfSessionService};
14
15pub fn setup_handlers() -> DfSessionService {
16    let session_config = SessionConfig::new().with_information_schema(true);
17    let session_context = SessionContext::new_with_config(session_config);
18
19    setup_pg_catalog(
20        &session_context,
21        "datafusion",
22        Arc::new(AuthManager::default()),
23    )
24    .expect("Failed to setup sesession context");
25
26    DfSessionService::new(Arc::new(session_context))
27}
28
29#[derive(Debug, Default)]
30pub struct MockClient {
31    metadata: HashMap<String, String>,
32    portal_store: HashMap<String, String>,
33    pub sent_messages: Vec<PgWireBackendMessage>,
34}
35
36impl MockClient {
37    pub fn new() -> MockClient {
38        let mut metadata = HashMap::new();
39        metadata.insert(METADATA_USER.to_string(), "postgres".to_string());
40
41        MockClient {
42            metadata,
43            portal_store: HashMap::default(),
44            sent_messages: Vec::new(),
45        }
46    }
47
48    pub fn sent_messages(&self) -> &[PgWireBackendMessage] {
49        &self.sent_messages
50    }
51}
52
53impl ClientInfo for MockClient {
54    fn socket_addr(&self) -> std::net::SocketAddr {
55        "127.0.0.1".parse().unwrap()
56    }
57
58    fn is_secure(&self) -> bool {
59        false
60    }
61
62    fn protocol_version(&self) -> ProtocolVersion {
63        ProtocolVersion::PROTOCOL3_0
64    }
65
66    fn set_protocol_version(&mut self, _version: ProtocolVersion) {}
67
68    fn pid_and_secret_key(&self) -> (i32, SecretKey) {
69        (0, SecretKey::I32(0))
70    }
71
72    fn set_pid_and_secret_key(&mut self, _pid: i32, _secret_key: SecretKey) {}
73
74    fn state(&self) -> PgWireConnectionState {
75        PgWireConnectionState::ReadyForQuery
76    }
77
78    fn set_state(&mut self, _new_state: PgWireConnectionState) {}
79
80    fn transaction_status(&self) -> TransactionStatus {
81        TransactionStatus::Idle
82    }
83
84    fn set_transaction_status(&mut self, _new_status: TransactionStatus) {}
85
86    fn metadata(&self) -> &HashMap<String, String> {
87        &self.metadata
88    }
89
90    fn metadata_mut(&mut self) -> &mut HashMap<String, String> {
91        &mut self.metadata
92    }
93
94    fn client_certificates<'a>(&self) -> Option<&[rustls_pki_types::CertificateDer<'a>]> {
95        None
96    }
97
98    fn sni_server_name(&self) -> Option<&str> {
99        None
100    }
101}
102
103impl ClientPortalStore for MockClient {
104    type PortalStore = HashMap<String, String>;
105    fn portal_store(&self) -> &Self::PortalStore {
106        &self.portal_store
107    }
108}
109
110#[cfg(test)]
111mod tests {
112    use super::*;
113
114    #[test]
115    fn test_mock_client_captures_messages() {
116        let client = MockClient::new();
117        assert!(client.sent_messages().is_empty());
118    }
119}
120
121impl Sink<PgWireBackendMessage> for MockClient {
122    type Error = std::io::Error;
123
124    fn poll_ready(
125        self: std::pin::Pin<&mut Self>,
126        _cx: &mut std::task::Context<'_>,
127    ) -> std::task::Poll<Result<(), Self::Error>> {
128        std::task::Poll::Ready(Ok(()))
129    }
130
131    fn start_send(
132        mut self: std::pin::Pin<&mut Self>,
133        item: PgWireBackendMessage,
134    ) -> Result<(), Self::Error> {
135        self.sent_messages.push(item);
136        Ok(())
137    }
138
139    fn poll_flush(
140        self: std::pin::Pin<&mut Self>,
141        _cx: &mut std::task::Context<'_>,
142    ) -> std::task::Poll<Result<(), Self::Error>> {
143        std::task::Poll::Ready(Ok(()))
144    }
145
146    fn poll_close(
147        self: std::pin::Pin<&mut Self>,
148        _cx: &mut std::task::Context<'_>,
149    ) -> std::task::Poll<Result<(), Self::Error>> {
150        std::task::Poll::Ready(Ok(()))
151    }
152}