use bytes::{Buf, BufMut, BytesMut};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use tracing::{debug, info};
use crate::YamlBaseError;
use crate::config::Config;
use crate::database::{Storage, Value};
use crate::protocol::catalog_router::CatalogRouter;
use crate::protocol::postgres_catalog::PostgresCatalog;
use crate::protocol::postgres_copy::PostgresCopyProtocol;
use crate::protocol::postgres_extended::ExtendedProtocol;
use crate::protocol::postgres_functions::PostgresFunctionProtocol;
use crate::protocol::postgres_information_schema::PostgresInformationSchema;
use crate::protocol::shared_catalog::SharedCatalog;
use crate::sql::{QueryExecutor, parse_sql};
use crate::yaml::schema::SqlType;
pub struct PostgresProtocol {
config: Arc<Config>,
executor: QueryExecutor,
_database_name: String,
extended_protocol: ExtendedProtocol,
catalog: PostgresCatalog,
information_schema: PostgresInformationSchema,
catalog_router: CatalogRouter,
}
#[derive(Debug, Default)]
struct ConnectionState {
authenticated: bool,
username: Option<String>,
database: Option<String>,
parameters: HashMap<String, String>,
}
impl PostgresProtocol {
pub async fn new(config: Arc<Config>, storage: Arc<Storage>) -> crate::Result<Self> {
let executor = QueryExecutor::new(storage.clone()).await?;
let mut catalog = PostgresCatalog::new(storage.clone());
let mut information_schema = PostgresInformationSchema::new(storage.clone());
let db_arc = storage.database();
let db = db_arc.read().await;
let mut table_oid = 16384; for (table_name, table) in &db.tables {
catalog.add_user_table(table_name, table_oid, &table.columns);
information_schema.add_user_table(table_name, &table.columns);
table_oid += 1;
}
drop(db);
let catalog_router = CatalogRouter::new(catalog.clone(), information_schema.clone());
let catalog_router_arc = Arc::new(catalog_router.clone());
let extended_protocol = ExtendedProtocol::with_catalog_router(catalog_router_arc);
Ok(Self {
config,
executor,
_database_name: String::new(),
extended_protocol,
catalog,
information_schema,
catalog_router,
})
}
pub async fn new_with_shared_catalog(
config: Arc<Config>,
storage: Arc<Storage>,
shared_catalog: SharedCatalog
) -> crate::Result<Self> {
let executor = QueryExecutor::new(storage.clone()).await?;
let catalog_state = shared_catalog.read().await;
let catalog = catalog_state.postgres_catalog.clone();
let information_schema = catalog_state.information_schema.clone();
let catalog_router = catalog_state.catalog_router.clone();
drop(catalog_state);
let catalog_router_arc = Arc::new(catalog_router.clone());
let extended_protocol = ExtendedProtocol::with_catalog_router(catalog_router_arc);
Ok(Self {
config,
executor,
_database_name: String::new(),
extended_protocol,
catalog,
information_schema,
catalog_router,
})
}
fn sql_type_to_oid(sql_type: &SqlType) -> u32 {
match sql_type {
SqlType::Integer => 23, SqlType::BigInt => 20, SqlType::Char(_) => 1042, SqlType::Varchar(_) => 1043, SqlType::Text => 25, SqlType::Timestamp => 1114, SqlType::Date => 1082, SqlType::Time => 1083, SqlType::Boolean => 16, SqlType::Decimal(_, _) => 1700, SqlType::Float => 700, SqlType::Double => 701, SqlType::Uuid => 2950, SqlType::Json => 114, }
}
pub async fn handle_connection(&mut self, mut stream: TcpStream) -> crate::Result<()> {
info!("New PostgreSQL connection");
let mut buffer = BytesMut::with_capacity(4096);
let mut state = ConnectionState::default();
self.read_startup_message(&mut stream, &mut buffer, &mut state)
.await?;
loop {
if buffer.is_empty() && stream.read_buf(&mut buffer).await? == 0 {
info!("Client disconnected");
break;
}
if buffer.len() < 5 {
if stream.read_buf(&mut buffer).await? == 0 {
info!("Client disconnected");
break;
}
continue;
}
let msg_type = buffer[0];
let length = u32::from_be_bytes([buffer[1], buffer[2], buffer[3], buffer[4]]) as usize;
if buffer.len() < length + 1 {
if stream.read_buf(&mut buffer).await? == 0 {
return Ok(());
}
continue;
}
match msg_type {
b'Q' => {
let query = self.parse_query(&buffer[5..length + 1])?;
self.handle_query(&mut stream, &query).await?;
}
b'P' => {
self.extended_protocol
.handle_parse(&mut stream, &buffer[5..length + 1])
.await?;
}
b'B' => {
self.extended_protocol
.handle_bind(&mut stream, &buffer[5..length + 1])
.await?;
}
b'D' => {
self.extended_protocol
.handle_describe(&mut stream, &buffer[5..length + 1], &self.executor)
.await?;
}
b'E' => {
self.extended_protocol
.handle_execute(&mut stream, &buffer[5..length + 1], &self.executor)
.await?;
}
b'S' => {
self.extended_protocol.handle_sync(&mut stream).await?;
}
b'C' => {
let close_type = buffer[5];
let name_end = buffer[6..length + 1]
.iter()
.position(|&b| b == 0)
.unwrap_or(length - 5);
let name = std::str::from_utf8(&buffer[6..6 + name_end]).map_err(|_| {
YamlBaseError::Protocol("Invalid UTF-8 in close name".to_string())
})?;
if close_type == b'S' {
self.extended_protocol.close_statement(name);
} else if close_type == b'P' {
self.extended_protocol.close_portal(name);
}
let mut close_buf = BytesMut::new();
close_buf.put_u8(b'3');
close_buf.put_u32(4);
stream.write_all(&close_buf).await?;
}
b'F' => {
PostgresFunctionProtocol::handle_function_call(
&mut stream,
&buffer[5..length + 1],
)
.await?;
}
b'X' => {
info!("Client requested termination");
break;
}
_ => {
debug!("Unhandled message type: {}", msg_type as char);
self.send_error(&mut stream, "XX000", "Unsupported operation")
.await?;
}
}
buffer.advance(length + 1);
}
Ok(())
}
async fn read_startup_message(
&self,
stream: &mut TcpStream,
buffer: &mut BytesMut,
state: &mut ConnectionState,
) -> crate::Result<()> {
stream.read_buf(buffer).await?;
if buffer.len() < 8 {
return Err(YamlBaseError::Protocol(
"Invalid startup packet".to_string(),
));
}
let mut length = u32::from_be_bytes([buffer[0], buffer[1], buffer[2], buffer[3]]) as usize;
let version = u32::from_be_bytes([buffer[4], buffer[5], buffer[6], buffer[7]]);
if version == 80877103 {
stream.write_all(b"N").await?;
buffer.clear();
stream.read_buf(buffer).await?;
length = u32::from_be_bytes([buffer[0], buffer[1], buffer[2], buffer[3]]) as usize;
}
let mut pos = 8;
while pos < length - 1 {
let key_start = pos;
while pos < buffer.len() && buffer[pos] != 0 {
pos += 1;
}
let key = std::str::from_utf8(&buffer[key_start..pos])
.map_err(|_| YamlBaseError::Protocol("Invalid UTF-8 in startup".to_string()))?
.to_string();
pos += 1;
let val_start = pos;
while pos < buffer.len() && buffer[pos] != 0 {
pos += 1;
}
let val = std::str::from_utf8(&buffer[val_start..pos])
.map_err(|_| YamlBaseError::Protocol("Invalid UTF-8 in startup".to_string()))?
.to_string();
pos += 1;
match key.as_str() {
"user" => state.username = Some(val.clone()),
"database" => state.database = Some(val.clone()),
_ => {}
}
state.parameters.insert(key, val);
}
self.send_auth_request(stream).await?;
buffer.clear();
stream.read_buf(buffer).await?;
if buffer.len() >= 5 && buffer[0] == b'p' {
let msg_len = u32::from_be_bytes([buffer[1], buffer[2], buffer[3], buffer[4]]) as usize;
let password = self.parse_password_message(&buffer[5..5 + msg_len - 4])?;
debug!(
"Auth check - Expected: {}:{}, Got: {:?}:{}, Allow anonymous: {}",
self.config.username,
self.config.password,
state.username,
password,
self.config.allow_anonymous
);
if self.config.allow_anonymous
|| (state.username.as_deref() == Some(&self.config.username)
&& password == self.config.password)
{
state.authenticated = true;
self.send_auth_ok(stream, state).await?;
buffer.advance(1 + msg_len);
} else {
self.send_error(stream, "28P01", "Authentication failed")
.await?;
return Err(YamlBaseError::Protocol("Authentication failed".to_string()));
}
} else {
return Err(YamlBaseError::Protocol(
"Expected password message".to_string(),
));
}
Ok(())
}
async fn send_auth_request(&self, stream: &mut TcpStream) -> crate::Result<()> {
let mut buf = BytesMut::new();
buf.put_u8(b'R');
buf.put_u32(8); buf.put_u32(3);
stream.write_all(&buf).await?;
Ok(())
}
async fn send_auth_ok(
&self,
stream: &mut TcpStream,
state: &ConnectionState,
) -> crate::Result<()> {
let mut buf = BytesMut::new();
buf.put_u8(b'R');
buf.put_u32(8);
buf.put_u32(0);
stream.write_all(&buf).await?;
buf.clear();
buf.put_u8(b'K');
buf.put_u32(12);
buf.put_u32(12345); buf.put_u32(67890); stream.write_all(&buf).await?;
self.send_parameter_status(stream, "server_version", "14.0")
.await?;
self.send_parameter_status(stream, "server_encoding", "UTF8")
.await?;
self.send_parameter_status(stream, "client_encoding", "UTF8")
.await?;
self.send_parameter_status(stream, "DateStyle", "ISO, MDY")
.await?;
self.send_parameter_status(stream, "TimeZone", "UTC")
.await?;
self.send_parameter_status(stream, "integer_datetimes", "on")
.await?;
self.send_parameter_status(stream, "IntervalStyle", "postgres")
.await?;
self.send_parameter_status(stream, "standard_conforming_strings", "on")
.await?;
self.send_parameter_status(stream, "application_name", "")
.await?;
self.send_parameter_status(stream, "is_superuser", "off")
.await?;
self.send_parameter_status(
stream,
"session_authorization",
&state.username.clone().unwrap_or_default(),
)
.await?;
self.send_ready_for_query(stream).await?;
Ok(())
}
async fn send_parameter_status(
&self,
stream: &mut TcpStream,
name: &str,
value: &str,
) -> crate::Result<()> {
let mut buf = BytesMut::new();
buf.put_u8(b'S');
let length = 4 + name.len() + 1 + value.len() + 1;
buf.put_u32(length as u32);
buf.put_slice(name.as_bytes());
buf.put_u8(0);
buf.put_slice(value.as_bytes());
buf.put_u8(0);
stream.write_all(&buf).await?;
Ok(())
}
async fn send_ready_for_query(&self, stream: &mut TcpStream) -> crate::Result<()> {
let mut buf = BytesMut::new();
buf.put_u8(b'Z');
buf.put_u32(5);
buf.put_u8(b'I');
stream.write_all(&buf).await?;
Ok(())
}
async fn handle_query(&self, stream: &mut TcpStream, query: &str) -> crate::Result<()> {
debug!("Executing query: {}", query);
if PostgresCopyProtocol::is_copy_command(query) {
match PostgresCopyProtocol::parse_copy_command(query) {
Ok(copy_cmd) => {
match parse_sql(©_cmd.select_query) {
Ok(statements) if !statements.is_empty() => {
match self.executor.execute(&statements[0]).await {
Ok(result) => {
PostgresCopyProtocol::handle_copy_to_stdout(
stream,
&result,
copy_cmd.format,
)
.await?;
self.send_ready_for_query(stream).await?;
return Ok(());
}
Err(e) => {
self.send_error(stream, "XX000", &e.to_string()).await?;
}
}
}
Ok(_) => {
self.send_error(stream, "42601", "Invalid COPY statement")
.await?;
}
Err(e) => {
self.send_error(
stream,
"42601",
&format!("Syntax error in COPY: {}", e),
)
.await?;
}
}
}
Err(e) => {
self.send_error(stream, "42601", &e.to_string()).await?;
}
}
self.send_ready_for_query(stream).await?;
return Ok(());
}
if let Some(result) = self.catalog_router.route_query(query)? {
self.send_query_result(stream, &result).await?;
self.send_ready_for_query(stream).await?;
return Ok(());
}
if let Some(result) = self.handle_catalog_query(query).await? {
self.send_query_result(stream, &result).await?;
self.send_ready_for_query(stream).await?;
return Ok(());
}
let statements = match parse_sql(query) {
Ok(stmts) => stmts,
Err(e) => {
self.send_error(stream, "42601", &format!("Syntax error: {}", e))
.await?;
self.send_ready_for_query(stream).await?;
return Ok(());
}
};
for statement in statements {
match self.executor.execute(&statement).await {
Ok(result) => {
self.send_query_result(stream, &result).await?;
}
Err(e) => {
self.send_error(stream, "XX000", &e.to_string()).await?;
}
}
}
self.send_ready_for_query(stream).await?;
Ok(())
}
async fn send_query_result(
&self,
stream: &mut TcpStream,
result: &crate::sql::executor::QueryResult,
) -> crate::Result<()> {
if !result.columns.is_empty() {
let mut buf = BytesMut::new();
buf.put_u8(b'T');
let mut length = 6; for col in &result.columns {
length += col.len() + 1 + 18; }
buf.put_u32(length as u32);
buf.put_u16(result.columns.len() as u16);
for (i, col) in result.columns.iter().enumerate() {
buf.put_slice(col.as_bytes());
buf.put_u8(0); buf.put_u32(0); buf.put_u16(i as u16);
let type_oid = if i < result.column_types.len() {
Self::sql_type_to_oid(&result.column_types[i])
} else {
25 };
buf.put_u32(type_oid);
buf.put_i16(-1); buf.put_i32(-1); buf.put_i16(0); }
stream.write_all(&buf).await?;
}
for row in &result.rows {
let mut buf = BytesMut::new();
buf.put_u8(b'D');
let mut row_length = 6; for val in row {
if matches!(val, Value::Null) {
row_length += 4; } else {
let val_str = val.to_string();
row_length += 4 + val_str.len(); }
}
buf.put_u32(row_length as u32);
buf.put_u16(row.len() as u16);
for val in row {
if matches!(val, Value::Null) {
buf.put_i32(-1); } else {
let val_str = val.to_string();
buf.put_i32(val_str.len() as i32);
buf.put_slice(val_str.as_bytes());
}
}
stream.write_all(&buf).await?;
}
let mut buf = BytesMut::new();
buf.put_u8(b'C');
let tag = if result.columns.is_empty() {
"BEGIN".to_string() } else {
format!("SELECT {}", result.rows.len())
};
buf.put_u32(4 + tag.len() as u32 + 1);
buf.put_slice(tag.as_bytes());
buf.put_u8(0);
stream.write_all(&buf).await?;
Ok(())
}
async fn send_error(
&self,
stream: &mut TcpStream,
code: &str,
message: &str,
) -> crate::Result<()> {
let mut buf = BytesMut::new();
buf.put_u8(b'E');
let error_fields = vec![(b'S', "ERROR"), (b'C', code), (b'M', message)];
let mut length = 4; for (_, val) in &error_fields {
length += 1 + val.len() + 1; }
length += 1;
buf.put_u32(length as u32);
for (field_type, val) in error_fields {
buf.put_u8(field_type);
buf.put_slice(val.as_bytes());
buf.put_u8(0);
}
buf.put_u8(0);
stream.write_all(&buf).await?;
Ok(())
}
fn parse_query(&self, data: &[u8]) -> crate::Result<String> {
let end = data.iter().position(|&b| b == 0).unwrap_or(data.len());
Ok(std::str::from_utf8(&data[..end])
.map_err(|_| YamlBaseError::Protocol("Invalid UTF-8 in query".to_string()))?
.to_string())
}
fn parse_password_message(&self, data: &[u8]) -> crate::Result<String> {
let end = data.iter().position(|&b| b == 0).unwrap_or(data.len());
Ok(std::str::from_utf8(&data[..end])
.map_err(|_| YamlBaseError::Protocol("Invalid UTF-8 in password".to_string()))?
.to_string())
}
async fn handle_catalog_query(
&self,
query: &str,
) -> crate::Result<Option<crate::sql::executor::QueryResult>> {
let query_upper = query.trim().to_uppercase();
if query_upper.contains("PG_TYPE") || query_upper.contains("PG_CATALOG.PG_TYPE") {
return Ok(Some(self.catalog.query_pg_type(Some(query))));
}
if query_upper.contains("PG_CLASS") || query_upper.contains("PG_CATALOG.PG_CLASS") {
return Ok(Some(self.catalog.query_pg_class(Some(query))));
}
if query_upper.contains("PG_ATTRIBUTE") || query_upper.contains("PG_CATALOG.PG_ATTRIBUTE") {
return Ok(Some(self.catalog.query_pg_attribute(Some(query))));
}
if query_upper.contains("PG_NAMESPACE") || query_upper.contains("PG_CATALOG.PG_NAMESPACE") {
return Ok(Some(self.catalog.query_pg_namespace()));
}
if query_upper.contains("PG_DATABASE") || query_upper.contains("PG_CATALOG.PG_DATABASE") {
return Ok(Some(self.catalog.query_pg_database()));
}
if query_upper.contains("PG_TABLES") || query_upper.contains("PG_CATALOG.PG_TABLES") {
return Ok(Some(self.catalog.query_pg_tables()));
}
if query_upper.contains("PG_STATIO_USER_TABLES") || query_upper.contains("PG_CATALOG.PG_STATIO_USER_TABLES") {
return Ok(Some(self.catalog.query_pg_statio_user_tables()));
}
if query_upper.contains("INFORMATION_SCHEMA.TABLES") {
return Ok(Some(self.information_schema.query_tables(Some(query))));
}
if query_upper.contains("INFORMATION_SCHEMA.COLUMNS") {
return Ok(Some(self.information_schema.query_columns(Some(query))));
}
if query_upper.contains("INFORMATION_SCHEMA.SCHEMATA") {
return Ok(Some(self.information_schema.query_schemata(Some(query))));
}
if query_upper.contains("SELECT VERSION()") {
let result = crate::sql::executor::QueryResult {
columns: vec!["version".to_string()],
column_types: vec![SqlType::Text],
rows: vec![vec![Value::Text(
"PostgreSQL 14.0 (YamlBase Mock Server)".to_string(),
)]],
};
return Ok(Some(result));
}
if query_upper.contains("SELECT CURRENT_DATABASE()") {
let result = crate::sql::executor::QueryResult {
columns: vec!["current_database".to_string()],
column_types: vec![SqlType::Text],
rows: vec![vec![Value::Text("postgres".to_string())]],
};
return Ok(Some(result));
}
if query_upper.contains("SELECT CURRENT_SCHEMA()") {
let result = crate::sql::executor::QueryResult {
columns: vec!["current_schema".to_string()],
column_types: vec![SqlType::Text],
rows: vec![vec![Value::Text("public".to_string())]],
};
return Ok(Some(result));
}
if query_upper.contains("SHOW ") {
return self.handle_show_command(query).await;
}
Ok(None)
}
async fn handle_show_command(
&self,
query: &str,
) -> crate::Result<Option<crate::sql::executor::QueryResult>> {
let query_upper = query.trim().to_uppercase();
if query_upper.contains("SHOW SERVER_VERSION")
|| query_upper.contains("SHOW server_version")
{
let result = crate::sql::executor::QueryResult {
columns: vec!["server_version".to_string()],
column_types: vec![SqlType::Text],
rows: vec![vec![Value::Text("14.0".to_string())]],
};
return Ok(Some(result));
}
if query_upper.contains("SHOW CLIENT_ENCODING")
|| query_upper.contains("SHOW client_encoding")
{
let result = crate::sql::executor::QueryResult {
columns: vec!["client_encoding".to_string()],
column_types: vec![SqlType::Text],
rows: vec![vec![Value::Text("UTF8".to_string())]],
};
return Ok(Some(result));
}
if query_upper.contains("SHOW TIMEZONE") || query_upper.contains("SHOW timezone") {
let result = crate::sql::executor::QueryResult {
columns: vec!["TimeZone".to_string()],
column_types: vec![SqlType::Text],
rows: vec![vec![Value::Text("UTC".to_string())]],
};
return Ok(Some(result));
}
if query_upper.contains("SHOW TRANSACTION ISOLATION LEVEL")
|| query_upper.contains("SHOW TRANSACTION_ISOLATION")
{
let result = crate::sql::executor::QueryResult {
columns: vec!["transaction_isolation".to_string()],
column_types: vec![SqlType::Text],
rows: vec![vec![Value::Text("read committed".to_string())]],
};
return Ok(Some(result));
}
if query_upper.contains("SHOW STANDARD_CONFORMING_STRINGS") {
let result = crate::sql::executor::QueryResult {
columns: vec!["standard_conforming_strings".to_string()],
column_types: vec![SqlType::Text],
rows: vec![vec![Value::Text("on".to_string())]],
};
return Ok(Some(result));
}
if query_upper.contains("SHOW ALL") {
let result = crate::sql::executor::QueryResult {
columns: vec![
"name".to_string(),
"setting".to_string(),
"description".to_string(),
],
column_types: vec![SqlType::Text, SqlType::Text, SqlType::Text],
rows: vec![
vec![
Value::Text("server_version".to_string()),
Value::Text("14.0".to_string()),
Value::Text("PostgreSQL version".to_string()),
],
vec![
Value::Text("client_encoding".to_string()),
Value::Text("UTF8".to_string()),
Value::Text("Client character encoding".to_string()),
],
vec![
Value::Text("TimeZone".to_string()),
Value::Text("UTC".to_string()),
Value::Text("Time zone".to_string()),
],
vec![
Value::Text("DateStyle".to_string()),
Value::Text("ISO, MDY".to_string()),
Value::Text("Date display style".to_string()),
],
vec![
Value::Text("transaction_isolation".to_string()),
Value::Text("read committed".to_string()),
Value::Text("Transaction isolation level".to_string()),
],
vec![
Value::Text("standard_conforming_strings".to_string()),
Value::Text("on".to_string()),
Value::Text("Treat backslashes literally in string literals".to_string()),
],
],
};
return Ok(Some(result));
}
Ok(None)
}
}