use std::sync::Arc;
use std::time::Duration;
use nodedb::bridge::dispatch::Dispatcher;
use nodedb::config::auth::AuthMode;
use nodedb::control::server::pgwire::listener::PgListener;
use nodedb::control::state::SharedState;
use nodedb::data::executor::core_loop::CoreLoop;
use nodedb::wal::WalManager;
#[tokio::test]
async fn pgwire_connect_and_query() {
let dir = tempfile::tempdir().unwrap();
let wal_path = dir.path().join("test.wal");
let wal = Arc::new(WalManager::open_for_testing(&wal_path).unwrap());
let (dispatcher, data_sides) = Dispatcher::new(1, 64);
let shared = SharedState::new(dispatcher, wal);
let data_side = data_sides.into_iter().next().unwrap();
let core_dir = dir.path().to_path_buf();
let (core_stop_tx, core_stop_rx) = std::sync::mpsc::channel::<()>();
let core_handle = tokio::task::spawn_blocking(move || {
let mut core =
CoreLoop::open(0, data_side.request_rx, data_side.response_tx, &core_dir).unwrap();
while core_stop_rx.try_recv().is_err() {
core.tick();
std::thread::sleep(Duration::from_millis(1));
}
});
let shared_poller = Arc::clone(&shared);
let (poller_shutdown_tx, mut poller_shutdown_rx) = tokio::sync::watch::channel(false);
let poller_handle = tokio::spawn(async move {
loop {
shared_poller.poll_and_route_responses();
tokio::select! {
_ = tokio::time::sleep(Duration::from_millis(1)) => {}
_ = poller_shutdown_rx.changed() => break,
}
}
});
let pg_listener = PgListener::bind("127.0.0.1:0".parse().unwrap())
.await
.unwrap();
let pg_addr = pg_listener.local_addr();
let (shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(false);
let shared_pg = Arc::clone(&shared);
let pg_handle = tokio::spawn(async move {
pg_listener
.run(
shared_pg,
AuthMode::Trust,
None,
Arc::new(tokio::sync::Semaphore::new(128)),
shutdown_rx,
)
.await
.unwrap();
});
tokio::time::sleep(Duration::from_millis(50)).await;
let conn_str = format!(
"host=127.0.0.1 port={} user=nodedb dbname=nodedb",
pg_addr.port()
);
let (client, connection) = tokio_postgres::connect(&conn_str, tokio_postgres::NoTls)
.await
.expect("pgwire connect failed");
let conn_handle = tokio::spawn(async move {
if let Err(e) = connection.await {
eprintln!("connection error: {e}");
}
});
let result = client.simple_query("SET client_encoding = 'UTF8'").await;
assert!(result.is_ok(), "SET command failed: {:?}", result.err());
let result = client.simple_query("SELECT 1").await;
match &result {
Ok(msgs) => {
println!("SELECT 1 returned {} messages", msgs.len());
for msg in msgs {
match msg {
tokio_postgres::SimpleQueryMessage::Row(row) => {
println!(" Row: {:?}", row.get(0));
}
tokio_postgres::SimpleQueryMessage::CommandComplete(n) => {
println!(" CommandComplete: {n}");
}
_ => {}
}
}
}
Err(e) => {
println!("SELECT 1 returned error (expected): {e}");
}
}
let result2 = client.simple_query("SET search_path = 'public'").await;
assert!(
result2.is_ok(),
"Connection died after query: {:?}",
result2.err()
);
drop(client);
let _ = conn_handle.await;
let _ = shutdown_tx.send(true);
let _ = pg_handle.await;
let _ = poller_shutdown_tx.send(true);
let _ = poller_handle.await;
let _ = core_stop_tx.send(());
let _ = core_handle.await;
}