Skip to main content

datafusion_postgres/
testing.rs

1use std::{collections::HashMap, sync::Arc};
2
3use datafusion::logical_expr::LogicalPlan;
4use datafusion::prelude::{SessionConfig, SessionContext};
5use datafusion::sql::sqlparser;
6use datafusion_pg_catalog::pg_catalog::setup_pg_catalog;
7use futures::Sink;
8use pgwire::{
9    api::{
10        ClientInfo, ClientPortalStore, METADATA_USER, PgWireConnectionState, SessionExtensions,
11        store::MemPortalStore,
12    },
13    messages::{
14        PgWireBackendMessage, ProtocolVersion, response::TransactionStatus, startup::SecretKey,
15    },
16};
17
18use crate::{DfSessionService, auth::AuthManager};
19
20pub fn setup_handlers() -> DfSessionService {
21    let session_config = SessionConfig::new().with_information_schema(true);
22    let session_context = SessionContext::new_with_config(session_config);
23
24    setup_pg_catalog(
25        &session_context,
26        "datafusion",
27        Arc::new(AuthManager::default()),
28    )
29    .expect("Failed to setup sesession context");
30
31    DfSessionService::new(Arc::new(session_context))
32}
33
34type DfStatement = (String, Option<(sqlparser::ast::Statement, LogicalPlan)>);
35
36#[derive(Debug, Default)]
37pub struct MockClient {
38    metadata: HashMap<String, String>,
39    portal_store: MemPortalStore<DfStatement>,
40    pub sent_messages: Vec<PgWireBackendMessage>,
41    session_extensions: SessionExtensions,
42}
43
44impl MockClient {
45    pub fn new() -> MockClient {
46        let mut metadata = HashMap::new();
47        metadata.insert(METADATA_USER.to_string(), "postgres".to_string());
48
49        MockClient {
50            metadata,
51            portal_store: MemPortalStore::new(),
52            sent_messages: Vec::new(),
53            session_extensions: SessionExtensions::new(),
54        }
55    }
56
57    pub fn sent_messages(&self) -> &[PgWireBackendMessage] {
58        &self.sent_messages
59    }
60}
61
62impl ClientInfo for MockClient {
63    fn socket_addr(&self) -> std::net::SocketAddr {
64        "127.0.0.1".parse().unwrap()
65    }
66
67    fn is_secure(&self) -> bool {
68        false
69    }
70
71    fn protocol_version(&self) -> ProtocolVersion {
72        ProtocolVersion::PROTOCOL3_0
73    }
74
75    fn set_protocol_version(&mut self, _version: ProtocolVersion) {}
76
77    fn pid_and_secret_key(&self) -> (i32, SecretKey) {
78        (0, SecretKey::I32(0))
79    }
80
81    fn set_pid_and_secret_key(&mut self, _pid: i32, _secret_key: SecretKey) {}
82
83    fn state(&self) -> PgWireConnectionState {
84        PgWireConnectionState::ReadyForQuery
85    }
86
87    fn set_state(&mut self, _new_state: PgWireConnectionState) {}
88
89    fn transaction_status(&self) -> TransactionStatus {
90        TransactionStatus::Idle
91    }
92
93    fn set_transaction_status(&mut self, _new_status: TransactionStatus) {}
94
95    fn metadata(&self) -> &HashMap<String, String> {
96        &self.metadata
97    }
98
99    fn metadata_mut(&mut self) -> &mut HashMap<String, String> {
100        &mut self.metadata
101    }
102
103    fn client_certificates<'a>(&self) -> Option<&[rustls_pki_types::CertificateDer<'a>]> {
104        None
105    }
106
107    fn sni_server_name(&self) -> Option<&str> {
108        None
109    }
110
111    fn session_extensions(&self) -> &pgwire::api::SessionExtensions {
112        &self.session_extensions
113    }
114}
115
116impl ClientPortalStore for MockClient {
117    type PortalStore = MemPortalStore<DfStatement>;
118    fn portal_store(&self) -> &Self::PortalStore {
119        &self.portal_store
120    }
121}
122
123#[cfg(test)]
124mod tests {
125    use super::*;
126
127    #[test]
128    fn test_mock_client_captures_messages() {
129        let client = MockClient::new();
130        assert!(client.sent_messages().is_empty());
131    }
132}
133
134impl Sink<PgWireBackendMessage> for MockClient {
135    type Error = std::io::Error;
136
137    fn poll_ready(
138        self: std::pin::Pin<&mut Self>,
139        _cx: &mut std::task::Context<'_>,
140    ) -> std::task::Poll<Result<(), Self::Error>> {
141        std::task::Poll::Ready(Ok(()))
142    }
143
144    fn start_send(
145        mut self: std::pin::Pin<&mut Self>,
146        item: PgWireBackendMessage,
147    ) -> Result<(), Self::Error> {
148        self.sent_messages.push(item);
149        Ok(())
150    }
151
152    fn poll_flush(
153        self: std::pin::Pin<&mut Self>,
154        _cx: &mut std::task::Context<'_>,
155    ) -> std::task::Poll<Result<(), Self::Error>> {
156        std::task::Poll::Ready(Ok(()))
157    }
158
159    fn poll_close(
160        self: std::pin::Pin<&mut Self>,
161        _cx: &mut std::task::Context<'_>,
162    ) -> std::task::Poll<Result<(), Self::Error>> {
163        std::task::Poll::Ready(Ok(()))
164    }
165}