use crate::auth::PasswordStore;
use crate::config::Config;
use crate::observability::ObservabilityProvider;
use crate::protocol::{
BackendMessage, FieldDescription, FrontendMessage, SubscriptionUpdateType, TransactionStatus,
};
use crate::session::{ExecutionResult, Session};
use crate::subscription::{SessionSubscriptionManager, SubscriptionManager};
use anyhow::Result;
use bytes::BytesMut;
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::Instant;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use tracing::{debug, error, info, warn};
use vibesql_executor::cache::table_extractor;
pub struct ConnectionHandler {
stream: TcpStream,
peer_addr: SocketAddr,
config: Arc<Config>,
observability: Arc<ObservabilityProvider>,
password_store: Option<Arc<PasswordStore>>,
read_buf: BytesMut,
write_buf: BytesMut,
session: Option<Session>,
connection_start: Instant,
active_connections: Arc<AtomicUsize>,
subscription_manager: SessionSubscriptionManager,
#[allow(dead_code)]
global_subscription_manager: Arc<SubscriptionManager>,
}
impl ConnectionHandler {
pub fn new(
stream: TcpStream,
peer_addr: SocketAddr,
config: Arc<Config>,
observability: Arc<ObservabilityProvider>,
password_store: Option<Arc<PasswordStore>>,
active_connections: Arc<AtomicUsize>,
global_subscription_manager: Arc<SubscriptionManager>,
) -> Self {
Self {
stream,
peer_addr,
config,
observability,
password_store,
read_buf: BytesMut::with_capacity(8192),
write_buf: BytesMut::with_capacity(8192),
session: None,
connection_start: Instant::now(),
active_connections,
subscription_manager: SessionSubscriptionManager::new(),
global_subscription_manager,
}
}
pub async fn handle(&mut self) -> Result<()> {
self.startup_handshake().await?;
self.process_queries().await?;
Ok(())
}
async fn startup_handshake(&mut self) -> Result<()> {
debug!("Starting handshake with {}", self.peer_addr);
self.read_message().await?;
let startup_msg = FrontendMessage::decode_startup(&mut self.read_buf)?;
match startup_msg {
Some(FrontendMessage::SSLRequest) => {
debug!("Received SSL request");
self.stream.write_u8(b'N').await?;
self.stream.flush().await?;
self.read_buf.clear();
self.read_message().await?;
let startup_msg = FrontendMessage::decode_startup(&mut self.read_buf)?;
self.handle_startup(startup_msg).await?;
}
Some(msg) => {
self.handle_startup(Some(msg)).await?;
}
None => {
return Err(anyhow::anyhow!("No startup message received"));
}
}
Ok(())
}
async fn handle_startup(&mut self, msg: Option<FrontendMessage>) -> Result<()> {
match msg {
Some(FrontendMessage::Startup {
protocol_version,
params,
}) => {
debug!("Startup: version={}, params={:?}", protocol_version, params);
let user = params.get("user").cloned().unwrap_or_else(|| "postgres".to_string());
let database = params.get("database").cloned().unwrap_or_else(|| user.clone());
self.authenticate(&user).await?;
self.session = Some(Session::new(database.clone(), user.clone())?);
info!("User '{}' connected to database '{}'", user, database);
self.send_parameter_status("server_version", "14.0 (VibeSQL)").await?;
self.send_parameter_status("server_encoding", "UTF8").await?;
self.send_parameter_status("client_encoding", "UTF8").await?;
self.send_parameter_status("DateStyle", "ISO, MDY").await?;
self.send_parameter_status("TimeZone", "UTC").await?;
self.send_backend_key_data().await?;
self.send_ready_for_query(TransactionStatus::Idle).await?;
Ok(())
}
_ => Err(anyhow::anyhow!("Invalid startup message")),
}
}
async fn authenticate(&mut self, user: &str) -> Result<()> {
match self.config.auth.method.as_str() {
"trust" => {
debug!("Using trust authentication for user '{}'", user);
self.send_authentication_ok().await?;
Ok(())
}
"password" => {
debug!("Requesting cleartext password for user '{}'", user);
self.send_cleartext_password_request().await?;
self.read_message().await?;
let msg = FrontendMessage::decode(&mut self.read_buf)?;
match msg {
Some(FrontendMessage::Password { password }) => {
debug!("Received password from user '{}'", user);
if let Some(ref store) = self.password_store {
if store.verify_cleartext(user, &password) {
info!("User '{}' authenticated successfully", user);
self.send_authentication_ok().await?;
Ok(())
} else {
error!("Authentication failed for user '{}'", user);
Err(anyhow::anyhow!("Authentication failed"))
}
} else {
error!("No password store configured");
Err(anyhow::anyhow!("Authentication not configured"))
}
}
_ => {
error!("Expected password message, got: {:?}", msg);
Err(anyhow::anyhow!("Expected password message"))
}
}
}
"md5" => {
debug!("Requesting MD5 password for user '{}'", user);
use rand::Rng;
let salt: [u8; 4] = rand::rng().random();
self.send_md5_password_request(&salt).await?;
self.read_message().await?;
let msg = FrontendMessage::decode(&mut self.read_buf)?;
match msg {
Some(FrontendMessage::Password { password }) => {
debug!("Received MD5 password response from user '{}'", user);
if let Some(ref store) = self.password_store {
if store.verify_md5(user, &password, &salt) {
info!("User '{}' authenticated successfully (MD5)", user);
self.send_authentication_ok().await?;
Ok(())
} else {
error!("MD5 authentication failed for user '{}'", user);
Err(anyhow::anyhow!("Authentication failed"))
}
} else {
error!("No password store configured");
Err(anyhow::anyhow!("Authentication not configured"))
}
}
_ => {
error!("Expected password message, got: {:?}", msg);
Err(anyhow::anyhow!("Expected password message"))
}
}
}
"scram-sha-256" => {
error!("SCRAM-SHA-256 authentication not yet implemented");
Err(anyhow::anyhow!("SCRAM-SHA-256 not implemented"))
}
_ => {
error!("Unsupported authentication method: {}", self.config.auth.method);
Err(anyhow::anyhow!("Unsupported authentication method"))
}
}
}
async fn process_queries(&mut self) -> Result<()> {
loop {
self.read_message().await?;
let msg = FrontendMessage::decode(&mut self.read_buf)?;
match msg {
Some(FrontendMessage::Query { query }) => {
debug!("Query: {}", query);
self.execute_query(&query).await?;
}
Some(FrontendMessage::Subscribe { query, params }) => {
debug!("Subscribe: {}", query);
self.handle_subscribe(&query, params).await?;
}
Some(FrontendMessage::Unsubscribe { subscription_id }) => {
debug!("Unsubscribe: {:?}", subscription_id);
self.subscription_manager.unsubscribe(&subscription_id);
}
Some(FrontendMessage::Terminate) => {
debug!("Client requested termination");
break;
}
Some(msg) => {
warn!("Unexpected message: {:?}", msg);
}
None => {
debug!("Connection closed by client");
break;
}
}
}
self.subscription_manager.clear();
Ok(())
}
async fn execute_query(&mut self, query: &str) -> Result<()> {
let session = self.session.as_mut().ok_or_else(|| anyhow::anyhow!("No session"))?;
if query.trim().is_empty() {
self.send_empty_query_response().await?;
self.send_ready_for_query(TransactionStatus::Idle).await?;
return Ok(());
}
let query_start = Instant::now();
match session.execute(query) {
Ok(result) => {
let query_duration = query_start.elapsed();
let stmt_type = result.statement_type();
let rows_affected = result.rows_affected();
if let Some(metrics) = self.observability.metrics() {
metrics.record_query(query_duration, stmt_type, true, rows_affected);
}
self.send_query_result(result).await?;
self.send_ready_for_query(TransactionStatus::Idle).await?;
Ok(())
}
Err(e) => {
error!("Query error: {}", e);
if let Some(metrics) = self.observability.metrics() {
metrics.record_query_error("execution_error", None);
}
self.send_error_response(&format!("{}", e)).await?;
self.send_ready_for_query(TransactionStatus::Idle).await?;
Ok(())
}
}
}
async fn handle_subscribe(
&mut self,
query: &str,
params: Vec<Option<Vec<u8>>>,
) -> Result<()> {
let session = self.session.as_mut().ok_or_else(|| anyhow::anyhow!("No session"))?;
let parsed = match vibesql_parser::Parser::parse_sql(query) {
Ok(stmt) => stmt,
Err(e) => {
let error_id = [0u8; 16];
self.send_subscription_error(&error_id, &format!("Parse error: {}", e))
.await?;
return Ok(());
}
};
let table_dependencies = table_extractor::extract_tables_from_statement(&parsed);
let subscription_id =
self.subscription_manager
.subscribe(query.to_string(), params, table_dependencies);
match session.execute(query) {
Ok(ExecutionResult::Select { rows, .. }) => {
let wire_rows: Vec<Vec<Option<Vec<u8>>>> = rows
.iter()
.map(|row| {
row.values
.iter()
.map(|v| Some(v.to_string().as_bytes().to_vec()))
.collect()
})
.collect();
self.send_subscription_data(
&subscription_id,
SubscriptionUpdateType::Full,
wire_rows,
)
.await?;
}
Ok(_) => {
self.subscription_manager.unsubscribe(&subscription_id);
self.send_subscription_error(
&subscription_id,
"Only SELECT queries can be subscribed to",
)
.await?;
}
Err(e) => {
self.subscription_manager.unsubscribe(&subscription_id);
self.send_subscription_error(&subscription_id, &format!("Execution error: {}", e))
.await?;
}
}
Ok(())
}
async fn send_query_result(&mut self, result: ExecutionResult) -> Result<()> {
match result {
ExecutionResult::Select { rows, columns } => {
let fields: Vec<FieldDescription> = columns
.iter()
.enumerate()
.map(|(i, col)| FieldDescription {
name: col.name.clone(),
table_oid: 0,
column_attr_number: i as i16,
data_type_oid: 25, data_type_size: -1, type_modifier: -1,
format_code: 0, })
.collect();
self.send_row_description(fields).await?;
let row_count = rows.len();
for row in rows {
let values: Vec<Option<Vec<u8>>> = row
.values
.iter()
.map(|v: &vibesql_types::SqlValue| Some(v.to_string().as_bytes().to_vec()))
.collect();
self.send_data_row(values).await?;
}
self.send_command_complete(&format!("SELECT {}", row_count)).await?;
}
ExecutionResult::Insert { rows_affected } => {
self.send_command_complete(&format!("INSERT 0 {}", rows_affected)).await?;
}
ExecutionResult::Update { rows_affected } => {
self.send_command_complete(&format!("UPDATE {}", rows_affected)).await?;
}
ExecutionResult::Delete { rows_affected } => {
self.send_command_complete(&format!("DELETE {}", rows_affected)).await?;
}
ExecutionResult::CreateTable
| ExecutionResult::CreateIndex
| ExecutionResult::CreateView => {
self.send_command_complete("CREATE TABLE").await?;
}
ExecutionResult::DropTable
| ExecutionResult::DropIndex
| ExecutionResult::DropView => {
self.send_command_complete("DROP TABLE").await?;
}
ExecutionResult::Analyze { tables_analyzed } => {
self.send_command_complete(&format!("ANALYZE {}", tables_analyzed)).await?;
}
ExecutionResult::Other { message } => {
self.send_command_complete(&message).await?;
}
ExecutionResult::Prepare { statement_name } => {
self.send_command_complete(&format!("PREPARE {}", statement_name)).await?;
}
ExecutionResult::Deallocate { statement_name } => {
self.send_command_complete(&format!("DEALLOCATE {}", statement_name)).await?;
}
ExecutionResult::DeclareCursor { cursor_name } => {
self.send_command_complete(&format!("DECLARE CURSOR {}", cursor_name)).await?;
}
ExecutionResult::OpenCursor { cursor_name } => {
self.send_command_complete(&format!("OPEN {}", cursor_name)).await?;
}
ExecutionResult::Fetch { rows, columns } => {
let fields: Vec<FieldDescription> = columns
.iter()
.enumerate()
.map(|(i, col)| FieldDescription {
name: col.name.clone(),
table_oid: 0,
column_attr_number: i as i16,
data_type_oid: 25, data_type_size: -1, type_modifier: -1,
format_code: 0, })
.collect();
self.send_row_description(fields).await?;
let row_count = rows.len();
for row in rows {
let values: Vec<Option<Vec<u8>>> = row
.values
.iter()
.map(|v: &vibesql_types::SqlValue| Some(v.to_string().as_bytes().to_vec()))
.collect();
self.send_data_row(values).await?;
}
self.send_command_complete(&format!("FETCH {}", row_count)).await?;
}
ExecutionResult::CloseCursor { cursor_name } => {
self.send_command_complete(&format!("CLOSE {}", cursor_name)).await?;
}
}
Ok(())
}
async fn send_authentication_ok(&mut self) -> Result<()> {
BackendMessage::AuthenticationOk.encode(&mut self.write_buf);
self.flush_write_buffer().await
}
async fn send_cleartext_password_request(&mut self) -> Result<()> {
BackendMessage::AuthenticationCleartextPassword.encode(&mut self.write_buf);
self.flush_write_buffer().await
}
async fn send_md5_password_request(&mut self, salt: &[u8; 4]) -> Result<()> {
BackendMessage::AuthenticationMD5Password { salt: *salt }.encode(&mut self.write_buf);
self.flush_write_buffer().await
}
async fn send_parameter_status(&mut self, name: &str, value: &str) -> Result<()> {
BackendMessage::ParameterStatus {
name: name.to_string(),
value: value.to_string(),
}
.encode(&mut self.write_buf);
self.flush_write_buffer().await
}
async fn send_backend_key_data(&mut self) -> Result<()> {
BackendMessage::BackendKeyData {
process_id: std::process::id() as i32,
secret_key: 12345, }
.encode(&mut self.write_buf);
self.flush_write_buffer().await
}
async fn send_ready_for_query(&mut self, status: TransactionStatus) -> Result<()> {
BackendMessage::ReadyForQuery { status }.encode(&mut self.write_buf);
self.flush_write_buffer().await
}
async fn send_row_description(&mut self, fields: Vec<FieldDescription>) -> Result<()> {
BackendMessage::RowDescription { fields }.encode(&mut self.write_buf);
self.flush_write_buffer().await
}
async fn send_data_row(&mut self, values: Vec<Option<Vec<u8>>>) -> Result<()> {
BackendMessage::DataRow { values }.encode(&mut self.write_buf);
self.flush_write_buffer().await
}
async fn send_command_complete(&mut self, tag: &str) -> Result<()> {
BackendMessage::CommandComplete { tag: tag.to_string() }.encode(&mut self.write_buf);
self.flush_write_buffer().await
}
async fn send_error_response(&mut self, message: &str) -> Result<()> {
let mut fields = HashMap::new();
fields.insert(b'S', "ERROR".to_string());
fields.insert(b'C', "XX000".to_string()); fields.insert(b'M', message.to_string());
BackendMessage::ErrorResponse { fields }.encode(&mut self.write_buf);
self.flush_write_buffer().await
}
async fn send_empty_query_response(&mut self) -> Result<()> {
BackendMessage::EmptyQueryResponse.encode(&mut self.write_buf);
self.flush_write_buffer().await
}
async fn send_subscription_data(
&mut self,
subscription_id: &[u8; 16],
update_type: SubscriptionUpdateType,
rows: Vec<Vec<Option<Vec<u8>>>>,
) -> Result<()> {
BackendMessage::SubscriptionData {
subscription_id: *subscription_id,
update_type,
rows,
}
.encode(&mut self.write_buf);
self.flush_write_buffer().await
}
async fn send_subscription_error(
&mut self,
subscription_id: &[u8; 16],
message: &str,
) -> Result<()> {
BackendMessage::SubscriptionError {
subscription_id: *subscription_id,
message: message.to_string(),
}
.encode(&mut self.write_buf);
self.flush_write_buffer().await
}
async fn read_message(&mut self) -> Result<()> {
let n = self.stream.read_buf(&mut self.read_buf).await?;
if n == 0 {
return Err(anyhow::anyhow!("Connection closed"));
}
Ok(())
}
async fn flush_write_buffer(&mut self) -> Result<()> {
self.stream.write_all(&self.write_buf).await?;
self.stream.flush().await?;
self.write_buf.clear();
Ok(())
}
}
impl Drop for ConnectionHandler {
fn drop(&mut self) {
self.active_connections.fetch_sub(1, Ordering::AcqRel);
if let Some(metrics) = self.observability.metrics() {
metrics.record_connection_duration(self.connection_start.elapsed());
}
}
}