use std::net::TcpStream;
use std::sync::{Arc, Mutex, MutexGuard};
use tracing::{debug, info, trace, warn};
#[cfg(unix)]
use std::os::unix::net::UnixStream;
use super::cancel::Cancellable;
use super::config::Config;
use super::connection::{parse_error_response, RawConnection};
use super::endpoint::ConnectionEndpoint;
use super::error::{Error, ErrorKind, Result};
use super::prepare;
use super::row::{Row, StreamRow};
use super::sync_stream::SyncStream;
use crate::protocol::message::Message;
use crate::types::Oid;
use super::notice::{Notice, NoticeReceiver};
pub struct Client {
connection: Arc<Mutex<RawConnection<SyncStream>>>,
process_id: i32,
secret_key: i32,
endpoint: ConnectionEndpoint,
notice_receiver: Option<Arc<NoticeReceiver>>,
}
impl std::fmt::Debug for Client {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Client")
.field("process_id", &self.process_id)
.field("secret_key", &self.secret_key)
.field("endpoint", &self.endpoint)
.field(
"notice_receiver",
&self.notice_receiver.as_ref().map(|_| "<callback>"),
)
.finish_non_exhaustive()
}
}
impl Client {
pub fn connect(config: &Config) -> Result<Self> {
info!(
target: "hyperdb_api",
host = %config.host(),
port = config.port(),
user = config.user().unwrap_or("(default)"),
database = config.database().unwrap_or("(none)"),
"connection-parameters"
);
let endpoint = ConnectionEndpoint::tcp(config.host(), config.port());
let addr = format!("{}:{}", config.host(), config.port());
let tcp_stream = TcpStream::connect(&addr).map_err(|e| {
warn!(target: "hyperdb_api", %addr, error = %e, "connection-failed");
Error::connection(format!("failed to connect to {addr}: {e}"))
})?;
tcp_stream.set_nodelay(true).ok();
let sock = socket2::SockRef::from(&tcp_stream);
sock.set_recv_buffer_size(4 * 1024 * 1024).ok();
sock.set_send_buffer_size(4 * 1024 * 1024).ok();
let stream = SyncStream::tcp(tcp_stream);
let mut connection = RawConnection::new(stream);
let params = config.startup_params();
let params_ref: Vec<(&str, &str)> = params.iter().map(|(k, v)| (*k, *v)).collect();
connection.startup(¶ms_ref, config.password())?;
let process_id = connection.process_id();
let secret_key = connection.secret_key();
debug!(
target: "hyperdb_api",
process_id,
"connection-established"
);
Ok(Client {
connection: Arc::new(Mutex::new(connection)),
process_id,
secret_key,
endpoint,
notice_receiver: None,
})
}
#[cfg(unix)]
pub fn connect_unix(socket_path: impl AsRef<std::path::Path>, config: &Config) -> Result<Self> {
use std::path::Path;
let path = socket_path.as_ref();
info!(
target: "hyperdb_api",
socket_path = %path.display(),
user = config.user().unwrap_or("(default)"),
database = config.database().unwrap_or("(none)"),
"connection-parameters-unix"
);
let unix_stream = UnixStream::connect(path).map_err(|e| {
warn!(target: "hyperdb_api", socket_path = %path.display(), error = %e, "connection-failed");
Error::connection(format!("failed to connect to unix socket {}: {}", path.display(), e))
})?;
let directory = path.parent().unwrap_or(Path::new("/"));
let name = path
.file_name()
.and_then(|n| n.to_str())
.unwrap_or("socket");
let endpoint = ConnectionEndpoint::domain_socket(directory, name);
let stream = SyncStream::unix(unix_stream);
let mut connection = RawConnection::new(stream);
let params = config.startup_params();
let params_ref: Vec<(&str, &str)> = params.iter().map(|(k, v)| (*k, *v)).collect();
connection.startup(¶ms_ref, config.password())?;
let process_id = connection.process_id();
let secret_key = connection.secret_key();
debug!(
target: "hyperdb_api",
process_id,
"connection-established-unix"
);
Ok(Client {
connection: Arc::new(Mutex::new(connection)),
process_id,
secret_key,
endpoint,
notice_receiver: None,
})
}
#[cfg(windows)]
pub fn connect_named_pipe(pipe_path: &str, config: &Config) -> Result<Self> {
use std::fs::OpenOptions;
use std::time::{Duration, Instant};
info!(
target: "hyperdb_api",
pipe_path = %pipe_path,
user = config.user().unwrap_or("(default)"),
database = config.database().unwrap_or("(none)"),
"connection-parameters-named-pipe"
);
const RETRY_INTERVAL: Duration = Duration::from_millis(20);
const MAX_WAIT: Duration = Duration::from_secs(10);
const ERROR_PIPE_BUSY: i32 = 231;
let deadline = Instant::now() + MAX_WAIT;
let file = loop {
match OpenOptions::new().read(true).write(true).open(pipe_path) {
Ok(f) => break f,
Err(e)
if e.raw_os_error() == Some(ERROR_PIPE_BUSY) && Instant::now() < deadline =>
{
std::thread::sleep(RETRY_INTERVAL);
}
Err(e) => {
warn!(target: "hyperdb_api", pipe_path = %pipe_path, error = %e, "connection-failed");
return Err(Error::connection(format!(
"failed to connect to named pipe {pipe_path}: {e}"
)));
}
}
};
let endpoint = ConnectionEndpoint::parse(&format!(
"tab.pipe://{}",
pipe_path.trim_start_matches(r"\\").replace('\\', "/")
))
.unwrap_or_else(|_| {
let parts: Vec<&str> = pipe_path
.trim_start_matches(r"\\")
.splitn(3, '\\')
.collect();
if parts.len() >= 3 {
ConnectionEndpoint::named_pipe(parts[0], parts[2])
} else {
ConnectionEndpoint::named_pipe(".", pipe_path)
}
});
let stream = SyncStream::named_pipe(file);
let mut connection = RawConnection::new(stream);
let params = config.startup_params();
let params_ref: Vec<(&str, &str)> = params.iter().map(|(k, v)| (*k, *v)).collect();
connection.startup(¶ms_ref, config.password())?;
let process_id = connection.process_id();
let secret_key = connection.secret_key();
debug!(
target: "hyperdb_api",
process_id,
"connection-established-named-pipe"
);
Ok(Client {
connection: Arc::new(Mutex::new(connection)),
process_id,
secret_key,
endpoint,
notice_receiver: None,
})
}
pub fn connect_endpoint(endpoint: &ConnectionEndpoint, config: &Config) -> Result<Self> {
match endpoint {
ConnectionEndpoint::Tcp { host, port } => {
let mut cfg = config.clone();
cfg = cfg.with_host(host.clone()).with_port(*port);
Self::connect(&cfg)
}
#[cfg(unix)]
ConnectionEndpoint::DomainSocket { directory, name } => {
let socket_path = directory.join(name);
Self::connect_unix(&socket_path, config)
}
#[cfg(windows)]
ConnectionEndpoint::NamedPipe { host, name } => {
let pipe_path = format!(r"\\{host}\pipe\{name}");
Self::connect_named_pipe(&pipe_path, config)
}
}
}
#[must_use]
pub fn endpoint(&self) -> &ConnectionEndpoint {
&self.endpoint
}
#[must_use]
pub fn process_id(&self) -> i32 {
self.process_id
}
#[must_use]
pub fn secret_key(&self) -> i32 {
self.secret_key
}
pub fn cancel(&self) -> Result<()> {
use crate::protocol::message::frontend;
use bytes::BytesMut;
use std::io::Write;
info!(
target: "hyperdb_api",
process_id = self.process_id,
"query-cancel-request"
);
let endpoint_str = self.endpoint.to_string();
match &self.endpoint {
ConnectionEndpoint::Tcp { host, port } => {
let addr = format!("{host}:{port}");
let mut stream = TcpStream::connect(&addr).map_err(|e| {
warn!(
target: "hyperdb_api",
addr = %endpoint_str,
error = %e,
"query-cancel-connect-failed"
);
Error::connection(format!(
"failed to connect for cancel request to {endpoint_str}: {e}"
))
})?;
stream.set_nodelay(true).ok();
let mut buf = BytesMut::with_capacity(16);
frontend::cancel_request(self.process_id, self.secret_key, &mut buf);
stream.write_all(&buf).map_err(|e| {
warn!(
target: "hyperdb_api",
error = %e,
"query-cancel-send-failed"
);
Error::io(e)
})?;
stream.flush().map_err(Error::io)?;
}
#[cfg(unix)]
ConnectionEndpoint::DomainSocket { directory, name } => {
let socket_path = directory.join(name);
let mut stream = UnixStream::connect(&socket_path).map_err(|e| {
warn!(
target: "hyperdb_api",
addr = %endpoint_str,
error = %e,
"query-cancel-connect-failed"
);
Error::connection(format!(
"failed to connect for cancel request to {endpoint_str}: {e}"
))
})?;
let mut buf = BytesMut::with_capacity(16);
frontend::cancel_request(self.process_id, self.secret_key, &mut buf);
stream.write_all(&buf).map_err(|e| {
warn!(
target: "hyperdb_api",
error = %e,
"query-cancel-send-failed"
);
Error::io(e)
})?;
stream.flush().map_err(Error::io)?;
}
#[cfg(windows)]
ConnectionEndpoint::NamedPipe { host, name } => {
let pipe_path = format!(r"\\{host}\pipe\{name}");
let mut file = std::fs::OpenOptions::new()
.read(true)
.write(true)
.open(&pipe_path)
.map_err(|e| {
warn!(
target: "hyperdb_api",
addr = %endpoint_str,
error = %e,
"query-cancel-connect-failed"
);
Error::connection(format!(
"failed to connect for cancel request to {endpoint_str}: {e}"
))
})?;
let mut buf = BytesMut::with_capacity(16);
frontend::cancel_request(self.process_id, self.secret_key, &mut buf);
file.write_all(&buf).map_err(|e| {
warn!(
target: "hyperdb_api",
error = %e,
"query-cancel-send-failed"
);
Error::io(e)
})?;
file.flush().map_err(Error::io)?;
}
}
debug!(
target: "hyperdb_api",
process_id = self.process_id,
"query-cancel-sent"
);
Ok(())
}
#[must_use]
pub fn parameter_status(&self, name: &str) -> Option<String> {
let conn = self.connection.lock().ok()?;
conn.parameter_status(name)
.map(std::string::ToString::to_string)
}
pub fn set_notice_receiver(&mut self, receiver: Option<NoticeReceiver>) {
self.notice_receiver = receiver.map(Arc::new);
}
pub(crate) fn process_notices(&self, messages: &[Message]) {
for msg in messages {
if let Message::NoticeResponse(body) = msg {
let notice = Notice::from_response_body(body);
if let Some(ref receiver) = self.notice_receiver {
receiver(notice);
} else {
warn!(
target: "hyperdb_api",
severity = notice.severity().unwrap_or("NOTICE"),
code = notice.code().unwrap_or(""),
message = %notice.message(),
"server-notice"
);
}
}
}
}
fn lock_connection(&self) -> Result<MutexGuard<'_, RawConnection<SyncStream>>> {
self.connection
.lock()
.map_err(|_| Error::connection("connection mutex poisoned"))
}
pub fn query(&self, query: &str) -> Result<Vec<Row>> {
let mut conn = self.lock_connection()?;
let messages = conn.simple_query(query)?;
drop(conn);
self.process_notices(&messages);
let mut rows = Vec::new();
let mut columns = None;
for msg in messages {
match msg {
crate::protocol::message::Message::RowDescription(desc) => {
let mut cols = Vec::new();
for f in desc.fields().filter_map(|r| {
r.map_err(|e| trace!(target: "hyperdb_api_core::client", error = %e, "dropped error parsing row description field")).ok()
}) {
cols.push(super::statement::Column::new(
f.name().to_string(),
f.type_oid(),
f.type_modifier(),
super::statement::ColumnFormat::from_code(f.format()),
));
}
columns = Some(Arc::new(cols));
}
crate::protocol::message::Message::DataRow(data) => {
if let Some(ref cols) = columns {
rows.push(Row::new(Arc::clone(cols), data)?);
}
}
_ => {}
}
}
Ok(rows)
}
pub fn query_fast(&self, query: &str) -> Result<Vec<StreamRow>> {
let mut conn = self.lock_connection()?;
let messages = conn.query_binary(query)?;
drop(conn);
self.process_notices(&messages);
let mut rows = Vec::new();
for msg in messages {
if let crate::protocol::message::Message::DataRow(data) = msg {
rows.push(StreamRow::new(data));
}
}
Ok(rows)
}
pub fn query_streaming<'a>(
&'a self,
query: &str,
chunk_size: usize,
) -> Result<QueryStream<'a>> {
let mut conn = self.lock_connection()?;
conn.start_query_binary(query)?;
Ok(QueryStream {
conn: Some(conn),
canceller: self,
finished: false,
chunk_size: chunk_size.max(1),
schema: None,
schema_read: false,
})
}
pub fn exec(&self, query: &str) -> Result<u64> {
let mut conn = self.lock_connection()?;
let messages = conn.simple_query(query)?;
drop(conn);
self.process_notices(&messages);
let mut affected = 0u64;
for msg in messages {
if let crate::protocol::message::Message::CommandComplete(body) = msg {
if let Ok(tag) = body.tag() {
if let Some(count) = parse_affected_rows(tag) {
affected = count;
}
}
}
}
Ok(affected)
}
pub fn batch_execute(&self, query: &str) -> Result<()> {
let mut conn = self.lock_connection()?;
let messages = conn.simple_query(query)?;
drop(conn);
self.process_notices(&messages);
Ok(())
}
pub fn prepare(&self, query: &str) -> Result<prepare::OwnedPreparedStatement> {
prepare::prepare_owned(&self.connection, query, &[])
}
pub fn prepare_typed(
&self,
query: &str,
param_types: &[Oid],
) -> Result<prepare::OwnedPreparedStatement> {
prepare::prepare_owned(&self.connection, query, param_types)
}
pub fn execute<P: AsRef<[Option<Vec<u8>>]>>(
&self,
statement: &prepare::OwnedPreparedStatement,
params: P,
) -> Result<Vec<Row>> {
let params_ref: Vec<Option<&[u8]>> = params
.as_ref()
.iter()
.map(|p| p.as_ref().map(std::vec::Vec::as_slice))
.collect();
prepare::execute_prepared(&self.connection, statement.statement(), ¶ms_ref)
}
pub fn execute_no_result<P: AsRef<[Option<Vec<u8>>]>>(
&self,
statement: &prepare::OwnedPreparedStatement,
params: P,
) -> Result<u64> {
let params_ref: Vec<Option<&[u8]>> = params
.as_ref()
.iter()
.map(|p| p.as_ref().map(std::vec::Vec::as_slice))
.collect();
prepare::execute_prepared_no_result(&self.connection, statement.statement(), ¶ms_ref)
}
pub fn execute_streaming<'a, P: AsRef<[Option<Vec<u8>>]>>(
&'a self,
statement: &prepare::OwnedPreparedStatement,
params: P,
chunk_size: usize,
) -> Result<super::prepared_stream::PreparedQueryStream<'a>> {
let params_ref: Vec<Option<&[u8]>> = params
.as_ref()
.iter()
.map(|p| p.as_ref().map(std::vec::Vec::as_slice))
.collect();
let mut conn = self.lock_connection()?;
conn.start_execute_prepared(statement.name(), ¶ms_ref, statement.columns().len())?;
let columns = std::sync::Arc::new(statement.columns().to_vec());
Ok(super::prepared_stream::PreparedQueryStream::new(
conn, self, chunk_size, columns,
))
}
pub fn close(self) -> Result<()> {
let mut conn = self.lock_connection()?;
conn.terminate()
}
pub fn copy_in(&self, table_name: &str, columns: &[&str]) -> Result<CopyInWriter<'_>> {
self.copy_in_with_format(table_name, columns, "HYPERBINARY")
}
pub fn copy_in_with_format(
&self,
table_name: &str,
columns: &[&str],
format: &str,
) -> Result<CopyInWriter<'_>> {
let mut conn = self.lock_connection()?;
conn.start_copy_in_with_format(table_name, columns, format)?;
Ok(CopyInWriter { connection: conn })
}
pub fn copy_in_raw(&self, query: &str) -> Result<CopyInWriter<'_>> {
if !query.trim_start().to_ascii_uppercase().starts_with("COPY") {
return Err(Error::new(
ErrorKind::Query,
"copy_in_raw() requires a COPY statement. \
The query must start with 'COPY'.",
));
}
let mut conn = self.lock_connection()?;
conn.start_copy_in_raw(query)?;
Ok(CopyInWriter { connection: conn })
}
#[must_use]
pub fn is_alive(&self) -> bool {
self.lock_connection().is_ok()
}
pub fn copy_out(&self, query: &str) -> Result<Vec<u8>> {
let mut conn = self.lock_connection()?;
conn.copy_out(query)
}
pub fn copy_out_to_writer(&self, query: &str, writer: &mut dyn std::io::Write) -> Result<u64> {
let mut conn = self.lock_connection()?;
conn.copy_out_to_writer(query, writer)
}
}
impl Cancellable for Client {
fn cancel(&self) {
if let Err(e) = Client::cancel(self) {
warn!(
target: "hyperdb_api_core::client",
error = %e,
process_id = self.process_id,
"cancel request failed (best-effort, swallowed)",
);
}
}
}
#[derive(Debug)]
pub struct CopyInWriter<'a> {
connection: MutexGuard<'a, RawConnection<SyncStream>>,
}
impl CopyInWriter<'_> {
pub fn send(&mut self, data: &[u8]) -> Result<()> {
self.connection.send_copy_data(data)
}
pub fn flush(&mut self) -> Result<()> {
self.connection.flush()
}
pub fn send_direct(&mut self, data: &[u8]) -> Result<()> {
self.connection.send_copy_data_direct(data)
}
pub fn flush_stream(&mut self) -> Result<()> {
self.connection.flush_stream()
}
pub fn reserve_buffer(&mut self, capacity: usize) {
self.connection.reserve_write_buffer(capacity);
}
pub fn finish(mut self) -> Result<u64> {
self.connection.finish_copy()
}
pub fn cancel(mut self, reason: &str) -> Result<()> {
self.connection.cancel_copy(reason)
}
}
fn parse_affected_rows(tag: &str) -> Option<u64> {
let parts: Vec<&str> = tag.split_whitespace().collect();
match parts.first()? {
&"INSERT" => {
parts.get(2)?.parse().ok()
}
&"UPDATE" | &"DELETE" | &"SELECT" | &"COPY" => {
parts.get(1)?.parse().ok()
}
_ => None,
}
}
pub struct QueryStream<'a> {
conn: Option<MutexGuard<'a, RawConnection<SyncStream>>>,
canceller: &'a dyn Cancellable,
finished: bool,
chunk_size: usize,
schema: Option<Vec<super::statement::Column>>,
schema_read: bool,
}
impl std::fmt::Debug for QueryStream<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("QueryStream")
.field("finished", &self.finished)
.field("chunk_size", &self.chunk_size)
.field("schema_read", &self.schema_read)
.finish_non_exhaustive()
}
}
impl Drop for QueryStream<'_> {
fn drop(&mut self) {
if self.finished {
return;
}
self.canceller.cancel();
const POST_CANCEL_DRAIN_CAP: usize = 1024;
if let Some(ref mut conn) = self.conn {
let _ok = conn.drain_until_ready_bounded(POST_CANCEL_DRAIN_CAP);
}
}
}
impl QueryStream<'_> {
#[must_use]
pub fn schema(&self) -> Option<&[super::statement::Column]> {
self.schema.as_deref()
}
pub fn next_chunk(&mut self) -> Result<Option<Vec<StreamRow>>> {
if self.finished {
return Ok(None);
}
let Some(conn) = self.conn.as_mut() else {
return Ok(None);
};
let mut rows = Vec::with_capacity(self.chunk_size);
while rows.len() < self.chunk_size {
let msg = conn.read_message()?;
match msg {
Message::RowDescription(desc) if !self.schema_read => {
let mut cols = Vec::new();
for f in desc.fields().filter_map(std::result::Result::ok) {
cols.push(super::statement::Column::new(
f.name().to_string(),
f.type_oid(),
f.type_modifier(),
super::statement::ColumnFormat::from_code(f.format()),
));
}
self.schema = Some(cols);
self.schema_read = true;
}
Message::DataRow(data) => {
rows.push(StreamRow::new(data));
if rows.len() >= self.chunk_size {
return Ok(Some(rows));
}
}
Message::ReadyForQuery(_) => {
self.finished = true;
self.conn = None;
return if rows.is_empty() {
Ok(None)
} else {
Ok(Some(rows))
};
}
Message::ErrorResponse(body) => {
self.finished = true;
let err = match self.conn {
Some(ref mut c) => c.consume_error(&body),
None => parse_error_response(&body),
};
self.conn = None;
return Err(err);
}
_ => {}
}
}
Ok(Some(rows))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_affected_rows() {
assert_eq!(parse_affected_rows("INSERT 0 5"), Some(5));
assert_eq!(parse_affected_rows("UPDATE 10"), Some(10));
assert_eq!(parse_affected_rows("DELETE 3"), Some(3));
assert_eq!(parse_affected_rows("SELECT 100"), Some(100));
assert_eq!(parse_affected_rows("CREATE TABLE"), None);
}
#[test]
fn test_copy_in_raw_rejects_non_copy_query() {
let query = "SELECT * FROM users";
assert!(
!query.trim_start().to_ascii_uppercase().starts_with("COPY"),
"Non-COPY query should not pass the COPY prefix check"
);
let copy_query = "COPY \"users\" FROM STDIN WITH (FORMAT csv)";
assert!(
copy_query
.trim_start()
.to_ascii_uppercase()
.starts_with("COPY"),
"COPY query should pass the prefix check"
);
let padded = " COPY \"users\" FROM STDIN";
assert!(padded.trim_start().to_ascii_uppercase().starts_with("COPY"));
let lowercase = "copy \"users\" FROM STDIN";
assert!(lowercase
.trim_start()
.to_ascii_uppercase()
.starts_with("COPY"));
}
}