use std::fmt::{Display, Formatter};
use std::io::Cursor;
use std::os::unix::io::AsRawFd;
use std::sync::{Arc, atomic::{AtomicU8, Ordering}};
use sharded_slab::{Pool, Slab};
use serde::Serialize;
use tokio::sync::{mpsc, oneshot, Notify};
use tokio::net::{TcpStream, ToSocketAddrs, tcp::{OwnedWriteHalf, OwnedReadHalf}};
use serde::de::DeserializeOwned;
use tokio::io::{BufReader, BufWriter};
use futures::future::try_join;
use nix::sys::socket;
use thiserror::Error;
use crate::iproto::{consts, request, response};
use response::ResponseBody;
type Buffer = Vec<u8>;
#[derive(Error, Debug)]
pub enum Error {
#[error("tarantool error")]
TarantoolError(response::ErrorResponse),
#[error("invalid response")]
InvalidResponse,
#[error("decoding error")]
InvalidDecoding,
#[error("error")]
ErrorCode(u8),
}
#[derive(Debug)]
pub struct TarantoolResp {
pub header: response::ResponseHeader,
pub cursor_ref: CursorRef,
}
pub type TarantoolResult = Result<TarantoolResp, Error>;
#[derive(Debug)]
pub struct CursorRef {
pub buffer_key: usize,
pub position: u64,
}
struct RequestHandle {
request_id: usize,
tx: oneshot::Sender<TarantoolResult>,
}
pub struct Connection {
state: AtomicU8,
requests_to_process_tx: mpsc::Sender<usize>,
pending_requests: Slab<RequestHandle>,
requests_not_full_notify: Notify,
buffer_pool: Arc<Pool<Buffer>>,
salt: Vec<u8>,
mss: u32,
}
const DISCONNECTED_STATE: u8 = 0;
const CONNECTED_STATE: u8 = 0;
impl Connection {
pub async fn connect<A: ToSocketAddrs>(addr: A) -> std::io::Result<Arc<Self>> {
use tokio::io::AsyncReadExt;
let stream = TcpStream::connect(addr).await?;
let mss = socket::getsockopt(stream.as_raw_fd(), socket::sockopt::TcpMaxSeg)?;
let (mut read_stream, write_stream) = stream.into_split();
let salt = {
let mut greeting_raw = [0; 128];
read_stream.read_exact(&mut greeting_raw).await?;
let salt_b64 = std::str::from_utf8(&greeting_raw[64..108]).unwrap().trim();
base64::decode(salt_b64).unwrap()
};
let (requests_to_process_tx, requests_to_process_rx) = mpsc::channel(128);
let conn = Arc::new(Connection {
state: AtomicU8::new(CONNECTED_STATE),
requests_to_process_tx,
pending_requests: Slab::new(),
requests_not_full_notify: Notify::new(),
buffer_pool: Arc::new(Pool::new()),
salt,
mss,
});
let conn_clone = conn.clone();
tokio::spawn(async move {
let writer_task = conn_clone.writer(requests_to_process_rx, write_stream);
let reader_task = conn_clone.reader(read_stream);
match try_join(writer_task, reader_task).await {
Ok(_) => {}
Err(_) => {}
}
});
Ok(conn)
}
fn write_req_to_buf<R>(&self, req: &R) -> Result<usize, rmp_serde::encode::Error>
where R: request::Request<Buffer>,
{
let mut write_buf = self.buffer_pool.create().unwrap();
write_buf.extend_from_slice(&[0xCE, 0, 0, 0, 0]);
req.encode(write_buf.as_mut())?;
let body_len = write_buf.len() - 5;
write_buf[1] = (body_len >> 24) as u8;
write_buf[2] = (body_len >> 16) as u8;
write_buf[3] = (body_len >> 8) as u8;
write_buf[4] = body_len as u8;
Ok(write_buf.key())
}
pub async fn make_request<Req, Resp, F>(&self, f: F) -> Result<Resp, Error>
where
Req: request::Request<Buffer>,
Resp: response::ResponseBody,
F: FnOnce(usize) -> Req,
{
let (tx, rx) = oneshot::channel();
let request_id = {
let entry = self.pending_requests.vacant_entry().unwrap();
let request_id = entry.key();
entry.insert(RequestHandle { request_id, tx });
request_id
};
let req = f(request_id);
let buffer_key = self.write_req_to_buf(&req).unwrap();
self.requests_to_process_tx.send(buffer_key).await.unwrap();
let TarantoolResp {
header: response::ResponseHeader { response_code_indicator, .. },
cursor_ref: CursorRef { buffer_key, position },
} = rx.await.unwrap().unwrap();
let buffer = self.buffer_pool.get(buffer_key).unwrap();
let mut cursor: Cursor<&Buffer> = Cursor::new(buffer.as_ref());
cursor.set_position(position);
const IPROTO_OK: u32 = consts::IPROTO_OK as u32;
let result = match response_code_indicator {
IPROTO_OK => {
let data_resp = Resp::decode(&mut cursor).unwrap();
Ok(data_resp)
}
0x8000..=0x8fff => {
let _error_code = response_code_indicator - 0x8000;
let err_resp = response::ErrorResponse::decode(&mut cursor).unwrap();
Err(Error::TarantoolError(err_resp))
}
_ => { panic!("error") }
};
self.buffer_pool.clear(buffer_key);
result
}
pub async fn call<T, R>(&self, name: &str, data: &T) -> Result<R, Error>
where
T: Serialize,
R: DeserializeOwned,
{
let resp: response::CallResponse<R> = self.make_request(|request_id| {
request::Call::new(request_id, name, data)
}).await?;
Ok(resp.into_data())
}
pub async fn auth(&self, username: &str, password: Option<&str>) -> Result<(), Error> {
let _resp: response::EmptyResponse = self.make_request(|request_id| {
request::Auth::new(request_id, &self.salt, username, password)
}).await?;
Ok(())
}
async fn writer(&self, mut requests_to_process_rx: mpsc::Receiver<usize>, write_stream: OwnedWriteHalf) -> std::io::Result<()> {
use tokio::io::AsyncWriteExt;
let mut write_stream = BufWriter::with_capacity(128 * 1024, write_stream);
while self.state.load(Ordering::Relaxed) == CONNECTED_STATE {
let buffer_key = requests_to_process_rx.recv().await.unwrap();
{
let mut write_buf = self.buffer_pool.clone().get_owned(buffer_key).unwrap();
write_stream.write_all(&mut write_buf).await?;
self.buffer_pool.clear(buffer_key);
}
const OPTIMAL_PAYLOAD_SIZE: usize = 1000;
while write_stream.buffer().len() < OPTIMAL_PAYLOAD_SIZE {
if let Ok(buffer_key) = requests_to_process_rx.try_recv() {
let mut write_buf = self.buffer_pool.clone().get_owned(buffer_key).unwrap();
write_stream.write_all(&mut write_buf).await?;
self.buffer_pool.clear(buffer_key);
} else {
break;
}
}
write_stream.flush().await?;
}
Ok(())
}
async fn reader(&self, read_stream: OwnedReadHalf) -> std::io::Result<()> {
use tokio::io::AsyncReadExt;
let mut read_stream = BufReader::with_capacity(128 * 1024, read_stream);
let mut payload_len_raw = [0; 5];
while self.state.load(Ordering::Relaxed) == CONNECTED_STATE {
read_stream.read_exact(&mut payload_len_raw).await?;
if payload_len_raw[0] != 0xCE {
panic!("invalid resp");
}
let len = ((payload_len_raw[1] as usize) << 24)
+ ((payload_len_raw[2] as usize) << 16)
+ ((payload_len_raw[3] as usize) << 8)
+ (payload_len_raw[4] as usize);
let mut resp_buf = self.buffer_pool.clone().create_owned().unwrap();
let buffer_key = resp_buf.key();
resp_buf.resize(len, 0);
read_stream.read_exact(&mut resp_buf).await?;
let resp_buf_ref: &mut Buffer = resp_buf.as_mut();
let mut resp_reader = Cursor::new(resp_buf_ref);
let header = response::ResponseHeader::decode(&mut resp_reader).unwrap();
let request_id = header.request_id();
let result = Ok(TarantoolResp {
header,
cursor_ref: CursorRef {
buffer_key,
position: resp_reader.position(),
},
});
drop(resp_buf);
let req = self.pending_requests.take(request_id).unwrap();
req.tx.send(result).unwrap();
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use std::time::Duration;
use tokio::time::timeout;
use super::Connection;
const TESTING_HOST: &str = "localhost:3301";
async fn conn() -> Arc<Connection> { Connection::connect(TESTING_HOST).await.unwrap() }
#[tokio::test]
async fn client_test() {
let conn = conn().await;
let t = Duration::from_secs(2);
timeout(t, conn.auth("guest", None)).await.unwrap().unwrap();
let (result, ): (usize, ) = timeout(t, conn.call("sum", &(1, 2))).await.unwrap().unwrap();
assert_eq!(result, 3);
let (result, ): (usize, ) = timeout(t, conn.call("sum", &(1, 2))).await.unwrap().unwrap();
assert_eq!(result, 3);
}
#[tokio::test]
#[should_panic]
async fn test_tarantool_error() {
let conn = conn().await;
let _: () = conn.call("not_existing_procedure", &(1, 2, 3)).await.unwrap();
}
#[tokio::test]
#[should_panic]
async fn test_invalid_user() {
let conn = conn().await;
conn.auth("kek", None).await.unwrap();
}
}