use std::convert::TryInto;
use std::io;
use async_std::net::{Shutdown, TcpStream};
use byteorder::{ByteOrder, LittleEndian};
use futures_core::future::BoxFuture;
use crate::cache::StatementCache;
use crate::connection::Connection;
use crate::io::{Buf, BufMut, BufStream};
use crate::mysql::error::MySqlError;
use crate::mysql::protocol::{
Capabilities, Decode, Encode, EofPacket, ErrPacket, Handshake, HandshakeResponse, OkPacket,
};
use crate::url::Url;
pub struct MySqlConnection {
pub(super) stream: BufStream<TcpStream>,
pub(super) capabilities: Capabilities,
pub(super) statement_cache: StatementCache<u32>,
rbuf: Vec<u8>,
next_seq_no: u8,
pub(super) ready: bool,
}
impl MySqlConnection {
pub(super) fn begin_command_phase(&mut self) {
self.next_seq_no = 0;
}
pub(super) fn write(&mut self, packet: impl Encode + std::fmt::Debug) {
let buf = self.stream.buffer_mut();
let header_offset = buf.len();
buf.advance(4);
packet.encode(buf, self.capabilities);
let len = buf.len() - header_offset - 4;
let mut header = &mut buf[header_offset..];
LittleEndian::write_u32(&mut header, len as u32);
header[3] = self.next_seq_no;
self.next_seq_no = self.next_seq_no.wrapping_add(1);
}
async fn receive_ok(&mut self) -> crate::Result<OkPacket> {
let packet = self.receive().await?;
Ok(match packet[0] {
0xfe | 0x00 => OkPacket::decode(packet)?,
0xff => {
return Err(MySqlError(ErrPacket::decode(packet)?).into());
}
id => {
return Err(protocol_err!(
"unexpected packet identifier 0x{:X?} when expecting 0xFE (OK) or 0xFF \
(ERR)",
id
)
.into());
}
})
}
pub(super) async fn receive_eof(&mut self) -> crate::Result<()> {
if !self.capabilities.contains(Capabilities::DEPRECATE_EOF) {
let _eof = EofPacket::decode(self.receive().await?)?;
}
Ok(())
}
pub(super) async fn receive(&mut self) -> crate::Result<&[u8]> {
Ok(self
.try_receive()
.await?
.ok_or(io::ErrorKind::UnexpectedEof)?)
}
pub(super) async fn try_receive(&mut self) -> crate::Result<Option<&[u8]>> {
self.rbuf.clear();
let mut header = ret_if_none!(self.stream.peek(4).await?);
let payload_len = header.get_uint::<LittleEndian>(3)? as usize;
self.next_seq_no = header.get_u8()?.wrapping_add(1);
self.stream.consume(4);
let mut payload = ret_if_none!(self.stream.peek(payload_len).await?);
self.rbuf.extend_from_slice(payload);
self.stream.consume(payload_len);
Ok(Some(&self.rbuf[..payload_len]))
}
}
impl MySqlConnection {
pub(super) async fn open(url: crate::Result<Url>) -> crate::Result<Self> {
let url = url?;
let stream = TcpStream::connect((url.host(), url.port(3306))).await?;
let mut self_ = Self {
stream: BufStream::new(stream),
capabilities: Capabilities::empty(),
rbuf: Vec::with_capacity(8192),
next_seq_no: 0,
statement_cache: StatementCache::new(),
ready: true,
};
let handshake_packet = self_.receive().await?;
let handshake = Handshake::decode(handshake_packet)?;
let client_capabilities = Capabilities::PROTOCOL_41
| Capabilities::IGNORE_SPACE
| Capabilities::FOUND_ROWS
| Capabilities::CONNECT_WITH_DB;
self_.capabilities =
(client_capabilities & handshake.server_capabilities) | Capabilities::PROTOCOL_41;
self_.write(HandshakeResponse {
client_collation: 192, max_packet_size: 1024,
username: url.username().unwrap_or("root"),
database: url.database().expect("required database"),
});
self_.stream.flush().await?;
let _ok = self_.receive_ok().await?;
Ok(self_)
}
async fn close(mut self) -> crate::Result<()> {
self.stream.flush().await?;
self.stream.stream.shutdown(Shutdown::Both)?;
Ok(())
}
}
impl Connection for MySqlConnection {
fn open<T>(url: T) -> BoxFuture<'static, crate::Result<Self>>
where
T: TryInto<Url, Error = crate::Error>,
Self: Sized,
{
Box::pin(MySqlConnection::open(url.try_into()))
}
fn close(self) -> BoxFuture<'static, crate::Result<()>> {
Box::pin(self.close())
}
}