use bytes::{Buf, BufMut, Bytes};
use tokio;
use tokio::io::AsyncWriteExt;
pub enum RserveConnection {
Tcp(tokio::net::TcpStream),
Unix(tokio::net::UnixStream),
}
pub enum ReturnValue {
Char(char),
Int(i32),
Double(f64),
Null(String),
Bool(bool),
Str(String),
IntVec(Vec<i32>),
DoubleVec(Vec<f64>),
BoolVec(Vec<bool>),
StrVec(Vec<String>),
}
impl RserveConnection {
async fn readable(&mut self) -> std::io::Result<()> {
match self {
RserveConnection::Tcp(stream) => stream.readable().await?,
RserveConnection::Unix(stream) => stream.readable().await?,
}
Ok(())
}
fn try_read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
match self {
RserveConnection::Tcp(stream) => stream.try_read(buf),
RserveConnection::Unix(stream) => stream.try_read(buf),
}
}
async fn writable(&self) -> std::io::Result<()> {
match self {
RserveConnection::Tcp(stream) => stream.writable().await?,
RserveConnection::Unix(stream) => stream.writable().await?,
}
Ok(())
}
fn try_write(&self, buf: &[u8]) -> std::io::Result<usize> {
match self {
RserveConnection::Tcp(stream) => stream.try_write(buf),
RserveConnection::Unix(stream) => stream.try_write(buf),
}
}
async fn shut_down(&mut self) -> std::io::Result<()> {
match self {
RserveConnection::Tcp(stream) => stream.shutdown().await?,
RserveConnection::Unix(stream) => stream.shutdown().await?,
}
Ok(())
}
pub async fn eval(
&mut self,
command: &str,
void: bool,
) -> Result<ReturnValue, Box<dyn std::error::Error>> {
self.writable().await?;
let cmd = Bytes::from(command.to_string());
let cmd_length = cmd.len() as i32;
let mut message_header = vec![];
if void {
message_header.put_i32_le(0x002_i32); } else {
message_header.put_i32_le(0x003_i32); }
message_header.put_i32_le(cmd_length + 4);
message_header.put_i32_le(0_i32);
message_header.put_i32_le(0_i32);
let mut data_header = vec![];
data_header.put_u8(0x04_u8);
data_header.put_i32_le(cmd_length);
let mut message = vec![];
message.put(&message_header[..]);
message.put(&data_header[..4]);
message.put(&cmd[..]);
match self.try_write(&message) {
Ok(n) => {
assert_eq!(n, message.len());
}
Err(ref e) if e.kind() == tokio::io::ErrorKind::WouldBlock => {}
Err(e) => {
return Err(e.into());
}
};
loop {
self.readable().await?;
let mut data = vec![0_u8; 1024];
match self.try_read(&mut data) {
Ok(n) => {
let mut res_data = &data[..n];
let cmd_res = res_data.get_i32_le(); let err_code = (cmd_res >> 24) & 127;
let response_code = cmd_res & 0xfffff;
if response_code != (0x10000 | 0x0001) {
let err_info = format!("error code: {}", err_code);
return Err(Box::new(std::io::Error::new(
std::io::ErrorKind::Other,
err_info,
)));
}
res_data.advance(12);
let data_type = res_data.get_u8(); res_data.advance(3);
match data_type {
1_u8 => {
let a = res_data.get_i32_le();
return Ok(ReturnValue::Int(a));
}
2_u8 => {
let a = res_data.get_u8() as char;
return Ok(ReturnValue::Char(a));
}
3_u8 => {
let a = res_data.get_f64_le();
return Ok(ReturnValue::Double(a));
}
4_u8 => {
let a = res_data.chunk().to_vec();
let s = String::from_utf8(a).unwrap();
return Ok(ReturnValue::Str(s));
}
10_u8 => {
let expression_type = res_data.get_u8(); let raw_data_header_length2 = res_data.take(3); let mut dst = vec![];
dst.put(raw_data_header_length2);
dst.put_u8(0_u8);
let data_length2 = (&dst[..]).get_i32_le();
res_data.advance(3);
match expression_type {
0_u8 => {
return Ok(ReturnValue::Null("NULL".to_string()));
}
1_u8 => {
let a = res_data.get_i32_le();
return Ok(ReturnValue::Int(a));
}
2_u8 => {
let a = res_data.get_f64_le();
return Ok(ReturnValue::Double(a));
}
3_u8 => {
let a = res_data.chunk().to_vec();
let s = String::from_utf8(a).unwrap();
return Ok(ReturnValue::Str(s));
}
6_u8 => {
let a = res_data.get_u8();
if a == 1 {
return Ok(ReturnValue::Bool(true));
} else {
return Ok(ReturnValue::Bool(true));
}
}
32_u8 => {
let mut a: Vec<i32> = vec![];
for _ in 0..data_length2 / 4 {
a.push(res_data.get_i32_le());
}
return Ok(ReturnValue::IntVec(a));
}
33_u8 => {
let mut a: Vec<f64> = vec![];
for _ in 0..data_length2 / 8 {
a.push(res_data.get_f64_le());
}
return Ok(ReturnValue::DoubleVec(a));
}
34_u8 => {
let a: Vec<String> =
String::from_utf8(res_data.take(data_length2 as usize).chunk().to_vec())
.unwrap()
.split("\0")
.map(|word| word.to_string())
.collect();
return Ok(ReturnValue::StrVec(a));
}
36_u8 => {
let mut a: Vec<bool> = vec![];
for _ in 0..data_length2 {
let b = res_data.get_u8();
if b == 1 {
a.push(true);
} else {
a.push(false);
}
}
return Ok(ReturnValue::BoolVec(a));
}
_ => {
return Err(Box::new(std::io::Error::new(
std::io::ErrorKind::Unsupported,
"unsupported outcome type!",
)));
}
}
}
_ => {
return Err(Box::new(std::io::Error::new(
std::io::ErrorKind::Unsupported,
"unsupported outcome type!",
)));
}
};
}
Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {}
Err(e) => {
return Err(e.into());
}
};
}
}
}
pub async fn connect(addr: &str) -> Result<RserveConnection, Box<dyn std::error::Error>> {
if addr.starts_with("tcp://") {
let addr = addr.trim_start_matches("tcp://");
let s = tokio::net::TcpStream::connect(addr).await?;
loop {
s.readable().await?;
let mut data = vec![0_u8; 1024];
match s.try_read(&mut data) {
Ok(n) => {
let string_result = String::from_utf8_lossy(&data[..n]);
assert!(string_result.starts_with("Rsrv01"));
break;
}
Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {}
Err(e) => {
return Err(e.into());
}
}
}
Ok(RserveConnection::Tcp(s))
} else if addr.starts_with("unix://") {
let path = addr.trim_start_matches("unix://");
let ss = tokio::net::UnixStream::connect(path).await?;
loop {
ss.readable().await?;
let mut data = vec![0_u8; 1024];
match ss.try_read(&mut data) {
Ok(n) => {
let string_result = String::from_utf8_lossy(&data[..n]);
assert!(string_result.starts_with("Rsrv01"));
break;
}
Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {}
Err(e) => {
return Err(e.into());
}
}
}
Ok(RserveConnection::Unix(ss))
} else {
Err(Box::new(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"Invalid address format",
)))
}
}