datafusion_postgres/
testing.rs

1use std::{collections::HashMap, sync::Arc};
2
3use datafusion::prelude::SessionContext;
4use datafusion_pg_catalog::pg_catalog::setup_pg_catalog;
5use futures::Sink;
6use pgwire::{
7    api::{ClientInfo, ClientPortalStore, PgWireConnectionState, METADATA_USER},
8    messages::{
9        response::TransactionStatus, startup::SecretKey, PgWireBackendMessage, ProtocolVersion,
10    },
11};
12
13use crate::{auth::AuthManager, DfSessionService};
14
15pub fn setup_handlers() -> DfSessionService {
16    let session_context = SessionContext::new();
17    setup_pg_catalog(
18        &session_context,
19        "datafusion",
20        Arc::new(AuthManager::default()),
21    )
22    .expect("Failed to setup sesession context");
23
24    DfSessionService::new(Arc::new(session_context), Arc::new(AuthManager::new()))
25}
26
27#[derive(Debug, Default)]
28pub struct MockClient {
29    metadata: HashMap<String, String>,
30    portal_store: HashMap<String, String>,
31}
32
33impl MockClient {
34    pub fn new() -> MockClient {
35        let mut metadata = HashMap::new();
36        metadata.insert(METADATA_USER.to_string(), "postgres".to_string());
37
38        MockClient {
39            metadata,
40            portal_store: HashMap::default(),
41        }
42    }
43}
44
45impl ClientInfo for MockClient {
46    fn socket_addr(&self) -> std::net::SocketAddr {
47        "127.0.0.1".parse().unwrap()
48    }
49
50    fn is_secure(&self) -> bool {
51        false
52    }
53
54    fn protocol_version(&self) -> ProtocolVersion {
55        ProtocolVersion::PROTOCOL3_0
56    }
57
58    fn set_protocol_version(&mut self, _version: ProtocolVersion) {}
59
60    fn pid_and_secret_key(&self) -> (i32, SecretKey) {
61        (0, SecretKey::I32(0))
62    }
63
64    fn set_pid_and_secret_key(&mut self, _pid: i32, _secret_key: SecretKey) {}
65
66    fn state(&self) -> PgWireConnectionState {
67        PgWireConnectionState::ReadyForQuery
68    }
69
70    fn set_state(&mut self, _new_state: PgWireConnectionState) {}
71
72    fn transaction_status(&self) -> TransactionStatus {
73        TransactionStatus::Idle
74    }
75
76    fn set_transaction_status(&mut self, _new_status: TransactionStatus) {}
77
78    fn metadata(&self) -> &HashMap<String, String> {
79        &self.metadata
80    }
81
82    fn metadata_mut(&mut self) -> &mut HashMap<String, String> {
83        &mut self.metadata
84    }
85
86    fn client_certificates<'a>(&self) -> Option<&[rustls_pki_types::CertificateDer<'a>]> {
87        None
88    }
89
90    fn sni_server_name(&self) -> Option<&str> {
91        None
92    }
93}
94
95impl ClientPortalStore for MockClient {
96    type PortalStore = HashMap<String, String>;
97    fn portal_store(&self) -> &Self::PortalStore {
98        &self.portal_store
99    }
100}
101
102impl Sink<PgWireBackendMessage> for MockClient {
103    type Error = std::io::Error;
104
105    fn poll_ready(
106        self: std::pin::Pin<&mut Self>,
107        _cx: &mut std::task::Context<'_>,
108    ) -> std::task::Poll<Result<(), Self::Error>> {
109        std::task::Poll::Ready(Ok(()))
110    }
111
112    fn start_send(
113        self: std::pin::Pin<&mut Self>,
114        _item: PgWireBackendMessage,
115    ) -> Result<(), Self::Error> {
116        Ok(())
117    }
118
119    fn poll_flush(
120        self: std::pin::Pin<&mut Self>,
121        _cx: &mut std::task::Context<'_>,
122    ) -> std::task::Poll<Result<(), Self::Error>> {
123        std::task::Poll::Ready(Ok(()))
124    }
125
126    fn poll_close(
127        self: std::pin::Pin<&mut Self>,
128        _cx: &mut std::task::Context<'_>,
129    ) -> std::task::Poll<Result<(), Self::Error>> {
130        std::task::Poll::Ready(Ok(()))
131    }
132}