#![allow(
clippy::cast_lossless,
clippy::cast_possible_truncation,
clippy::cast_sign_loss,
clippy::doc_markdown,
clippy::similar_names,
clippy::uninlined_format_args,
clippy::unreadable_literal,
unused_mut,
unused_variables
)]
use std::io::{Read, Write};
use std::net::TcpStream;
use std::path::PathBuf;
use std::time::Duration;
mod common;
fn local_spawn(
db: &std::path::Path,
admin_pw: Option<&str>,
) -> (std::process::Child, common::ServerAddrs) {
let mut b = common::ServerBuilder::new().arg_path(db).with_pgwire();
if let Some(pw) = admin_pw {
b = b.env("SPG_ADMIN_PASSWORD", pw);
}
b.spawn()
}
const READ_TIMEOUT: Duration = Duration::from_secs(3);
fn unique_tmpdir() -> PathBuf {
let nanos = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_nanos();
let p = std::env::temp_dir().join(format!("spg-e2e-pgwire-{nanos}"));
std::fs::create_dir_all(&p).unwrap();
p
}
struct PgMessage {
ty: u8,
body: Vec<u8>,
}
fn read_message(s: &mut TcpStream) -> PgMessage {
let mut header = [0u8; 5];
s.read_exact(&mut header).expect("pg header");
let ty = header[0];
let len = u32::from_be_bytes([header[1], header[2], header[3], header[4]]) as usize;
let body_len = len.saturating_sub(4);
let mut body = vec![0u8; body_len];
if body_len > 0 {
s.read_exact(&mut body).expect("pg body");
}
PgMessage { ty, body }
}
fn send_startup(s: &mut TcpStream, user: &str) {
let mut body = Vec::new();
body.extend_from_slice(&196608u32.to_be_bytes()); body.extend_from_slice(b"user\0");
body.extend_from_slice(user.as_bytes());
body.push(0);
body.push(0); let total = (body.len() + 4) as u32;
let mut out = Vec::with_capacity(body.len() + 4);
out.extend_from_slice(&total.to_be_bytes());
out.extend_from_slice(&body);
s.write_all(&out).unwrap();
}
fn send_query(s: &mut TcpStream, sql: &str) {
let mut body = Vec::with_capacity(sql.len() + 1);
body.extend_from_slice(sql.as_bytes());
body.push(0);
let total = (body.len() + 4) as u32;
let mut out = Vec::with_capacity(body.len() + 5);
out.push(b'Q');
out.extend_from_slice(&total.to_be_bytes());
out.extend_from_slice(&body);
s.write_all(&out).unwrap();
}
fn read_until_ready(s: &mut TcpStream) {
loop {
let m = read_message(s);
if m.ty == b'Z' {
return;
}
}
}
#[test]
fn select_version_canned_response_works() {
let dir = unique_tmpdir();
let db = dir.join("spg.db");
let (raw, addrs) = local_spawn(&db, None);
let mut child = common::ChildGuard(raw);
let mut s = common::connect_to(addrs.pgwire.as_ref().unwrap());
s.set_read_timeout(Some(READ_TIMEOUT)).unwrap();
send_startup(&mut s, "anyone");
let ok = read_message(&mut s);
assert_eq!(ok.ty, b'R');
read_until_ready(&mut s);
send_query(&mut s, "SELECT version()");
let rd = read_message(&mut s);
assert_eq!(rd.ty, b'T');
let dr = read_message(&mut s);
assert_eq!(dr.ty, b'D');
let cell_count = u16::from_be_bytes([dr.body[0], dr.body[1]]);
assert_eq!(cell_count, 1);
let len = i32::from_be_bytes([dr.body[2], dr.body[3], dr.body[4], dr.body[5]]);
assert!(len > 0);
let value = std::str::from_utf8(&dr.body[6..6 + len as usize]).unwrap();
assert!(value.contains("spg"), "got {value:?}");
let _ = read_message(&mut s); read_until_ready(&mut s);
}