use std::io::{self, Read, Write};
use std::net::TcpStream;
use std::sync::RwLock;
use std::time::Duration;
#[derive(Debug, Clone)]
pub enum ForwardedResult {
Rows {
columns: Vec<ColumnInfo>,
rows: Vec<Vec<Option<String>>>,
},
Command {
tag: String,
rows_affected: u64,
},
Error {
severity: String,
code: String,
message: String,
detail: Option<String>,
hint: Option<String>,
},
}
#[derive(Debug, Clone)]
pub struct ColumnInfo {
pub name: String,
pub type_oid: i32,
}
pub struct QueryForwarder {
primary_host: String,
primary_port: u16,
connection_timeout: Duration,
query_timeout: Duration,
connections: RwLock<Vec<TcpStream>>,
max_connections: usize,
}
impl QueryForwarder {
pub fn new(primary_host: String, primary_port: u16) -> Self {
Self {
primary_host,
primary_port,
connection_timeout: Duration::from_secs(5),
query_timeout: Duration::from_secs(30),
connections: RwLock::new(Vec::new()),
max_connections: 10,
}
}
pub fn with_connection_timeout(mut self, timeout: Duration) -> Self {
self.connection_timeout = timeout;
self
}
pub fn with_query_timeout(mut self, timeout: Duration) -> Self {
self.query_timeout = timeout;
self
}
pub fn forward_query(&self, query: &str) -> Result<ForwardedResult, ForwarderError> {
let mut conn = self.get_connection()?;
let result = self.execute_query(&mut conn, query);
if result.is_ok() {
self.return_connection(conn);
}
result
}
fn get_connection(&self) -> Result<TcpStream, ForwarderError> {
if let Ok(mut pool) = self.connections.write() {
if let Some(conn) = pool.pop() {
if Self::is_connection_alive(&conn) {
return Ok(conn);
}
}
}
self.create_connection()
}
fn return_connection(&self, conn: TcpStream) {
if let Ok(mut pool) = self.connections.write() {
if pool.len() < self.max_connections {
pool.push(conn);
}
}
}
fn is_connection_alive(conn: &TcpStream) -> bool {
if conn.set_nonblocking(true).is_err() {
return false;
}
let mut buf = [0u8; 1];
let result = match conn.peek(&mut buf) {
Ok(0) => false, Ok(_) => true, Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => true, Err(_) => false, };
let _ = conn.set_nonblocking(false);
result
}
fn create_connection(&self) -> Result<TcpStream, ForwarderError> {
use std::net::ToSocketAddrs;
let addr_str = format!("{}:{}", self.primary_host, self.primary_port);
let addr = addr_str
.to_socket_addrs()
.map_err(|e| ForwarderError::Connection(format!("Cannot resolve '{}': {}", addr_str, e)))?
.next()
.ok_or_else(|| ForwarderError::Connection(format!("No addresses found for '{}'", addr_str)))?;
let conn = TcpStream::connect_timeout(&addr, self.connection_timeout)
.map_err(|e| ForwarderError::Connection(format!("Failed to connect to primary at {}: {}", addr_str, e)))?;
conn.set_read_timeout(Some(self.query_timeout))
.map_err(|e| ForwarderError::Connection(format!("Failed to set read timeout: {}", e)))?;
conn.set_write_timeout(Some(self.query_timeout))
.map_err(|e| ForwarderError::Connection(format!("Failed to set write timeout: {}", e)))?;
let mut conn = conn;
self.perform_startup(&mut conn)?;
Ok(conn)
}
fn perform_startup(&self, conn: &mut TcpStream) -> Result<(), ForwarderError> {
let mut params: Vec<u8> = Vec::new();
params.extend_from_slice(b"user\0");
params.extend_from_slice(b"heliosdb\0");
params.extend_from_slice(b"database\0");
params.extend_from_slice(b"heliosdb\0");
params.push(0);
let msg_len = 4 + 4 + params.len(); let mut msg = Vec::with_capacity(msg_len);
msg.extend_from_slice(&(msg_len as i32).to_be_bytes());
msg.extend_from_slice(&196608i32.to_be_bytes()); msg.extend_from_slice(¶ms);
conn.write_all(&msg)
.map_err(|e| ForwarderError::Protocol(format!("Failed to send startup: {}", e)))?;
conn.flush()
.map_err(|e| ForwarderError::Protocol(format!("Failed to flush startup: {}", e)))?;
loop {
let msg_type = self.read_byte(conn)?;
let msg_len = self.read_i32(conn)? as usize - 4;
match msg_type {
b'R' => {
let auth_type = self.read_i32(conn)?;
if auth_type == 0 {
continue;
} else {
let remaining = msg_len - 4;
if remaining > 0 {
let mut buf = vec![0u8; remaining];
conn.read_exact(&mut buf)
.map_err(|e| ForwarderError::Protocol(format!("Failed to read auth data: {}", e)))?;
}
}
}
b'S' => {
let mut buf = vec![0u8; msg_len];
conn.read_exact(&mut buf)
.map_err(|e| ForwarderError::Protocol(format!("Failed to read param status: {}", e)))?;
}
b'K' => {
let mut buf = vec![0u8; msg_len];
conn.read_exact(&mut buf)
.map_err(|e| ForwarderError::Protocol(format!("Failed to read backend key: {}", e)))?;
}
b'Z' => {
let mut buf = vec![0u8; msg_len];
conn.read_exact(&mut buf)
.map_err(|e| ForwarderError::Protocol(format!("Failed to read ready: {}", e)))?;
return Ok(());
}
b'E' => {
let error = self.parse_error_response(conn, msg_len)?;
return Err(ForwarderError::Primary(error));
}
_ => {
let mut buf = vec![0u8; msg_len];
conn.read_exact(&mut buf)
.map_err(|e| ForwarderError::Protocol(format!("Failed to skip message: {}", e)))?;
}
}
}
}
fn execute_query(&self, conn: &mut TcpStream, query: &str) -> Result<ForwardedResult, ForwarderError> {
let query_bytes = query.as_bytes();
let msg_len = 4 + query_bytes.len() + 1;
let mut msg = Vec::with_capacity(1 + msg_len);
msg.push(b'Q');
msg.extend_from_slice(&(msg_len as i32).to_be_bytes());
msg.extend_from_slice(query_bytes);
msg.push(0);
conn.write_all(&msg)
.map_err(|e| ForwarderError::Protocol(format!("Failed to send query: {}", e)))?;
conn.flush()
.map_err(|e| ForwarderError::Protocol(format!("Failed to flush query: {}", e)))?;
let mut columns: Vec<ColumnInfo> = Vec::new();
let mut rows: Vec<Vec<Option<String>>> = Vec::new();
let mut command_tag: Option<String> = None;
loop {
let msg_type = self.read_byte(conn)?;
let msg_len = self.read_i32(conn)? as usize - 4;
match msg_type {
b'T' => {
columns = self.parse_row_description(conn, msg_len)?;
}
b'D' => {
let row = self.parse_data_row(conn, msg_len, columns.len())?;
rows.push(row);
}
b'C' => {
let mut buf = vec![0u8; msg_len];
conn.read_exact(&mut buf)
.map_err(|e| ForwarderError::Protocol(format!("Failed to read command complete: {}", e)))?;
if let Some(0) = buf.last() {
buf.pop();
}
command_tag = Some(String::from_utf8_lossy(&buf).to_string());
}
b'Z' => {
let mut buf = vec![0u8; msg_len];
conn.read_exact(&mut buf)
.map_err(|e| ForwarderError::Protocol(format!("Failed to read ready: {}", e)))?;
if !columns.is_empty() || !rows.is_empty() {
return Ok(ForwardedResult::Rows { columns, rows });
} else if let Some(tag) = command_tag {
let rows_affected = Self::parse_rows_affected(&tag);
return Ok(ForwardedResult::Command { tag, rows_affected });
} else {
return Ok(ForwardedResult::Command {
tag: "OK".to_string(),
rows_affected: 0,
});
}
}
b'E' => {
let error = self.parse_error_response(conn, msg_len)?;
loop {
let mt = self.read_byte(conn)?;
let ml = self.read_i32(conn)? as usize - 4;
let mut buf = vec![0u8; ml];
conn.read_exact(&mut buf).ok();
if mt == b'Z' {
break;
}
}
return Ok(error);
}
b'N' => {
let mut buf = vec![0u8; msg_len];
conn.read_exact(&mut buf)
.map_err(|e| ForwarderError::Protocol(format!("Failed to read notice: {}", e)))?;
}
b'I' => {
}
_ => {
let mut buf = vec![0u8; msg_len];
conn.read_exact(&mut buf)
.map_err(|e| ForwarderError::Protocol(format!("Failed to skip message type {}: {}", msg_type as char, e)))?;
}
}
}
}
fn parse_row_description(&self, conn: &mut TcpStream, _msg_len: usize) -> Result<Vec<ColumnInfo>, ForwarderError> {
let num_fields = self.read_i16(conn)? as usize;
let mut columns = Vec::with_capacity(num_fields);
for _ in 0..num_fields {
let name = self.read_string(conn)?;
let _table_oid = self.read_i32(conn)?;
let _column_attr = self.read_i16(conn)?;
let type_oid = self.read_i32(conn)?;
let _type_size = self.read_i16(conn)?;
let _type_mod = self.read_i32(conn)?;
let _format = self.read_i16(conn)?;
columns.push(ColumnInfo { name, type_oid });
}
Ok(columns)
}
fn parse_data_row(&self, conn: &mut TcpStream, _msg_len: usize, num_columns: usize) -> Result<Vec<Option<String>>, ForwarderError> {
let num_values = self.read_i16(conn)? as usize;
let mut row = Vec::with_capacity(num_columns.max(num_values));
for _ in 0..num_values {
let len = self.read_i32(conn)?;
if len == -1 {
row.push(None); } else {
let mut buf = vec![0u8; len as usize];
conn.read_exact(&mut buf)
.map_err(|e| ForwarderError::Protocol(format!("Failed to read data: {}", e)))?;
row.push(Some(String::from_utf8_lossy(&buf).to_string()));
}
}
Ok(row)
}
fn parse_error_response(&self, conn: &mut TcpStream, msg_len: usize) -> Result<ForwardedResult, ForwarderError> {
let mut buf = vec![0u8; msg_len];
conn.read_exact(&mut buf)
.map_err(|e| ForwarderError::Protocol(format!("Failed to read error: {}", e)))?;
let mut severity = String::from("ERROR");
let mut code = String::from("XX000");
let mut message = String::from("Unknown error");
let mut detail = None;
let mut hint = None;
let mut i = 0;
while i < buf.len() {
let field_type = buf[i];
i += 1;
if field_type == 0 {
break;
}
let start = i;
while i < buf.len() && buf[i] != 0 {
i += 1;
}
let value = String::from_utf8_lossy(&buf[start..i]).to_string();
i += 1;
match field_type {
b'S' => severity = value,
b'C' => code = value,
b'M' => message = value,
b'D' => detail = Some(value),
b'H' => hint = Some(value),
_ => {} }
}
Ok(ForwardedResult::Error {
severity,
code,
message,
detail,
hint,
})
}
fn parse_rows_affected(tag: &str) -> u64 {
let parts: Vec<&str> = tag.split_whitespace().collect();
if let Some(last) = parts.last() {
last.parse().unwrap_or(0)
} else {
0
}
}
fn read_byte(&self, conn: &mut TcpStream) -> Result<u8, ForwarderError> {
let mut buf = [0u8; 1];
conn.read_exact(&mut buf)
.map_err(|e| ForwarderError::Protocol(format!("Failed to read byte: {}", e)))?;
Ok(buf[0])
}
fn read_i16(&self, conn: &mut TcpStream) -> Result<i16, ForwarderError> {
let mut buf = [0u8; 2];
conn.read_exact(&mut buf)
.map_err(|e| ForwarderError::Protocol(format!("Failed to read i16: {}", e)))?;
Ok(i16::from_be_bytes(buf))
}
fn read_i32(&self, conn: &mut TcpStream) -> Result<i32, ForwarderError> {
let mut buf = [0u8; 4];
conn.read_exact(&mut buf)
.map_err(|e| ForwarderError::Protocol(format!("Failed to read i32: {}", e)))?;
Ok(i32::from_be_bytes(buf))
}
fn read_string(&self, conn: &mut TcpStream) -> Result<String, ForwarderError> {
let mut bytes = Vec::new();
loop {
let b = self.read_byte(conn)?;
if b == 0 {
break;
}
bytes.push(b);
}
Ok(String::from_utf8_lossy(&bytes).to_string())
}
}
#[derive(Debug)]
pub enum ForwarderError {
Connection(String),
Protocol(String),
Primary(ForwardedResult),
NotConfigured,
}
impl std::fmt::Display for ForwarderError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ForwarderError::Connection(msg) => write!(f, "Connection error: {}", msg),
ForwarderError::Protocol(msg) => write!(f, "Protocol error: {}", msg),
ForwarderError::Primary(result) => {
if let ForwardedResult::Error { message, .. } = result {
write!(f, "Primary error: {}", message)
} else {
write!(f, "Primary error")
}
}
ForwarderError::NotConfigured => write!(f, "Query forwarder not configured"),
}
}
}
impl std::error::Error for ForwarderError {}
static QUERY_FORWARDER: once_cell::sync::OnceCell<QueryForwarder> = once_cell::sync::OnceCell::new();
pub fn init_query_forwarder(primary_host: String, primary_port: u16) {
let _ = QUERY_FORWARDER.set(QueryForwarder::new(primary_host, primary_port));
}
pub fn query_forwarder() -> Option<&'static QueryForwarder> {
QUERY_FORWARDER.get()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_rows_affected() {
assert_eq!(QueryForwarder::parse_rows_affected("INSERT 0 1"), 1);
assert_eq!(QueryForwarder::parse_rows_affected("UPDATE 5"), 5);
assert_eq!(QueryForwarder::parse_rows_affected("DELETE 10"), 10);
assert_eq!(QueryForwarder::parse_rows_affected("SELECT 100"), 100);
assert_eq!(QueryForwarder::parse_rows_affected("CREATE TABLE"), 0);
}
}