use std::{
borrow::Cow,
fmt::Display,
marker::PhantomData,
mem::ManuallyDrop,
net::{IpAddr, Ipv4Addr, SocketAddr, ToSocketAddrs},
ops::Range,
};
use bytes::{Buf, BufMut, BytesMut};
use thiserror::Error;
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
net::{TcpSocket, UnixStream},
};
use crate::{
args::Args,
auth::{AuthPlugin, compute_auth},
bind::{Bind, BindError},
constants::{client, com},
decode::Column,
lru::{Entry, LRUCache},
package_parser::{DecodeError, DecodeResult, PackageParser},
row::{FromRow, Row},
};
#[derive(Error, Debug)]
#[non_exhaustive]
pub enum ConnectionErrorContent {
#[error("mysql error {code}: {message}")]
Mysql {
code: u16,
status: [u8; 5],
message: String,
},
#[error(transparent)]
Io(#[from] tokio::io::Error),
#[error("error reading {0}: {1}")]
Decode(&'static str, DecodeError),
#[error("error binding paramater {0}: {1}")]
Bind(u16, BindError),
#[error("protocol error {0}")]
ProtocolError(String),
#[error("fetch return no columns")]
ExpectedRows,
#[error("rows return for execute")]
UnexpectedRows,
#[cfg(feature = "cancel_testing")]
#[doc(hidden)]
#[error("await threshold reached")]
TestCancelled,
#[error("await threshold reached")]
TooFewListArguments,
#[error("await threshold reached")]
TooManyListArguments,
#[error("Invalid url")]
InvalidUrl,
}
pub struct ConnectionError(Box<ConnectionErrorContent>);
const _: () = {
assert!(size_of::<ConnectionError>() == size_of::<usize>());
};
impl ConnectionError {
pub fn content(&self) -> &ConnectionErrorContent {
&self.0
}
}
impl std::ops::Deref for ConnectionError {
type Target = ConnectionErrorContent;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<E: Into<ConnectionErrorContent>> From<E> for ConnectionError {
fn from(value: E) -> Self {
ConnectionError(Box::new(value.into()))
}
}
impl std::fmt::Debug for ConnectionError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
std::fmt::Debug::fmt(&self.0, f)
}
}
impl std::fmt::Display for ConnectionError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
std::fmt::Display::fmt(&self.0, f)
}
}
impl std::error::Error for ConnectionError {}
pub type ConnectionResult<T> = std::result::Result<T, ConnectionError>;
pub trait WithLoc<T> {
fn loc(self, loc: &'static str) -> ConnectionResult<T>;
}
impl<T> WithLoc<T> for DecodeResult<T> {
fn loc(self, loc: &'static str) -> ConnectionResult<T> {
self.map_err(|v| ConnectionErrorContent::Decode(loc, v).into())
}
}
trait Except {
type Value;
fn ev(self, loc: &'static str, expected: Self::Value) -> ConnectionResult<()>;
}
impl<T: Eq + Display> Except for DecodeResult<T> {
type Value = T;
fn ev(self, loc: &'static str, expected: T) -> ConnectionResult<()> {
let v = self.loc(loc)?;
if v != expected {
Err(ConnectionErrorContent::ProtocolError(format!(
"Expected {expected} for {loc} got {v}"
))
.into())
} else {
Ok(())
}
}
}
pub trait RowMap<'a> {
type E: From<ConnectionError>;
type T;
fn map(row: Row<'a>) -> Result<Self::T, Self::E>;
}
struct FromRowMapper<T>(PhantomData<T>);
impl<'a, T: FromRow<'a>> RowMap<'a> for FromRowMapper<T> {
type E = ConnectionError;
type T = T;
fn map(row: Row<'a>) -> Result<Self::T, Self::E> {
T::from_row(&row).loc("row")
}
}
enum OwnedReadHalf {
Tcp(tokio::net::tcp::OwnedReadHalf),
Unix(tokio::net::unix::OwnedReadHalf),
}
enum OwnedWriteHalf {
Tcp(tokio::net::tcp::OwnedWriteHalf),
Unix(tokio::net::unix::OwnedWriteHalf),
}
struct Reader {
buff: BytesMut,
read: OwnedReadHalf,
skip_on_read: usize,
buffer_packages: bool,
}
impl Reader {
fn new(read: OwnedReadHalf) -> Self {
Self {
read,
buff: BytesMut::with_capacity(1234),
skip_on_read: 0,
buffer_packages: false,
}
}
async fn read_raw(&mut self) -> ConnectionResult<Range<usize>> {
if !self.buffer_packages {
self.buff.advance(self.skip_on_read);
self.skip_on_read = 0;
}
while self.buff.remaining() < 4 + self.skip_on_read {
match &mut self.read {
OwnedReadHalf::Tcp(r) => r.read_buf(&mut self.buff).await?,
OwnedReadHalf::Unix(r) => r.read_buf(&mut self.buff).await?,
};
}
let y: u32 = u32::from_le_bytes(
self.buff[self.skip_on_read..self.skip_on_read + 4]
.try_into()
.unwrap(),
);
let len: usize = (y & 0xFFFFFF).try_into().unwrap();
let _s = (y >> 24) as u8;
if len == 0xFFFFFF {
return Err(ConnectionErrorContent::ProtocolError(
"Extended packages not supported".to_string(),
)
.into());
}
while self.buff.remaining() < self.skip_on_read + 4 + len {
match &mut self.read {
OwnedReadHalf::Tcp(r) => r.read_buf(&mut self.buff).await?,
OwnedReadHalf::Unix(r) => r.read_buf(&mut self.buff).await?,
};
}
let r = self.skip_on_read + 4..self.skip_on_read + 4 + len;
self.skip_on_read += 4 + len;
Ok(r)
}
#[inline]
async fn read(&mut self) -> ConnectionResult<&[u8]> {
let r = self.read_raw().await?;
Ok(self.bytes(r))
}
#[inline]
fn bytes(&self, r: Range<usize>) -> &[u8] {
&self.buff[r]
}
}
struct Writer {
write: OwnedWriteHalf,
buff: BytesMut,
seq: u8,
}
impl Writer {
fn new(write: OwnedWriteHalf) -> Self {
Writer {
write,
buff: BytesMut::with_capacity(1234),
seq: 1,
}
}
fn compose(&mut self) -> Composer<'_> {
self.buff.clear();
self.buff.put_u32(0);
Composer { writer: self }
}
async fn send(&mut self) -> ConnectionResult<()> {
match &mut self.write {
OwnedWriteHalf::Tcp(r) => Ok(r.write_all_buf(&mut self.buff).await?),
OwnedWriteHalf::Unix(r) => Ok(r.write_all_buf(&mut self.buff).await?),
}
}
}
struct Composer<'a> {
writer: &'a mut Writer,
}
impl<'a> Composer<'a> {
fn put_u32(&mut self, v: u32) {
self.writer.buff.put_u32_le(v)
}
fn put_u16(&mut self, v: u16) {
self.writer.buff.put_u16_le(v)
}
fn put_u8(&mut self, v: u8) {
self.writer.buff.put_u8(v)
}
fn put_str_null(&mut self, s: &str) {
self.writer.buff.put(s.as_bytes());
self.writer.buff.put_u8(0);
}
fn put_bytes(&mut self, s: &[u8]) {
self.writer.buff.put(s);
}
fn finalize(self) {
let len = self.writer.buff.len();
let mut x = &mut self.writer.buff[..4];
x.put_u32_le((len - 4) as u32 | ((self.writer.seq as u32) << 24));
self.writer.seq = self.writer.seq.wrapping_add(1);
}
}
pub struct ConnectionOptions<'a> {
address: SocketAddr,
user: Cow<'a, str>,
password: Cow<'a, str>,
database: Option<Cow<'a, str>>,
statement_case_size: usize,
unix_socket: Option<Cow<'a, std::path::Path>>,
}
impl<'a> ConnectionOptions<'a> {
pub fn new() -> ConnectionOptions<'a> {
Default::default()
}
pub fn into_owned(self) -> ConnectionOptions<'static> {
ConnectionOptions {
address: self.address,
user: self.user.into_owned().into(),
password: self.password.into_owned().into(),
database: self.database.map(|v| v.into_owned().into()),
statement_case_size: self.statement_case_size,
unix_socket: self.unix_socket.map(|v| v.into_owned().into()),
}
}
pub fn from_url(url: &'a str) -> Result<Self, ConnectionError> {
let Some(v) = url.strip_prefix("mysql://") else {
return Err(ConnectionErrorContent::InvalidUrl.into());
};
let (authority, path) = v
.split_once('/')
.map(|(a, b)| (a, Some(b)))
.unwrap_or((v, None));
let (user_info, address) = authority
.split_once('@')
.map(|(a, b)| (Some(a), b))
.unwrap_or((None, authority));
let (user, password) = user_info
.map(|v| {
v.split_once(':')
.map(|(a, b)| (Some(a), Some(b)))
.unwrap_or((Some(v), None))
})
.unwrap_or_default();
let (host, port) = address
.rsplit_once(':')
.map(|(a, b)| (a, Some(b)))
.unwrap_or((address, None));
let port: u16 = match port {
Some(v) => v.parse().map_err(|_| ConnectionErrorContent::InvalidUrl)?,
None => 3306,
};
let (db, unix_socket) = path
.map(|v| {
v.split_once("?socket=")
.map(|(a, b)| (Some(a), Some(std::path::Path::new(b).into())))
.unwrap_or((Some(v), None))
})
.unwrap_or_default();
let mut addrs = (host, port).to_socket_addrs()?;
let Some(address) = addrs.next() else {
return Err(ConnectionErrorContent::InvalidUrl.into());
};
Ok(ConnectionOptions {
address,
user: user.unwrap_or("root").into(),
password: password.unwrap_or("password").into(),
database: db.map(|v| v.into()),
unix_socket,
..Default::default()
})
}
pub fn user(self, user: impl Into<Cow<'a, str>>) -> Self {
Self {
user: user.into(),
..self
}
}
pub fn password(self, password: impl Into<Cow<'a, str>>) -> Self {
Self {
password: password.into(),
..self
}
}
pub fn database(self, database: impl Into<Cow<'a, str>>) -> Self {
Self {
database: Some(database.into()),
..self
}
}
pub fn address(self, address: impl std::net::ToSocketAddrs) -> Result<Self, std::io::Error> {
match address.to_socket_addrs()?.next() {
Some(v) => Ok(Self { address: v, ..self }),
None => Err(std::io::Error::new(
std::io::ErrorKind::NotFound,
"No host resolved",
)),
}
}
pub fn unix_socket(self, path: impl Into<Cow<'a, std::path::Path>>) -> Self {
Self {
unix_socket: Some(path.into()),
..self
}
}
pub fn statment_case_size(self, size: usize) -> Self {
Self {
statement_case_size: size,
..self
}
}
}
impl<'a> Default for ConnectionOptions<'a> {
fn default() -> Self {
Self {
address: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 3306),
user: Cow::Borrowed("root"),
password: Cow::Borrowed("password"),
database: None,
statement_case_size: 1024,
unix_socket: None,
}
}
}
#[derive(Debug)]
#[non_exhaustive]
pub struct ColumnDefinition<'a> {
pub schema: &'a str,
pub table_alias: &'a str,
pub table: &'a str,
pub column_alias: &'a str,
pub column: &'a str,
pub extended_matadata: Option<&'a str>,
pub character_set_number: u16,
pub max_column_size: u32,
pub field_types: u8,
pub field_detail_flags: u16,
pub decimals: u8,
}
impl<'a> ColumnDefinition<'a> {
fn new(data: &'a [u8]) -> DecodeResult<Self> {
let mut p = PackageParser::new(data);
p.skip_lenenc_str()?; let schema = p.get_lenenc_str()?;
let table_alias = p.get_lenenc_str()?;
let table = p.get_lenenc_str()?;
let column_alias = p.get_lenenc_str()?;
let column = p.get_lenenc_str()?;
let extended_matadata = None;
p.get_lenenc()?;
let character_set_number = p.get_u16()?;
let max_column_size = p.get_u32()?;
let field_types = p.get_u8()?;
let field_detail_flags = p.get_u16()?;
let decimals = p.get_u8()?;
Ok(ColumnDefinition {
schema,
table_alias,
table,
column_alias,
column,
extended_matadata,
character_set_number,
max_column_size,
field_types,
field_detail_flags,
decimals,
})
}
}
pub struct ColumnsInformation<'a> {
data: &'a [u8],
ranges: &'a [Range<usize>],
}
impl<'a> ColumnsInformation<'a> {
pub fn get(&self, idx: usize) -> Option<DecodeResult<ColumnDefinition<'a>>> {
self.ranges
.get(idx)
.map(|v| ColumnDefinition::new(&self.data[v.clone()]))
}
}
impl<'a> Iterator for ColumnsInformation<'a> {
type Item = DecodeResult<ColumnDefinition<'a>>;
fn next(&mut self) -> Option<Self::Item> {
match self.ranges.split_off_first() {
Some(v) => Some(ColumnDefinition::new(&self.data[v.clone()])),
None => None,
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
(self.ranges.len(), Some(self.ranges.len()))
}
fn nth(&mut self, n: usize) -> Option<Self::Item> {
self.get(n)
}
}
impl<'a> ExactSizeIterator for ColumnsInformation<'a> {
fn len(&self) -> usize {
self.ranges.len()
}
}
impl StatementInformation {
pub fn columns(&self) -> ColumnsInformation<'_> {
ColumnsInformation {
data: &self.info,
ranges: &self.ranges[self.num_params as usize..],
}
}
pub fn parameters(&self) -> ColumnsInformation<'_> {
ColumnsInformation {
data: &self.info,
ranges: &self.ranges[..self.num_params as usize],
}
}
}
pub struct StatementInformation {
num_params: u16,
info: Vec<u8>,
ranges: Vec<Range<usize>>,
}
struct Statement {
stmt_id: u32,
num_params: u16,
information: Option<StatementInformation>,
}
pub struct QueryIterator<'a> {
connection: &'a mut RawConnection,
}
impl<'a> QueryIterator<'a> {
pub async fn next(&mut self) -> ConnectionResult<Option<Row<'_>>> {
match self.connection.state {
ConnectionState::Clean => return Ok(None),
ConnectionState::QueryReadRows => (),
_ => panic!("Logic error"),
}
self.connection.test_cancel()?;
let start_instant = self.connection.stats.get_instant();
let package = self.connection.reader.read().await?;
self.connection.stats.add_fetch(start_instant);
let mut pp = PackageParser::new(package);
match pp.get_u8().loc("Row first byte")? {
0x00 => Ok(Some(Row::new(&self.connection.columns, package))),
0xFE => {
self.connection.state = ConnectionState::Clean;
Ok(None)
}
0xFF => {
handle_mysql_error(&mut pp)?;
unreachable!()
}
v => Err(ConnectionErrorContent::ProtocolError(format!(
"Unexpected response type {v} to row package"
))
.into()),
}
}
}
pub struct MapQueryIterator<'a, M>
where
for<'b> M: RowMap<'b>,
{
connection: &'a mut RawConnection,
_phantom: PhantomData<M>,
}
impl<'a, M> MapQueryIterator<'a, M>
where
for<'b> M: RowMap<'b>,
{
pub async fn next<'b>(
&'b mut self,
) -> Result<Option<<M as RowMap<'b>>::T>, <M as RowMap<'b>>::E> {
match self.connection.state {
ConnectionState::Clean => return Ok(None),
ConnectionState::QueryReadRows => (),
_ => panic!("Logic error"),
}
self.connection.test_cancel()?;
let start_instant = self.connection.stats.get_instant();
let package = self.connection.reader.read().await?;
self.connection.stats.add_fetch(start_instant);
let mut pp = PackageParser::new(package);
match pp.get_u8().loc("Row first byte")? {
0x00 => Ok(Some(M::map(Row::new(&self.connection.columns, package))?)),
0xFE => {
self.connection.state = ConnectionState::Clean;
Ok(None)
}
0xFF => {
handle_mysql_error(&mut pp)?;
unreachable!()
}
v => Err(
ConnectionError::from(ConnectionErrorContent::ProtocolError(format!(
"Unexpected response type {v} to row package"
)))
.into(),
),
}
}
}
pub struct ExecuteResult {
affected_rows: u64,
last_insert_id: u64,
}
impl ExecuteResult {
pub fn affected_rows(&self) -> u64 {
self.affected_rows
}
pub fn last_insert_id(&self) -> u64 {
self.last_insert_id
}
}
enum QueryResult {
WithColumns,
ExecuteResult(ExecuteResult),
}
#[derive(Clone, Copy, Debug)]
enum ConnectionState {
Clean,
PrepareStatementSend,
PrepareStatementReadHead,
PrepareStatementReadParams {
params: u16,
columns: u16,
stmt_id: u32,
},
ClosePreparedStatement,
QuerySend,
QueryReadHead,
QueryReadColumns(u64),
QueryReadRows,
UnpreparedSend,
UnpreparedRecv,
PingSend,
PingRecv,
Broken,
}
trait IStats {
type Instant: Sized;
fn get_instant(&self) -> Self::Instant;
fn add_prepare(&mut self, start_instant: Self::Instant);
fn add_execute(&mut self, start_instant: Self::Instant);
fn add_fetch(&mut self, start_instant: Self::Instant);
}
#[allow(unused)]
#[derive(Default, Debug)]
pub struct Stats {
pub prepare_counts: usize,
pub prepare_time: std::time::Duration,
pub execute_counts: usize,
pub execute_time: std::time::Duration,
pub fetch_time: std::time::Duration,
}
impl IStats for Stats {
type Instant = std::time::Instant;
fn get_instant(&self) -> Self::Instant {
std::time::Instant::now()
}
fn add_prepare(&mut self, start_instant: Self::Instant) {
self.prepare_counts += 1;
self.prepare_time += start_instant.elapsed();
}
fn add_execute(&mut self, start_instant: Self::Instant) {
self.execute_counts += 1;
self.execute_time += start_instant.elapsed()
}
fn add_fetch(&mut self, start_instant: Self::Instant) {
self.fetch_time += start_instant.elapsed();
}
}
#[allow(unused)]
#[derive(Default)]
struct NoStats;
impl IStats for NoStats {
type Instant = NoStats;
fn get_instant(&self) -> Self::Instant {
NoStats
}
fn add_prepare(&mut self, _: Self::Instant) {}
fn add_execute(&mut self, _: Self::Instant) {}
fn add_fetch(&mut self, _: Self::Instant) {}
}
struct RawConnection {
reader: Reader,
writer: Writer,
state: ConnectionState,
columns: Vec<Column>,
ranges: Vec<Range<usize>>,
#[cfg(feature = "cancel_testing")]
cancel_count: Option<usize>,
#[cfg(feature = "stats")]
stats: Stats,
#[cfg(not(feature = "stats"))]
stats: NoStats,
#[cfg(feature = "list_hack")]
list_lengths: Vec<usize>,
}
fn parse_column_definition(p: &mut PackageParser) -> ConnectionResult<Column> {
p.skip_lenenc_str().loc("catalog")?;
p.skip_lenenc_str().loc("schema")?;
p.skip_lenenc_str().loc("table")?;
p.skip_lenenc_str().loc("org_table")?;
p.skip_lenenc_str().loc("name")?;
p.skip_lenenc_str().loc("org_name")?;
p.get_lenenc().loc("length of fixed length fields")?;
let character_set = p.get_u16().loc("character_set")?;
p.get_u32().loc("column_length")?;
let r#type = p.get_u8().loc("type")?;
let flags = p.get_u16().loc("flags")?;
p.get_u8().loc("decimals")?;
p.get_u16().loc("res")?;
Ok(Column {
r#type,
flags,
character_set,
})
}
fn handle_mysql_error(pp: &mut PackageParser) -> ConnectionResult<std::convert::Infallible> {
let code = pp.get_u16().loc("code")?;
pp.get_u8().ev("sharp", b'#')?;
let a = pp.get_u8().loc("status0")?;
let b = pp.get_u8().loc("status1")?;
let c = pp.get_u8().loc("status2")?;
let d = pp.get_u8().loc("status3")?;
let e = pp.get_u8().loc("status4")?;
let msg = pp.get_eof_str().loc("message")?;
Err(ConnectionErrorContent::Mysql {
code,
status: [a, b, c, d, e],
message: msg.to_string(),
}
.into())
}
fn begin_transaction_query(depth: usize) -> Cow<'static, str> {
match depth {
0 => "BEGIN".into(),
1 => "SAVEPOINT _sqly_savepoint_1".into(),
2 => "SAVEPOINT _sqly_savepoint_2".into(),
3 => "SAVEPOINT _sqly_savepoint_3".into(),
v => format!("SAVEPOINT _sqly_savepoint_{}", v).into(),
}
}
fn commit_transaction_query(depth: usize) -> Cow<'static, str> {
match depth {
0 => "COMMIT".into(),
1 => "RELEASE SAVEPOINT _sqly_savepoint_1".into(),
2 => "RELEASE SAVEPOINT _sqly_savepoint_2".into(),
3 => "RELEASE SAVEPOINT _sqly_savepoint_3".into(),
v => format!("RELEASE SAVEPOINT _sqly_savepoint_{}", v).into(),
}
}
fn rollback_transaction_query(depth: usize) -> Cow<'static, str> {
match depth {
0 => "ROLLBACK".into(),
1 => "ROLLBACK TO SAVEPOINT _sqly_savepoint_1".into(),
2 => "ROLLBACK TO SAVEPOINT _sqly_savepoint_2".into(),
3 => "ROLLBACK TO SAVEPOINT _sqly_savepoint_3".into(),
v => format!("RELEASE TO SAVEPOINT _sqly_savepoint_{}", v).into(),
}
}
impl RawConnection {
async fn connect(options: &ConnectionOptions<'_>) -> ConnectionResult<Self> {
let (read, write) = if let Some(path) = &options.unix_socket {
let socket = UnixStream::connect(path).await?;
let (read, write) = socket.into_split();
(OwnedReadHalf::Unix(read), OwnedWriteHalf::Unix(write))
} else {
let stream = if options.address.is_ipv4() {
let socket = TcpSocket::new_v4()?;
socket.connect(options.address).await?
} else {
let socket = TcpSocket::new_v6()?;
socket.connect(options.address).await?
};
let (read, write) = stream.into_split();
(OwnedReadHalf::Tcp(read), OwnedWriteHalf::Tcp(write))
};
let mut reader = Reader::new(read);
let mut writer = Writer::new(write);
let package = reader.read().await?;
let mut p = PackageParser::new(package);
p.get_u8().ev("protocol version", 10)?;
p.skip_null_str().loc("status")?;
let _wthread_id = p.get_u32().loc("thread_id")?;
let nonce1 = p.get_bytes(8).loc("nonce1")?;
p.get_u8().ev("nonce1_end", 0)?;
let capability_flags_1 = p.get_u16().loc("capability_flags_1")?;
let _character_set = p.get_u8().loc("character_set")?;
p.get_u16().loc("status_flags")?;
let capability_flags_2 = p.get_u16().loc("capability_flags_2")?;
let auth_plugin_data_len = p.get_u8().loc("auth_plugin_data_len")?;
let _capability_flags = capability_flags_1 as u32 | (capability_flags_2 as u32) << 16;
p.get_bytes(10).loc("reserved")?;
let nonce2 = p
.get_bytes(auth_plugin_data_len as usize - 9)
.loc("nonce2")?;
p.get_u8().ev("nonce2_end", 0)?;
let auth_plugin = p.get_null_str().loc("auth_plugin")?;
let auth_method = match auth_plugin {
"mysql_native_password" => AuthPlugin::NativePassword,
#[cfg(feature = "sha2_auth")]
"caching_sha2_password" => AuthPlugin::CachingSha2Password,
v => {
return Err(ConnectionErrorContent::ProtocolError(format!(
"Unhandled auth plugin {v}"
))
.into());
}
};
let mut p = writer.compose();
let mut opts = client::LONG_PASSWORD
| client::LONG_FLAG
| client::LOCAL_FILES
| client::PROTOCOL_41
| client::DEPRECATE_EOF
| client::TRANSACTIONS
| client::SECURE_CONNECTION
| client::MULTI_STATEMENTS
| client::MULTI_RESULTS
| client::PS_MULTI_RESULTS
| client::PLUGIN_AUTH;
if options.database.is_some() {
opts |= client::CONNECT_WITH_DB
}
p.put_u32(opts);
p.put_u32(0x1000000); p.put_u16(45); for _ in 0..22 {
p.put_u8(0);
}
p.put_str_null(&options.user);
let mut nonce = Vec::with_capacity(nonce1.len() + nonce2.len());
nonce.extend_from_slice(nonce1);
nonce.extend_from_slice(nonce2);
let auth = compute_auth(&options.password, &nonce, auth_method);
let auth = auth.as_slice();
p.put_u8(auth.len() as u8);
for v in auth {
p.put_u8(*v);
}
if let Some(database) = &options.database {
p.put_str_null(database);
}
p.put_str_null(auth_plugin);
p.finalize();
writer.send().await?;
loop {
let p = reader.read().await?;
let mut pp = PackageParser::new(p);
match pp.get_u8().loc("response type")? {
0xFF => {
handle_mysql_error(&mut pp)?;
}
0x00 => {
let _rows = pp.get_lenenc().loc("rows")?;
let _last_inserted_id = pp.get_lenenc().loc("last_inserted_id")?;
break;
}
0xFE => {
return Err(ConnectionErrorContent::ProtocolError(
"Unexpected auth switch".into(),
)
.into());
}
#[cfg(feature = "sha2_auth")]
0x01 if matches!(auth_method, AuthPlugin::CachingSha2Password) => {
match pp.get_u8().loc("auth_status")? {
0x03 => break,
0x04 => {
writer.seq = 3;
let mut p = writer.compose();
p.put_u8(0x02);
p.finalize();
writer.send().await?;
let p = reader.read().await?;
let mut pp = PackageParser::new(p);
pp.get_u8().ev("first", 1)?;
let pem = pp.get_eof_str().loc("pem")?;
let pwd = crate::auth::encrypt_rsa(pem, &options.password, &nonce)?;
writer.seq = 5;
let mut p = writer.compose();
p.put_bytes(&pwd);
p.finalize();
writer.send().await?;
}
v => {
return Err(ConnectionErrorContent::ProtocolError(format!(
"Unexpected auth status {v} to handshake response"
))
.into());
}
}
}
v => {
return Err(ConnectionErrorContent::ProtocolError(format!(
"Unexpected response type {v} to handshake response"
))
.into());
}
}
}
writer.seq = 0;
Ok(RawConnection {
reader,
writer,
state: ConnectionState::Clean,
columns: Vec::new(),
ranges: Vec::new(),
#[cfg(feature = "cancel_testing")]
cancel_count: None,
stats: Default::default(),
#[cfg(feature = "list_hack")]
list_lengths: Vec::new(),
})
}
#[inline]
fn test_cancel(&mut self) -> ConnectionResult<()> {
#[cfg(feature = "cancel_testing")]
if let Some(v) = &mut self.cancel_count {
if *v == 0 {
return Err(ConnectionErrorContent::TestCancelled.into());
}
*v -= 1;
}
Ok(())
}
async fn cleanup(&mut self) -> ConnectionResult<()> {
loop {
match self.state {
ConnectionState::Clean => break,
ConnectionState::PrepareStatementSend => {
self.test_cancel()?;
self.writer.send().await?;
self.state = ConnectionState::PrepareStatementReadHead;
continue;
}
ConnectionState::PrepareStatementReadHead => {
self.test_cancel()?;
let package = self.reader.read().await?;
let mut p = PackageParser::new(package);
match p.get_u8().loc("response type")? {
0 => {
let stmt_id = p.get_u32().loc("stmt_id")?;
let columns = p.get_u16().loc("num_columns")?;
let params = p.get_u16().loc("num_params")?;
self.state = ConnectionState::PrepareStatementReadParams {
params,
columns,
stmt_id,
};
continue;
}
255 => {
self.state = ConnectionState::Clean;
}
v => {
self.state = ConnectionState::Broken;
return Err(ConnectionErrorContent::ProtocolError(format!(
"Unexpected response type {v} to prepare statement"
))
.into());
}
}
}
ConnectionState::PrepareStatementReadParams {
params: 0,
columns: 0,
stmt_id,
} => {
self.writer.seq = 0;
let mut p = self.writer.compose();
p.put_u8(com::STMT_CLOSE);
p.put_u32(stmt_id);
p.finalize();
self.state = ConnectionState::ClosePreparedStatement;
}
ConnectionState::PrepareStatementReadParams {
params: 0,
columns,
stmt_id,
} => {
self.test_cancel()?;
self.reader.read().await?;
self.state = ConnectionState::PrepareStatementReadParams {
params: 0,
columns: columns - 1,
stmt_id,
};
}
ConnectionState::PrepareStatementReadParams {
params,
columns,
stmt_id,
} => {
self.test_cancel()?;
self.reader.read().await?;
self.state = ConnectionState::PrepareStatementReadParams {
params: params - 1,
columns,
stmt_id,
};
}
ConnectionState::ClosePreparedStatement => {
self.test_cancel()?;
self.writer.send().await?;
self.state = ConnectionState::Clean;
}
ConnectionState::QuerySend => {
self.test_cancel()?;
self.writer.send().await?;
self.state = ConnectionState::QueryReadHead;
}
ConnectionState::QueryReadHead => {
self.test_cancel()?;
let package = self.reader.read().await?;
{
let mut pp = PackageParser::new(package);
match pp.get_u8().loc("first_byte")? {
255 | 0 => {
self.state = ConnectionState::Clean;
continue;
}
_ => (),
}
}
let column_count = PackageParser::new(package)
.get_lenenc()
.loc("column_count")?;
self.state = ConnectionState::QueryReadColumns(column_count)
}
ConnectionState::QueryReadColumns(0) => {
self.state = ConnectionState::QueryReadRows;
}
ConnectionState::QueryReadColumns(cnt) => {
self.test_cancel()?;
self.reader.read().await?;
self.state = ConnectionState::QueryReadColumns(cnt - 1);
}
ConnectionState::QueryReadRows => {
self.test_cancel()?;
let package = self.reader.read().await?;
let mut pp = PackageParser::new(package);
match pp.get_u8().loc("Row first byte")? {
0x00 => (),
0xFE => {
self.state = ConnectionState::Clean;
}
0xFF => {
self.state = ConnectionState::Broken;
handle_mysql_error(&mut pp)?;
unreachable!()
}
v => {
self.state = ConnectionState::Broken;
return Err(ConnectionErrorContent::ProtocolError(format!(
"Unexpected response type {v} to row package"
))
.into());
}
}
}
ConnectionState::UnpreparedSend => {
self.test_cancel()?;
self.writer.send().await?;
self.state = ConnectionState::QueryReadHead;
}
ConnectionState::UnpreparedRecv => {
self.test_cancel()?;
let package = self.reader.read().await?;
let mut pp = PackageParser::new(package);
match pp.get_u8().loc("first_byte")? {
255 => {
self.state = ConnectionState::Broken;
handle_mysql_error(&mut pp)?;
unreachable!()
}
0 => {
self.state = ConnectionState::Clean;
return Ok(());
}
v => {
self.state = ConnectionState::Broken;
return Err(ConnectionErrorContent::ProtocolError(format!(
"Unexpected response type {v} to row package"
))
.into());
}
}
}
ConnectionState::PingSend => {
self.test_cancel()?;
self.writer.send().await?;
self.state = ConnectionState::PingRecv;
}
ConnectionState::PingRecv => {
self.test_cancel()?;
let package = self.reader.read().await?;
let mut pp = PackageParser::new(package);
match pp.get_u8().loc("first_byte")? {
255 => {
self.state = ConnectionState::Broken;
handle_mysql_error(&mut pp)?;
unreachable!()
}
0 => {
self.state = ConnectionState::Clean;
return Ok(());
}
v => {
self.state = ConnectionState::Broken;
return Err(ConnectionErrorContent::ProtocolError(format!(
"Unexpected response type {v} to ping"
))
.into());
}
}
}
ConnectionState::Broken => {
return Err(ConnectionErrorContent::ProtocolError(
"Previous protocol error reported".to_string(),
)
.into());
}
}
}
Ok(())
}
async fn prepare_query(&mut self, stmt: &str, with_info: bool) -> ConnectionResult<Statement> {
assert!(matches!(self.state, ConnectionState::Clean));
self.writer.seq = 0;
let mut p = self.writer.compose();
p.put_u8(com::STMT_PREPARE);
p.put_bytes(stmt.as_bytes());
p.finalize();
let start_instant = self.stats.get_instant();
self.state = ConnectionState::PrepareStatementSend;
self.test_cancel()?;
self.writer.send().await?;
self.state = ConnectionState::PrepareStatementReadHead;
self.test_cancel()?;
let package = self.reader.read().await?;
let mut p = PackageParser::new(package);
match p.get_u8().loc("response type")? {
0 => {
let stmt_id = p.get_u32().loc("stmt_id")?;
let num_columns = p.get_u16().loc("num_columns")?;
let num_params = p.get_u16().loc("num_params")?;
let mut info_bytes: Vec<_> = Vec::new();
let mut info_ranges = Vec::new();
for p in 0..num_params {
self.state = ConnectionState::PrepareStatementReadParams {
params: num_params - p,
columns: num_columns,
stmt_id,
};
self.test_cancel()?;
let pkg = self.reader.read().await?;
if with_info {
let start = info_bytes.len();
info_bytes.extend(pkg);
info_ranges.push(start..info_bytes.len())
}
}
for c in 0..num_columns {
self.state = ConnectionState::PrepareStatementReadParams {
params: 0,
columns: num_columns - c,
stmt_id,
};
self.test_cancel()?;
let pkg = self.reader.read().await?;
if with_info {
let start = info_bytes.len();
info_bytes.extend(pkg);
info_ranges.push(start..info_bytes.len())
}
}
let information = if with_info {
Some(StatementInformation {
num_params,
info: info_bytes,
ranges: info_ranges,
})
} else {
None
};
self.state = ConnectionState::Clean;
self.stats.add_prepare(start_instant);
Ok(Statement {
stmt_id,
num_params,
information,
})
}
255 => {
handle_mysql_error(&mut p)?;
unreachable!()
}
v => {
self.state = ConnectionState::Broken;
Err(ConnectionErrorContent::ProtocolError(format!(
"Unexpected response type {v} to prepare statement"
))
.into())
}
}
}
fn query<'a>(&'a mut self, statement: &'a Statement) -> Query<'a> {
assert!(matches!(self.state, ConnectionState::Clean));
self.writer.seq = 0;
let mut p = self.writer.compose();
p.put_u8(com::STMT_EXECUTE);
p.put_u32(statement.stmt_id);
p.put_u8(0); p.put_u32(1);
let null_offset = p.writer.buff.len();
let mut type_offset = null_offset;
if statement.num_params != 0 {
let null_bytes = statement.num_params.div_ceil(8);
for _ in 0..null_bytes {
p.put_u8(0);
}
p.put_u8(1);
type_offset = p.writer.buff.len();
for _ in 0..statement.num_params {
p.put_u16(0);
}
}
Query {
connection: self,
statement,
cur_param: 0,
null_offset,
type_offset,
}
}
async fn query_send(&mut self) -> ConnectionResult<QueryResult> {
let p = Composer {
writer: &mut self.writer,
};
p.finalize();
let start_instant = self.stats.get_instant();
self.state = ConnectionState::QuerySend;
self.test_cancel()?;
self.writer.send().await?;
self.state = ConnectionState::QueryReadHead;
self.test_cancel()?;
let package = self.reader.read().await?;
{
let mut pp = PackageParser::new(package);
match pp.get_u8().loc("first_byte")? {
255 => {
handle_mysql_error(&mut pp)?;
}
0 => {
self.stats.add_execute(start_instant);
self.state = ConnectionState::Clean;
let affected_rows = pp.get_lenenc().loc("affected_rows")?;
let last_insert_id = pp.get_lenenc().loc("last_insert_id")?;
return Ok(QueryResult::ExecuteResult(ExecuteResult {
affected_rows,
last_insert_id,
}));
}
_ => (),
}
}
let column_count = PackageParser::new(package)
.get_lenenc()
.loc("column_count")?;
self.columns.clear();
for c in 0..column_count {
self.state = ConnectionState::QueryReadColumns(column_count - c);
self.test_cancel()?;
let package = self.reader.read().await?;
let mut p = PackageParser::new(package);
self.columns.push(parse_column_definition(&mut p)?);
}
self.stats.add_execute(start_instant);
self.state = ConnectionState::QueryReadRows;
Ok(QueryResult::WithColumns)
}
fn execute_unprepared(
&mut self,
escaped_sql: Cow<'_, str>,
) -> impl Future<Output = ConnectionResult<ExecuteResult>> + Send {
assert!(matches!(self.state, ConnectionState::Clean));
self.writer.seq = 0;
let mut p = self.writer.compose();
p.put_u8(com::QUERY);
p.put_bytes(escaped_sql.as_bytes());
p.finalize();
self.state = ConnectionState::UnpreparedSend;
async move {
let start_time = self.stats.get_instant();
self.test_cancel()?;
self.writer.send().await?;
self.state = ConnectionState::UnpreparedRecv;
self.test_cancel()?;
let package = self.reader.read().await?;
{
let mut pp = PackageParser::new(package);
match pp.get_u8().loc("first_byte")? {
255 => {
handle_mysql_error(&mut pp)?;
unreachable!()
}
0 => {
self.stats.add_execute(start_time);
self.state = ConnectionState::Clean;
let affected_rows = pp.get_lenenc().loc("affected_rows")?;
let last_insert_id = pp.get_lenenc().loc("last_insert_id")?;
Ok(ExecuteResult {
affected_rows,
last_insert_id,
})
}
v => {
self.state = ConnectionState::Broken;
Err(ConnectionErrorContent::ProtocolError(format!(
"Unexpected response type {v} to row package"
))
.into())
}
}
}
}
}
fn close_prepared_statement(
&mut self,
id: u32,
) -> impl Future<Output = ConnectionResult<()>> + Send {
assert!(matches!(self.state, ConnectionState::Clean));
self.writer.seq = 0;
let mut p = self.writer.compose();
p.put_u8(com::STMT_CLOSE);
p.put_u32(id);
p.finalize();
self.state = ConnectionState::ClosePreparedStatement;
async move {
let start_time = self.stats.get_instant();
self.test_cancel()?;
self.writer.send().await?;
self.stats.add_prepare(start_time);
self.state = ConnectionState::Clean;
Ok(())
}
}
fn ping(&mut self) -> impl Future<Output = ConnectionResult<()>> + Send {
assert!(matches!(self.state, ConnectionState::Clean));
self.writer.seq = 0;
let mut p = self.writer.compose();
p.put_u8(com::PING);
p.finalize();
self.state = ConnectionState::PingSend;
async move {
self.test_cancel()?;
self.writer.send().await?;
self.state = ConnectionState::PingRecv;
self.test_cancel()?;
let package = self.reader.read().await?;
{
let mut pp = PackageParser::new(package);
match pp.get_u8().loc("first_byte")? {
255 => {
handle_mysql_error(&mut pp)?;
unreachable!()
}
0 => {
self.state = ConnectionState::Clean;
Ok(())
}
v => {
self.state = ConnectionState::Broken;
Err(ConnectionErrorContent::ProtocolError(format!(
"Unexpected response type {v} to ping"
))
.into())
}
}
}
}
}
}
pub struct Connection {
prepared_statements: LRUCache<Statement>,
prepared_statement: Option<Statement>,
raw: RawConnection,
transaction_depth: usize,
cleanup_rollbacks: usize,
}
pub struct Query<'a> {
connection: &'a mut RawConnection,
statement: &'a Statement,
cur_param: u16,
null_offset: usize,
type_offset: usize,
}
impl<'a> Query<'a> {
#[inline]
pub fn information(&self) -> Option<&StatementInformation> {
self.statement.information.as_ref()
}
#[inline]
pub fn bind<T: Bind + ?Sized>(mut self, v: &T) -> ConnectionResult<Self> {
if self.cur_param == self.statement.num_params {
return Err(ConnectionErrorContent::Bind(
self.cur_param,
BindError::TooManyArgumentsBound,
)
.into());
}
let mut w = crate::bind::Writer::new(&mut self.connection.writer.buff);
if !v
.bind(&mut w)
.map_err(|e| ConnectionErrorContent::Bind(self.cur_param, e))?
{
let w = self.cur_param / 8;
let b = self.cur_param % 8;
self.connection.writer.buff[self.null_offset + w as usize] |= 1 << b;
}
self.connection.writer.buff[self.type_offset + (self.cur_param * 2) as usize] = T::TYPE;
if T::UNSIGNED {
self.connection.writer.buff[self.type_offset + (self.cur_param * 2) as usize + 1] = 128;
}
self.cur_param += 1;
Ok(self)
}
pub async fn fetch_optional_map<M: RowMap<'a>>(self) -> Result<Option<M::T>, M::E> {
if self.cur_param != self.statement.num_params {
return Err(ConnectionError::from(ConnectionErrorContent::Bind(
self.cur_param,
BindError::TooFewArgumentsBound,
))
.into());
}
match self.connection.query_send().await? {
QueryResult::WithColumns => (),
QueryResult::ExecuteResult(_) => {
return Err(ConnectionError::from(ConnectionErrorContent::ExpectedRows).into());
}
}
let start_instant = self.connection.stats.get_instant();
self.connection.test_cancel()?;
let p1 = self.connection.reader.read_raw().await?;
{
let mut pp = PackageParser::new(self.connection.reader.bytes(p1.clone()));
match pp.get_u8().loc("Row first byte")? {
0x00 => (),
0xFE => {
self.connection.state = ConnectionState::Clean;
return Ok(None);
}
0xFF => {
handle_mysql_error(&mut pp)?;
unreachable!()
}
v => {
return Err(ConnectionError::from(ConnectionErrorContent::ProtocolError(
format!("Unexpected response type {v} to row package"),
))
.into());
}
}
}
self.connection.reader.buffer_packages = true;
self.connection.test_cancel()?;
let p2 = self.connection.reader.read_raw().await?;
{
let mut pp = PackageParser::new(self.connection.reader.bytes(p2));
match pp.get_u8().loc("Row first byte")? {
0x00 => {
return Err(
ConnectionError::from(ConnectionErrorContent::UnexpectedRows).into(),
);
}
0xFE => {
self.connection.state = ConnectionState::Clean;
}
0xFF => {
handle_mysql_error(&mut pp)?;
unreachable!()
}
v => {
return Err(ConnectionError::from(ConnectionErrorContent::ProtocolError(
format!("Unexpected response type {v} to row package"),
))
.into());
}
}
}
self.connection.stats.add_fetch(start_instant);
let row = Row::new(&self.connection.columns, self.connection.reader.bytes(p1));
Ok(Some(M::map(row)?))
}
pub fn fetch_optional<T: FromRow<'a>>(
self,
) -> impl Future<Output = ConnectionResult<Option<T>>> + Send {
self.fetch_optional_map::<FromRowMapper<T>>()
}
#[inline]
pub async fn fetch_one<T: FromRow<'a>>(self) -> ConnectionResult<T> {
match self.fetch_optional().await? {
Some(v) => Ok(v),
None => Err(ConnectionErrorContent::ExpectedRows.into()),
}
}
pub async fn fetch_all_map<M: RowMap<'a>>(self) -> Result<Vec<M::T>, M::E> {
if self.cur_param != self.statement.num_params {
return Err(ConnectionError::from(ConnectionErrorContent::Bind(
self.cur_param,
BindError::TooFewArgumentsBound,
))
.into());
}
let start_instant = self.connection.stats.get_instant();
match self.connection.query_send().await? {
QueryResult::WithColumns => (),
QueryResult::ExecuteResult(_) => {
return Err(ConnectionError::from(ConnectionErrorContent::ExpectedRows).into());
}
};
self.connection.ranges.clear();
loop {
self.connection.test_cancel()?;
let p = self.connection.reader.read_raw().await?;
{
let mut pp = PackageParser::new(self.connection.reader.bytes(p.clone()));
match pp.get_u8().loc("Row first byte")? {
0x00 => self.connection.ranges.push(p),
0xFE => {
self.connection.state = ConnectionState::Clean;
break;
}
0xFF => {
handle_mysql_error(&mut pp)?;
unreachable!()
}
v => {
return Err(ConnectionError::from(ConnectionErrorContent::ProtocolError(
format!("Unexpected response type {v} to row package"),
))
.into());
}
}
}
self.connection.reader.buffer_packages = true;
}
self.connection.stats.add_fetch(start_instant);
let mut ans = Vec::with_capacity(self.connection.ranges.len());
for p in &self.connection.ranges {
let row = Row::new(
&self.connection.columns,
self.connection.reader.bytes(p.clone()),
);
ans.push(M::map(row)?);
}
Ok(ans)
}
pub fn fetch_all<T: FromRow<'a>>(
self,
) -> impl Future<Output = ConnectionResult<Vec<T>>> + Send {
self.fetch_all_map::<FromRowMapper<T>>()
}
pub async fn fetch(self) -> ConnectionResult<QueryIterator<'a>> {
if self.cur_param != self.statement.num_params {
return Err(ConnectionErrorContent::Bind(
self.cur_param,
BindError::TooFewArgumentsBound,
)
.into());
}
match self.connection.query_send().await? {
QueryResult::ExecuteResult(_) => Err(ConnectionErrorContent::ExpectedRows.into()),
QueryResult::WithColumns => Ok(QueryIterator {
connection: self.connection,
}),
}
}
pub async fn fetch_map<M>(self) -> ConnectionResult<MapQueryIterator<'a, M>>
where
for<'b> M: RowMap<'b>,
{
if self.cur_param != self.statement.num_params {
return Err(ConnectionErrorContent::Bind(
self.cur_param,
BindError::TooFewArgumentsBound,
)
.into());
}
match self.connection.query_send().await? {
QueryResult::ExecuteResult(_) => Err(ConnectionErrorContent::ExpectedRows.into()),
QueryResult::WithColumns => Ok(MapQueryIterator {
connection: self.connection,
_phantom: Default::default(),
}),
}
}
pub async fn execute(self) -> ConnectionResult<ExecuteResult> {
if self.cur_param != self.statement.num_params {
return Err(ConnectionErrorContent::Bind(
self.cur_param,
BindError::TooFewArgumentsBound,
)
.into());
}
match self.connection.query_send().await? {
QueryResult::WithColumns => Err(ConnectionErrorContent::UnexpectedRows.into()),
QueryResult::ExecuteResult(v) => Ok(v),
}
}
}
pub struct Transaction<'a> {
connection: &'a mut Connection,
}
impl<'a> Transaction<'a> {
pub async fn commit(self) -> ConnectionResult<()> {
self.connection.cleanup().await?;
let mut this = ManuallyDrop::new(self);
this.connection.commit_impl().await?;
Ok(())
}
pub async fn rollback(self) -> ConnectionResult<()> {
self.connection.cleanup().await?;
let mut this = ManuallyDrop::new(self);
this.connection.rollback_impl().await?;
Ok(())
}
}
impl<'a> Executor for Transaction<'a> {
#[inline]
fn query_raw(
&mut self,
stmt: Cow<'static, str>,
options: QueryOptions,
) -> impl Future<Output = ConnectionResult<Query<'_>>> + Send {
self.connection.query_inner(stmt, options)
}
#[inline]
fn begin(&mut self) -> impl Future<Output = ConnectionResult<Transaction<'_>>> + Send {
self.connection.begin_impl()
}
#[inline]
fn query_with_args_raw(
&mut self,
stmt: Cow<'static, str>,
options: QueryOptions,
args: impl Args + Send,
) -> impl Future<Output = ConnectionResult<Query<'_>>> {
self.connection.query_with_args_raw(stmt, options, args)
}
#[inline]
fn execute_unprepared(
&mut self,
stmt: &str,
) -> impl Future<Output = ConnectionResult<ExecuteResult>> + Send {
self.connection.execute_unprepared(stmt)
}
#[inline]
fn ping(&mut self) -> impl Future<Output = ConnectionResult<()>> + Send {
self.connection.ping()
}
}
impl<'a> Drop for Transaction<'a> {
fn drop(&mut self) {
self.connection.cleanup_rollbacks += 1;
}
}
pub trait Executor: Sized + Send {
fn query_raw(
&mut self,
stmt: Cow<'static, str>,
options: QueryOptions,
) -> impl Future<Output = ConnectionResult<Query<'_>>> + Send;
fn query_with_args_raw(
&mut self,
stmt: Cow<'static, str>,
options: QueryOptions,
args: impl Args + Send,
) -> impl Future<Output = ConnectionResult<Query<'_>>> + Send;
fn execute_unprepared(
&mut self,
stmt: &str,
) -> impl Future<Output = ConnectionResult<ExecuteResult>> + Send;
fn begin(&mut self) -> impl Future<Output = ConnectionResult<Transaction<'_>>> + Send;
fn ping(&mut self) -> impl Future<Output = ConnectionResult<()>> + Send;
}
pub trait ExecutorExt {
fn query(
&mut self,
stmt: impl Into<Cow<'static, str>>,
) -> impl Future<Output = ConnectionResult<Query<'_>>> + Send;
fn query_with_options(
&mut self,
stmt: impl Into<Cow<'static, str>>,
options: QueryOptions,
) -> impl Future<Output = ConnectionResult<Query<'_>>> + Send;
fn query_with_args(
&mut self,
stmt: impl Into<Cow<'static, str>>,
args: impl Args + Send,
) -> impl Future<Output = ConnectionResult<Query<'_>>> + Send;
fn fetch_all<'a, T: FromRow<'a>>(
&'a mut self,
stmt: impl Into<Cow<'static, str>>,
args: impl Args + Send,
) -> impl Future<Output = ConnectionResult<Vec<T>>> + Send;
fn fetch_all_map<'a, M: RowMap<'a>>(
&'a mut self,
stmt: impl Into<Cow<'static, str>>,
args: impl Args + Send,
) -> impl Future<Output = Result<Vec<M::T>, M::E>> + Send;
fn fetch_one<'a, T: FromRow<'a>>(
&'a mut self,
stmt: impl Into<Cow<'static, str>>,
args: impl Args + Send,
) -> impl Future<Output = ConnectionResult<T>> + Send;
fn fetch_one_map<'a, M: RowMap<'a>>(
&'a mut self,
stmt: impl Into<Cow<'static, str>>,
args: impl Args + Send,
) -> impl Future<Output = Result<M::T, M::E>> + Send;
fn fetch_optional<'a, T: FromRow<'a>>(
&'a mut self,
stmt: impl Into<Cow<'static, str>>,
args: impl Args + Send,
) -> impl Future<Output = ConnectionResult<Option<T>>> + Send;
fn fetch_optional_map<'a, M: RowMap<'a>>(
&'a mut self,
stmt: impl Into<Cow<'static, str>>,
args: impl Args + Send,
) -> impl Future<Output = Result<Option<M::T>, M::E>> + Send;
fn execute(
&mut self,
stmt: impl Into<Cow<'static, str>>,
args: impl Args + Send,
) -> impl Future<Output = ConnectionResult<ExecuteResult>> + Send;
fn fetch(
&mut self,
stmt: impl Into<Cow<'static, str>>,
args: impl Args + Send,
) -> impl Future<Output = ConnectionResult<QueryIterator<'_>>> + Send;
fn fetch_map<'a, M>(
&'a mut self,
stmt: impl Into<Cow<'static, str>>,
args: impl Args + Send,
) -> impl Future<Output = ConnectionResult<MapQueryIterator<'a, M>>> + Send
where
for<'b> M: RowMap<'b>;
}
async fn fetch_all_impl<'a, E: Executor + Sized + Send, T: FromRow<'a>>(
e: &'a mut E,
stmt: Cow<'static, str>,
args: impl Args + Send,
) -> ConnectionResult<Vec<T>> {
let q = e
.query_with_args_raw(stmt, QueryOptions::new(), args)
.await?;
q.fetch_all().await
}
async fn fetch_all_map_impl<'a, E: Executor + Sized + Send, M: RowMap<'a>>(
e: &'a mut E,
stmt: Cow<'static, str>,
args: impl Args + Send,
) -> Result<Vec<M::T>, M::E> {
let q = e
.query_with_args_raw(stmt, QueryOptions::new(), args)
.await?;
q.fetch_all_map::<M>().await
}
async fn fetch_one_impl<'a, E: Executor + Sized + Send, T: FromRow<'a>>(
e: &'a mut E,
stmt: Cow<'static, str>,
args: impl Args + Send,
) -> ConnectionResult<T> {
let q = e
.query_with_args_raw(stmt, QueryOptions::new(), args)
.await?;
match q.fetch_optional().await? {
Some(v) => Ok(v),
None => Err(ConnectionErrorContent::ExpectedRows.into()),
}
}
async fn fetch_one_map_impl<'a, E: Executor + Sized + Send, M: RowMap<'a>>(
e: &'a mut E,
stmt: Cow<'static, str>,
args: impl Args + Send,
) -> Result<M::T, M::E> {
let q = e
.query_with_args_raw(stmt, QueryOptions::new(), args)
.await
.map_err(M::E::from)?;
match q.fetch_optional_map::<M>().await? {
Some(v) => Ok(v),
None => Err(ConnectionError::from(ConnectionErrorContent::ExpectedRows).into()),
}
}
async fn fetch_optional_impl<'a, E: Executor + Sized + Send, T: FromRow<'a>>(
e: &'a mut E,
stmt: Cow<'static, str>,
args: impl Args + Send,
) -> ConnectionResult<Option<T>> {
let q = e
.query_with_args_raw(stmt, QueryOptions::new(), args)
.await?;
q.fetch_optional().await
}
async fn fetch_optional_map_impl<'a, E: Executor + Sized + Send, M: RowMap<'a>>(
e: &'a mut E,
stmt: Cow<'static, str>,
args: impl Args + Send,
) -> Result<Option<M::T>, M::E> {
let q = e
.query_with_args_raw(stmt, QueryOptions::new(), args)
.await?;
q.fetch_optional_map::<M>().await
}
async fn execute_impl<E: Executor + Sized + Send>(
e: &mut E,
stmt: Cow<'static, str>,
args: impl Args + Send,
) -> ConnectionResult<ExecuteResult> {
let q = e
.query_with_args_raw(stmt, QueryOptions::new(), args)
.await?;
q.execute().await
}
async fn fetch_impl<'a, E: Executor + Sized + Send>(
e: &'a mut E,
stmt: Cow<'static, str>,
args: impl Args + Send,
) -> ConnectionResult<QueryIterator<'a>> {
let q = e
.query_with_args_raw(stmt, QueryOptions::new(), args)
.await?;
q.fetch().await
}
async fn fetch_map_impl<'a, E: Executor + Sized + Send, M>(
e: &'a mut E,
stmt: Cow<'static, str>,
args: impl Args + Send,
) -> ConnectionResult<MapQueryIterator<'a, M>>
where
for<'b> M: RowMap<'b>,
{
let q = e
.query_with_args_raw(stmt, QueryOptions::new(), args)
.await?;
q.fetch_map::<M>().await
}
impl<E: Executor + Sized + Send> ExecutorExt for E {
#[inline]
fn query(
&mut self,
stmt: impl Into<Cow<'static, str>>,
) -> impl Future<Output = ConnectionResult<Query<'_>>> + Send {
self.query_raw(stmt.into(), QueryOptions::new())
}
#[inline]
fn query_with_options(
&mut self,
stmt: impl Into<Cow<'static, str>>,
options: QueryOptions,
) -> impl Future<Output = ConnectionResult<Query<'_>>> + Send {
self.query_raw(stmt.into(), options)
}
#[inline]
fn query_with_args(
&mut self,
stmt: impl Into<Cow<'static, str>>,
args: impl Args + Send,
) -> impl Future<Output = ConnectionResult<Query<'_>>> {
self.query_with_args_raw(stmt.into(), QueryOptions::new(), args)
}
#[inline]
fn fetch_all<'a, T: FromRow<'a>>(
&'a mut self,
stmt: impl Into<Cow<'static, str>>,
args: impl Args + Send,
) -> impl Future<Output = ConnectionResult<Vec<T>>> + Send {
fetch_all_impl(self, stmt.into(), args)
}
#[inline]
fn fetch_all_map<'a, M: RowMap<'a>>(
&'a mut self,
stmt: impl Into<Cow<'static, str>>,
args: impl Args + Send,
) -> impl Future<Output = Result<Vec<M::T>, M::E>> + Send {
fetch_all_map_impl::<E, M>(self, stmt.into(), args)
}
#[inline]
fn fetch_one<'a, T: FromRow<'a>>(
&'a mut self,
stmt: impl Into<Cow<'static, str>>,
args: impl Args + Send,
) -> impl Future<Output = ConnectionResult<T>> + Send {
fetch_one_impl(self, stmt.into(), args)
}
#[inline]
fn fetch_one_map<'a, M: RowMap<'a>>(
&'a mut self,
stmt: impl Into<Cow<'static, str>>,
args: impl Args + Send,
) -> impl Future<Output = Result<M::T, M::E>> + Send {
fetch_one_map_impl::<E, M>(self, stmt.into(), args)
}
#[inline]
fn fetch_optional<'a, T: FromRow<'a>>(
&'a mut self,
stmt: impl Into<Cow<'static, str>>,
args: impl Args + Send,
) -> impl Future<Output = ConnectionResult<Option<T>>> + Send {
fetch_optional_impl(self, stmt.into(), args)
}
#[inline]
fn fetch_optional_map<'a, M: RowMap<'a>>(
&'a mut self,
stmt: impl Into<Cow<'static, str>>,
args: impl Args + Send,
) -> impl Future<Output = Result<Option<M::T>, M::E>> + Send {
fetch_optional_map_impl::<E, M>(self, stmt.into(), args)
}
#[inline]
fn execute(
&mut self,
stmt: impl Into<Cow<'static, str>>,
args: impl Args + Send,
) -> impl Future<Output = ConnectionResult<ExecuteResult>> + Send {
execute_impl(self, stmt.into(), args)
}
#[inline]
fn fetch(
&mut self,
stmt: impl Into<Cow<'static, str>>,
args: impl Args + Send,
) -> impl Future<Output = ConnectionResult<QueryIterator<'_>>> + Send {
fetch_impl(self, stmt.into(), args)
}
#[inline]
fn fetch_map<'a, M>(
&'a mut self,
stmt: impl Into<Cow<'static, str>>,
args: impl Args + Send,
) -> impl Future<Output = ConnectionResult<MapQueryIterator<'a, M>>> + Send
where
for<'b> M: RowMap<'b>,
{
fetch_map_impl::<E, M>(self, stmt.into(), args)
}
}
pub struct QueryOptions {
cache: bool,
information: bool,
}
impl QueryOptions {
pub fn new() -> Self {
Default::default()
}
pub fn cache(self, enable: bool) -> Self {
QueryOptions {
cache: enable,
..self
}
}
pub fn information(self, enable: bool) -> Self {
QueryOptions {
information: enable,
..self
}
}
}
impl Default for QueryOptions {
fn default() -> Self {
Self {
cache: true,
information: false,
}
}
}
impl Connection {
pub async fn connect(options: &ConnectionOptions<'_>) -> ConnectionResult<Self> {
let raw = RawConnection::connect(options).await?;
Ok(Connection {
raw,
prepared_statements: LRUCache::new(options.statement_case_size),
transaction_depth: 0,
cleanup_rollbacks: 0,
prepared_statement: None,
})
}
pub fn is_clean(&self) -> bool {
matches!(self.raw.state, ConnectionState::Clean)
}
pub async fn cleanup(&mut self) -> ConnectionResult<()> {
self.raw.cleanup().await?;
assert!(self.cleanup_rollbacks <= self.transaction_depth);
if self.cleanup_rollbacks != 0 {
let statement =
rollback_transaction_query(self.transaction_depth - self.cleanup_rollbacks);
self.transaction_depth -= self.cleanup_rollbacks;
self.cleanup_rollbacks = 0;
self.raw.execute_unprepared(statement).await?;
}
if let Some(v) = self.prepared_statement.take() {
self.raw.close_prepared_statement(v.stmt_id).await?
}
Ok(())
}
async fn query_inner(
&mut self,
stmt: Cow<'static, str>,
options: QueryOptions,
) -> ConnectionResult<Query<'_>> {
self.cleanup().await?;
if !options.cache {
let r = self.raw.prepare_query(&stmt, options.information).await?;
self.prepared_statement = Some(r);
Ok(self.raw.query(self.prepared_statement.as_ref().unwrap()))
} else {
let statement = match self.prepared_statements.entry(stmt) {
Entry::Occupied(mut e) => {
if e.get().information.is_none() && options.information {
let r = self.raw.prepare_query(e.key(), options.information).await?;
let old = e.insert(r);
self.raw.close_prepared_statement(old.stmt_id).await?
}
e.bump();
e.into_mut()
}
Entry::Vacant(e) => {
let r = self.raw.prepare_query(e.key(), options.information).await?;
let (r, old) = e.insert(r);
if let Some((_, old)) = old {
self.raw.close_prepared_statement(old.stmt_id).await?
}
r
}
};
Ok(self.raw.query(statement))
}
}
async fn begin_impl(&mut self) -> ConnectionResult<Transaction<'_>> {
self.cleanup().await?;
assert_eq!(self.cleanup_rollbacks, 0);
let q = begin_transaction_query(self.transaction_depth);
self.transaction_depth += 1;
self.cleanup_rollbacks = 1;
self.raw.execute_unprepared(q).await?;
assert_eq!(self.cleanup_rollbacks, 1);
self.cleanup_rollbacks = 0;
Ok(Transaction { connection: self })
}
async fn ping_impl(&mut self) -> ConnectionResult<()> {
self.cleanup().await?;
self.raw.ping().await?;
Ok(())
}
fn rollback_impl(&mut self) -> impl Future<Output = ConnectionResult<ExecuteResult>> + Send {
assert!(matches!(self.raw.state, ConnectionState::Clean));
assert_eq!(self.cleanup_rollbacks, 0);
assert_ne!(self.transaction_depth, 0);
self.transaction_depth -= 1;
self.raw
.execute_unprepared(rollback_transaction_query(self.transaction_depth))
}
fn commit_impl(&mut self) -> impl Future<Output = ConnectionResult<ExecuteResult>> + Send {
assert!(matches!(self.raw.state, ConnectionState::Clean));
assert_eq!(self.cleanup_rollbacks, 0);
assert_ne!(self.transaction_depth, 0);
self.transaction_depth -= 1;
self.raw
.execute_unprepared(commit_transaction_query(self.transaction_depth))
}
#[cfg(feature = "cancel_testing")]
#[doc(hidden)]
pub fn set_cancel_count(&mut self, cnt: Option<usize>) {
self.raw.cancel_count = cnt;
}
#[cfg(feature = "stats")]
pub fn stats(&self) -> &Stats {
&self.raw.stats
}
#[cfg(feature = "stats")]
pub fn clear_stats(&mut self) {
self.raw.stats = Default::default()
}
}
#[cfg(feature = "list_hack")]
fn convert_list_query(
stmt: Cow<'static, str>,
lengths: &[usize],
) -> ConnectionResult<Cow<'static, str>> {
if let Some((head, tail)) = stmt.split_once("_LIST_") {
let mut stmt = String::with_capacity(stmt.len() + 2 * lengths.iter().sum::<usize>());
stmt.push_str(head);
let mut len_it = lengths.iter();
for part in tail.split("_LIST_") {
let Some(len) = len_it.next() else {
return Err(ConnectionErrorContent::TooFewListArguments.into());
};
if *len == 0 {
stmt.push_str("NULL");
} else {
for i in 0..*len {
if i == 0 {
stmt.push('?');
} else {
stmt.push_str(", ?");
}
}
}
stmt.push_str(part);
}
if len_it.next().is_some() {
return Err(ConnectionErrorContent::TooManyListArguments.into());
}
Ok(stmt.into())
} else {
if !lengths.is_empty() {
return Err(ConnectionErrorContent::TooManyListArguments.into());
}
Ok(stmt)
}
}
impl Executor for Connection {
#[inline]
fn query_raw(
&mut self,
stmt: Cow<'static, str>,
options: QueryOptions,
) -> impl Future<Output = ConnectionResult<Query<'_>>> + Send {
self.query_inner(stmt, options)
}
#[inline]
fn begin(&mut self) -> impl Future<Output = ConnectionResult<Transaction<'_>>> + Send {
self.begin_impl()
}
#[inline]
async fn query_with_args_raw(
&mut self,
stmt: Cow<'static, str>,
options: QueryOptions,
args: impl Args + Send,
) -> ConnectionResult<Query<'_>> {
#[cfg(feature = "list_hack")]
let stmt = {
self.raw.list_lengths.clear();
args.list_lengths(&mut self.raw.list_lengths);
convert_list_query(stmt, &self.raw.list_lengths)?
};
self.cleanup().await?;
let query = self.query_inner(stmt, options).await?;
args.bind_args(query)
}
async fn execute_unprepared(&mut self, stmt: &str) -> ConnectionResult<ExecuteResult> {
self.cleanup().await?;
self.raw.execute_unprepared(stmt.into()).await
}
#[inline]
fn ping(&mut self) -> impl Future<Output = ConnectionResult<()>> + Send {
self.ping_impl()
}
}