use compio::net::TcpStream;
#[cfg(unix)]
use compio::net::UnixStream;
use zerocopy::{FromBytes, FromZeros, IntoBytes};
use crate::PreparedStatement;
use crate::buffer::BufferSet;
use crate::buffer_pool::PooledBufferSet;
use crate::constant::CapabilityFlags;
use crate::error::{Error, Result};
use crate::protocol::TextRowPayload;
use crate::protocol::command::Action;
use crate::protocol::command::ColumnDefinition;
use crate::protocol::command::bulk_exec::{BulkExec, BulkFlags, BulkParamsSet, write_bulk_execute};
use crate::protocol::command::prepared::{Exec, read_prepare_ok, write_execute, write_prepare};
use crate::protocol::command::query::{Query, write_query};
use crate::protocol::command::utility::{
DropHandler, FirstHandler, write_ping, write_reset_connection,
};
use crate::protocol::connection::{Handshake, HandshakeAction, InitialHandshake};
use crate::protocol::packet::PacketHeader;
use crate::protocol::primitive::read_string_lenenc;
use crate::protocol::response::{ErrPayloadBytes, OkPayloadBytes};
use crate::protocol::r#trait::{BinaryResultSetHandler, TextResultSetHandler, param::Params};
use super::stream::Stream;
pub struct Conn {
stream: Stream,
buffer_set: PooledBufferSet,
initial_handshake: InitialHandshake,
capability_flags: CapabilityFlags,
mariadb_capabilities: crate::constant::MariadbCapabilityFlags,
in_transaction: bool,
is_broken: bool,
}
impl Conn {
pub async fn new<O: TryInto<crate::opts::Opts>>(opts: O) -> Result<Self>
where
Error: From<O::Error>,
{
let opts: crate::opts::Opts = opts.try_into()?;
#[cfg(unix)]
let stream = if let Some(socket_path) = &opts.socket {
let stream = UnixStream::connect(socket_path).await?;
Stream::unix(stream)
} else {
if opts.host.is_empty() {
return Err(Error::BadUsageError(
"Missing host in connection options".to_string(),
));
}
let addr = format!("{}:{}", opts.host, opts.port);
let stream = TcpStream::connect(&addr).await?;
stream.set_nodelay(opts.tcp_nodelay)?;
Stream::tcp(stream)
};
#[cfg(not(unix))]
let stream = {
if opts.socket.is_some() {
return Err(Error::BadUsageError(
"Unix sockets are not supported on this platform".to_string(),
));
}
if opts.host.is_empty() {
return Err(Error::BadUsageError(
"Missing host in connection options".to_string(),
));
}
let addr = format!("{}:{}", opts.host, opts.port);
let stream = TcpStream::connect(&addr).await?;
stream.set_nodelay(opts.tcp_nodelay)?;
Stream::tcp(stream)
};
Self::new_with_stream(stream, &opts).await
}
pub async fn new_with_stream(stream: Stream, opts: &crate::opts::Opts) -> Result<Self> {
let mut conn_stream = stream;
let mut buffer_set = opts.buffer_pool.get_buffer_set();
#[cfg(feature = "compio-tls")]
let host = opts.host.clone();
let mut handshake = Handshake::new(opts);
loop {
match handshake.step(&mut buffer_set)? {
HandshakeAction::ReadPacket(buffer) => {
buffer.clear();
read_payload(&mut conn_stream, buffer).await?;
}
HandshakeAction::WritePacket { sequence_id } => {
write_handshake_payload(&mut conn_stream, &mut buffer_set, sequence_id).await?;
buffer_set.read_buffer.clear();
read_payload(&mut conn_stream, &mut buffer_set.read_buffer).await?;
}
#[cfg(feature = "compio-tls")]
HandshakeAction::UpgradeTls { sequence_id } => {
write_handshake_payload(&mut conn_stream, &mut buffer_set, sequence_id).await?;
conn_stream = conn_stream.upgrade_to_tls(&host).await?;
}
#[cfg(not(feature = "compio-tls"))]
HandshakeAction::UpgradeTls { .. } => {
return Err(Error::BadUsageError(
"TLS requested but compio-tls feature is not enabled".to_string(),
));
}
HandshakeAction::Finished => break,
}
}
let (initial_handshake, capability_flags, mariadb_capabilities) = handshake.finish()?;
let conn = Self {
stream: conn_stream,
buffer_set,
initial_handshake,
capability_flags,
mariadb_capabilities,
in_transaction: false,
is_broken: false,
};
#[cfg(unix)]
let mut conn = if opts.upgrade_to_unix_socket && conn.stream.is_tcp_loopback() {
conn.try_upgrade_to_unix_socket(opts).await
} else {
conn
};
#[cfg(not(unix))]
let mut conn = conn;
if let Some(init_command) = &opts.init_command {
conn.query_drop(init_command).await?;
}
Ok(conn)
}
pub fn server_version(&self) -> &[u8] {
&self.buffer_set.initial_handshake[self.initial_handshake.server_version.clone()]
}
pub fn capability_flags(&self) -> CapabilityFlags {
self.capability_flags
}
pub fn is_mysql(&self) -> bool {
self.capability_flags.is_mysql()
}
pub fn is_mariadb(&self) -> bool {
self.capability_flags.is_mariadb()
}
pub fn connection_id(&self) -> u64 {
self.initial_handshake.connection_id as u64
}
pub fn status_flags(&self) -> crate::constant::ServerStatusFlags {
self.initial_handshake.status_flags
}
pub fn is_broken(&self) -> bool {
self.is_broken
}
#[inline]
fn check_error<T>(&mut self, result: Result<T>) -> Result<T> {
if let Err(e) = &result
&& e.is_conn_broken()
{
self.is_broken = true;
}
result
}
pub(crate) fn set_in_transaction(&mut self, value: bool) {
self.in_transaction = value;
}
pub fn in_transaction(&self) -> bool {
self.in_transaction
}
#[cfg(unix)]
async fn try_upgrade_to_unix_socket(mut self, opts: &crate::opts::Opts) -> Self {
let mut handler = SocketPathHandler { path: None };
if self.query("SELECT @@socket", &mut handler).await.is_err() {
return self;
}
let socket_path = match handler.path {
Some(p) if !p.is_empty() => p,
_ => return self,
};
let unix_stream = match UnixStream::connect(&socket_path).await {
Ok(s) => s,
Err(_) => return self,
};
let stream = Stream::unix(unix_stream);
let mut opts_unix = opts.clone();
opts_unix.upgrade_to_unix_socket = false;
match Box::pin(Self::new_with_stream(stream, &opts_unix)).await {
Ok(new_conn) => new_conn,
Err(_) => self,
}
}
async fn write_payload(&mut self) -> Result<()> {
let mut sequence_id = 0_u8;
let mut buffer = self.buffer_set.write_buffer_mut().as_mut_slice();
loop {
let chunk_size = buffer[4..].len().min(0xFFFFFF);
PacketHeader::mut_from_bytes(&mut buffer[0..4])?
.encode_in_place(chunk_size, sequence_id);
self.stream.write_all(&buffer[..4 + chunk_size]).await?;
if chunk_size < 0xFFFFFF {
break;
}
sequence_id = sequence_id.wrapping_add(1);
buffer = &mut buffer[0xFFFFFF..];
}
self.stream.flush().await?;
Ok(())
}
pub async fn prepare(&mut self, sql: &str) -> Result<PreparedStatement> {
let result = self.prepare_inner(sql).await;
self.check_error(result)
}
async fn prepare_inner(&mut self, sql: &str) -> Result<PreparedStatement> {
use crate::protocol::command::ColumnDefinitions;
self.buffer_set.read_buffer.clear();
write_prepare(self.buffer_set.new_write_buffer(), sql);
self.write_payload().await?;
let _ = read_payload(&mut self.stream, &mut self.buffer_set.read_buffer).await?;
if !self.buffer_set.read_buffer.is_empty() && self.buffer_set.read_buffer[0] == 0xFF {
Err(ErrPayloadBytes(&self.buffer_set.read_buffer))?
}
let prepare_ok = read_prepare_ok(&self.buffer_set.read_buffer)?;
let statement_id = prepare_ok.statement_id();
let num_params = prepare_ok.num_params();
let num_columns = prepare_ok.num_columns();
for _ in 0..num_params {
let _ = read_payload(&mut self.stream, &mut self.buffer_set.read_buffer).await?;
}
let column_definitions = if num_columns > 0 {
self.read_column_definition_packets(num_columns as usize)
.await?;
Some(ColumnDefinitions::new(
num_columns as usize,
std::mem::take(&mut self.buffer_set.column_definition_buffer),
)?)
} else {
None
};
let mut stmt = PreparedStatement::new(statement_id);
if let Some(col_defs) = column_definitions {
stmt.set_column_definitions(col_defs);
}
Ok(stmt)
}
async fn read_column_definition_packets(&mut self, num_columns: usize) -> Result<u8> {
let mut header = PacketHeader::new_zeroed();
let out = &mut self.buffer_set.column_definition_buffer;
out.clear();
for _ in 0..num_columns {
self.stream.read_exact(header.as_mut_bytes()).await?;
let length = header.length();
out.extend((length as u32).to_ne_bytes());
out.reserve(length);
let spare = out.spare_capacity_mut();
self.stream.read_buf_exact(&mut spare[..length]).await?;
unsafe {
out.set_len(out.len() + length);
}
}
Ok(header.sequence_id)
}
async fn drive_exec<H: BinaryResultSetHandler>(
&mut self,
stmt: &mut crate::PreparedStatement,
handler: &mut H,
) -> Result<()> {
let cache_metadata = self
.mariadb_capabilities
.contains(crate::constant::MariadbCapabilityFlags::MARIADB_CLIENT_CACHE_METADATA);
let mut exec = Exec::new(handler, stmt, cache_metadata);
loop {
match exec.step(&mut self.buffer_set)? {
Action::NeedPacket(buffer) => {
buffer.clear();
let _ = read_payload(&mut self.stream, buffer).await?;
}
Action::ReadColumnMetadata { num_columns } => {
self.read_column_definition_packets(num_columns).await?;
}
Action::Finished => return Ok(()),
}
}
}
async fn drive_query<H: TextResultSetHandler>(&mut self, handler: &mut H) -> Result<()> {
let mut query = Query::new(handler);
loop {
match query.step(&mut self.buffer_set)? {
Action::NeedPacket(buffer) => {
buffer.clear();
let _ = read_payload(&mut self.stream, buffer).await?;
}
Action::ReadColumnMetadata { num_columns } => {
self.read_column_definition_packets(num_columns).await?;
}
Action::Finished => return Ok(()),
}
}
}
pub async fn exec<P, H>(
&mut self,
stmt: &mut PreparedStatement,
params: P,
handler: &mut H,
) -> Result<()>
where
P: Params,
H: BinaryResultSetHandler,
{
let result = self.exec_inner(stmt, params, handler).await;
self.check_error(result)
}
async fn exec_inner<P, H>(
&mut self,
stmt: &mut PreparedStatement,
params: P,
handler: &mut H,
) -> Result<()>
where
P: Params,
H: BinaryResultSetHandler,
{
write_execute(self.buffer_set.new_write_buffer(), stmt.id(), params)?;
self.write_payload().await?;
self.drive_exec(stmt, handler).await
}
async fn drive_bulk_exec<H: BinaryResultSetHandler>(
&mut self,
stmt: &mut crate::PreparedStatement,
handler: &mut H,
) -> Result<()> {
let cache_metadata = self
.mariadb_capabilities
.contains(crate::constant::MariadbCapabilityFlags::MARIADB_CLIENT_CACHE_METADATA);
let mut bulk_exec = BulkExec::new(handler, stmt, cache_metadata);
loop {
match bulk_exec.step(&mut self.buffer_set)? {
Action::NeedPacket(buffer) => {
buffer.clear();
let _ = read_payload(&mut self.stream, buffer).await?;
}
Action::ReadColumnMetadata { num_columns } => {
self.read_column_definition_packets(num_columns).await?;
}
Action::Finished => return Ok(()),
}
}
}
pub async fn exec_bulk_insert_or_update<P, I, H>(
&mut self,
stmt: &mut PreparedStatement,
params: P,
flags: BulkFlags,
handler: &mut H,
) -> Result<()>
where
P: BulkParamsSet + IntoIterator<Item = I>,
I: Params,
H: BinaryResultSetHandler,
{
let result = self
.exec_bulk_insert_or_update_inner(stmt, params, flags, handler)
.await;
self.check_error(result)
}
async fn exec_bulk_insert_or_update_inner<P, I, H>(
&mut self,
stmt: &mut PreparedStatement,
params: P,
flags: BulkFlags,
handler: &mut H,
) -> Result<()>
where
P: BulkParamsSet + IntoIterator<Item = I>,
I: Params,
H: BinaryResultSetHandler,
{
if !self.is_mariadb() {
for param in params {
self.exec_inner(stmt, param, &mut DropHandler::default())
.await?;
}
Ok(())
} else {
write_bulk_execute(self.buffer_set.new_write_buffer(), stmt.id(), params, flags)?;
self.write_payload().await?;
self.drive_bulk_exec(stmt, handler).await
}
}
pub async fn exec_first<Row, P>(
&mut self,
stmt: &mut PreparedStatement,
params: P,
) -> Result<Option<Row>>
where
Row: for<'buf> crate::raw::FromRow<'buf>,
P: Params,
{
let result = self.exec_first_inner(stmt, params).await;
self.check_error(result)
}
async fn exec_first_inner<Row, P>(
&mut self,
stmt: &mut PreparedStatement,
params: P,
) -> Result<Option<Row>>
where
Row: for<'buf> crate::raw::FromRow<'buf>,
P: Params,
{
write_execute(self.buffer_set.new_write_buffer(), stmt.id(), params)?;
self.write_payload().await?;
let mut handler = FirstHandler::<Row>::default();
self.drive_exec(stmt, &mut handler).await?;
Ok(handler.take())
}
pub async fn exec_drop<P>(&mut self, stmt: &mut PreparedStatement, params: P) -> Result<()>
where
P: Params,
{
self.exec(stmt, params, &mut DropHandler::default()).await
}
pub async fn exec_collect<Row, P>(
&mut self,
stmt: &mut PreparedStatement,
params: P,
) -> Result<Vec<Row>>
where
Row: for<'buf> crate::raw::FromRow<'buf>,
P: Params,
{
let mut handler = crate::handler::CollectHandler::<Row>::default();
self.exec(stmt, params, &mut handler).await?;
Ok(handler.into_rows())
}
pub async fn exec_foreach<Row, P, F>(
&mut self,
stmt: &mut PreparedStatement,
params: P,
f: F,
) -> Result<()>
where
Row: for<'buf> crate::raw::FromRow<'buf>,
P: Params,
F: FnMut(Row) -> Result<()>,
{
let mut handler = crate::handler::ForEachHandler::<Row, F>::new(f);
self.exec(stmt, params, &mut handler).await
}
pub async fn exec_foreach_ref<Row, P, F>(
&mut self,
stmt: &mut PreparedStatement,
params: P,
f: F,
) -> Result<()>
where
Row: for<'buf> crate::ref_row::RefFromRow<'buf>,
P: Params,
F: for<'buf> FnMut(&'buf Row) -> Result<()>,
{
let mut handler = crate::handler::ForEachRefHandler::<Row, F>::new(f);
self.exec(stmt, params, &mut handler).await
}
pub async fn query<H>(&mut self, sql: &str, handler: &mut H) -> Result<()>
where
H: TextResultSetHandler,
{
let result = self.query_inner(sql, handler).await;
self.check_error(result)
}
async fn query_inner<H>(&mut self, sql: &str, handler: &mut H) -> Result<()>
where
H: TextResultSetHandler,
{
write_query(self.buffer_set.new_write_buffer(), sql);
self.write_payload().await?;
self.drive_query(handler).await
}
pub async fn query_drop(&mut self, sql: &str) -> Result<()> {
let result = self.query_drop_inner(sql).await;
self.check_error(result)
}
async fn query_drop_inner(&mut self, sql: &str) -> Result<()> {
write_query(self.buffer_set.new_write_buffer(), sql);
self.write_payload().await?;
self.drive_query(&mut DropHandler::default()).await
}
pub async fn ping(&mut self) -> Result<()> {
let result = self.ping_inner().await;
self.check_error(result)
}
async fn ping_inner(&mut self) -> Result<()> {
write_ping(self.buffer_set.new_write_buffer());
self.write_payload().await?;
self.buffer_set.read_buffer.clear();
let _ = read_payload(&mut self.stream, &mut self.buffer_set.read_buffer).await?;
Ok(())
}
pub async fn reset(&mut self) -> Result<()> {
let result = self.reset_inner().await;
self.check_error(result)
}
async fn reset_inner(&mut self) -> Result<()> {
write_reset_connection(self.buffer_set.new_write_buffer());
self.write_payload().await?;
self.buffer_set.read_buffer.clear();
let _ = read_payload(&mut self.stream, &mut self.buffer_set.read_buffer).await?;
self.in_transaction = false;
Ok(())
}
pub async fn transaction<F, R>(&mut self, f: F) -> Result<R>
where
F: std::ops::AsyncFnOnce(&mut Conn, super::transaction::Transaction) -> Result<R>,
{
if self.in_transaction {
return Err(Error::NestedTransaction);
}
self.in_transaction = true;
if let Err(err) = self.query_drop("BEGIN").await {
self.in_transaction = false;
return Err(err);
}
let tx = super::transaction::Transaction::new(self.connection_id());
let result = f(self, tx).await;
if self.in_transaction {
self.in_transaction = false;
match &result {
Ok(_) => self.query_drop("COMMIT").await?,
Err(_) => {
let _ = self.query_drop("ROLLBACK").await;
}
}
}
result
}
}
async fn read_payload(reader: &mut Stream, buffer: &mut Vec<u8>) -> Result<u8> {
let mut packet_header = PacketHeader::new_zeroed();
buffer.clear();
reader.read_exact(packet_header.as_mut_bytes()).await?;
let length = packet_header.length();
let mut sequence_id = packet_header.sequence_id;
buffer.reserve(length);
{
let spare = buffer.spare_capacity_mut();
reader.read_buf_exact(&mut spare[..length]).await?;
unsafe {
buffer.set_len(length);
}
}
let mut current_length = length;
while current_length == 0xFFFFFF {
reader.read_exact(packet_header.as_mut_bytes()).await?;
current_length = packet_header.length();
sequence_id = packet_header.sequence_id;
buffer.reserve(current_length);
let spare = buffer.spare_capacity_mut();
reader.read_buf_exact(&mut spare[..current_length]).await?;
unsafe {
buffer.set_len(buffer.len() + current_length);
}
}
Ok(sequence_id)
}
async fn write_handshake_payload(
stream: &mut Stream,
buffer_set: &mut BufferSet,
sequence_id: u8,
) -> Result<()> {
let mut buffer = buffer_set.write_buffer_mut().as_mut_slice();
let mut seq_id = sequence_id;
loop {
let chunk_size = buffer[4..].len().min(0xFFFFFF);
PacketHeader::mut_from_bytes(&mut buffer[0..4])?.encode_in_place(chunk_size, seq_id);
stream.write_all(&buffer[..4 + chunk_size]).await?;
if chunk_size < 0xFFFFFF {
break;
}
seq_id = seq_id.wrapping_add(1);
buffer = &mut buffer[0xFFFFFF..];
}
stream.flush().await?;
Ok(())
}
#[cfg(unix)]
struct SocketPathHandler {
path: Option<String>,
}
#[cfg(unix)]
impl TextResultSetHandler for SocketPathHandler {
fn no_result_set(&mut self, _: OkPayloadBytes) -> Result<()> {
Ok(())
}
fn resultset_start(&mut self, _: &[ColumnDefinition<'_>]) -> Result<()> {
Ok(())
}
fn resultset_end(&mut self, _: OkPayloadBytes) -> Result<()> {
Ok(())
}
fn row(&mut self, _: &[ColumnDefinition<'_>], row: TextRowPayload<'_>) -> Result<()> {
if row.0.first() == Some(&0xFB) {
return Ok(());
}
let (value, _) = read_string_lenenc(row.0)?;
if !value.is_empty() {
self.path = Some(String::from_utf8_lossy(value).into_owned());
}
Ok(())
}
}