use super::serialize::ENCODING;
use super::{K, qtype};
use std::convert::TryInto;
use std::path::Path;
use std::net::IpAddr;
use std::{io, env, str, fs};
use std::collections::HashMap;
use io::BufRead;
use async_trait::async_trait;
use trust_dns_resolver::TokioAsyncResolver;
use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader};
use tokio::net::{TcpStream, TcpListener};
use tokio_native_tls::native_tls::{TlsConnector as TlsConnectorInner, TlsAcceptor as TlsAcceptorInner, Identity};
use tokio_native_tls::{TlsStream, TlsConnector, TlsAcceptor};
use sha1::Sha1;
use once_cell::sync::Lazy;
#[cfg(unix)]
use tokio::net::{UnixStream, UnixListener};
pub mod qmsg_type{
pub const asynchronous: u8 = 0;
pub const synchronous: u8 = 1;
pub const response:u8 = 2;
}
const ACCOUNTS:Lazy<HashMap<String, String>>=Lazy::new(||{
let mut map:HashMap<String, String>=HashMap::new();
let file=fs::OpenOptions::new().read(true).open(env::var("KDBPLUS_ACCOUNT_FILE").expect("KDBPLUS_ACCOUNT_FILE is not set")).expect("failed to open account file");
let mut reader = io::BufReader::new(file);
let mut line=String::new();
loop{
match reader.read_line(&mut line){
Ok(0) => break,
Ok(_) => {
let credential=line.as_str().split(':').collect::<Vec<&str>>();
let mut password=credential[1];
if password.ends_with('\n'){
password=&password[0..password.len()-1];
}
map.insert(credential[0].to_string(), password.to_string());
line.clear();
},
_ => unreachable!()
}
}
map
});
pub enum ConnectionMethod{
TCP = 0,
TLS = 1,
UDS = 2
}
pub trait Query: Send + Sync{
fn serialize(&self, message_type: u8) -> Vec<u8>;
}
#[async_trait]
trait QStreamInner: Send + Sync{
async fn shutdown(&mut self, is_server: bool) -> io::Result<()>;
async fn send_message(&mut self, message: &dyn Query, message_type: u8) -> io::Result<()>;
async fn send_async_message(&mut self, message: &dyn Query) -> io::Result<()>;
async fn send_sync_message(&mut self, message: &dyn Query) -> io::Result<K>;
async fn receive_message(&mut self) ->io::Result<(u8, K)>;
}
pub struct QStream{
stream: Box<dyn QStreamInner>,
method: ConnectionMethod,
listener: bool,
}
#[derive(Clone, Copy, Debug)]
struct MessageHeader{
encoding: u8,
message_type: u8,
compressed: u8,
unused: u8,
length: u32
}
impl Query for &str{
fn serialize(&self, message_type: u8) -> Vec<u8>{
let byte_message=self.as_bytes();
let message_length=byte_message.len() as u32;
let total_length=MessageHeader::size() as u32 + 6 + message_length;
let total_length_bytes=match ENCODING{
0 => total_length.to_be_bytes(),
_ => total_length.to_le_bytes()
};
let mut message=Vec::with_capacity(message_length as usize + MessageHeader::size());
message.extend_from_slice(&[ENCODING, message_type, 0, 0]);
message.extend_from_slice(&total_length_bytes);
message.extend_from_slice(&[qtype::STRING as u8, 0]);
let length_info=match ENCODING{
0 => message_length.to_be_bytes(),
_ => message_length.to_le_bytes()
};
message.extend_from_slice(&length_info);
message.extend_from_slice(byte_message);
message
}
}
impl Query for K{
fn serialize(&self, message_type: u8) -> Vec<u8>{
let mut byte_message=self.q_ipc_encode();
let message_length=byte_message.len();
let total_length=(MessageHeader::size() + message_length) as u32;
let total_length_bytes=match ENCODING{
0 => total_length.to_be_bytes(),
_ => total_length.to_le_bytes()
};
if message_length > 1992{
let mut message=Vec::with_capacity(message_length + 8);
message.extend_from_slice(&[ENCODING, message_type as u8, 0, 0, 0, 0, 0, 0]);
message.append(&mut byte_message);
match compress(message){
(true, compressed) => {
return compressed;
},
(false, mut uncompressed) => {
uncompressed[4..8].copy_from_slice(&total_length_bytes);
return uncompressed;
}
}
}
else{
let mut message=Vec::with_capacity(message_length + MessageHeader::size());
message.extend_from_slice(&[ENCODING, message_type as u8, 0, 0]);
message.extend_from_slice(&total_length_bytes);
message.append(&mut byte_message);
return message;
}
}
}
impl QStream{
fn new(stream: Box<dyn QStreamInner>, method: ConnectionMethod, is_listener: bool) -> Self{
QStream{
stream: stream,
method: method,
listener: is_listener
}
}
pub async fn connect(method: ConnectionMethod, host: &str, port: u16, credential: &str) -> io::Result<Self>{
match method{
ConnectionMethod::TCP => {
let stream=connect_tcp(host, port, credential).await?;
Ok(QStream::new(Box::new(stream), ConnectionMethod::TCP, false))
},
ConnectionMethod::TLS => {
let stream=connect_tls(host, port, credential).await?;
Ok(QStream::new(Box::new(stream), ConnectionMethod::TLS, false))
},
ConnectionMethod::UDS => {
let stream=connect_uds(port, credential).await?;
Ok(QStream::new(Box::new(stream), ConnectionMethod::UDS, false))
}
}
}
pub async fn accept(method: ConnectionMethod, host: &str, port: u16) -> io::Result<Self>{
match method{
ConnectionMethod::TCP => {
let listener = TcpListener::bind(&format!("{}:{}", host, port)).await?;
let (mut socket, _) = listener.accept().await?;
while let Err(_) = read_client_input(&mut socket).await{
socket = listener.accept().await?.0;
}
Ok(QStream::new(Box::new(socket), ConnectionMethod::TCP, true))
},
ConnectionMethod::TLS => {
let listener = TcpListener::bind(&format!("{}:{}", host, port)).await?;
let identity = build_identity_from_cert().await?;
let tls_acceptor = TlsAcceptor::from(TlsAcceptorInner::new(identity).unwrap());
let (mut socket, _) = listener.accept().await?;
let mut tls_socket=tls_acceptor.accept(socket).await.expect("failed to accept TLS connection");
while let Err(_) = read_client_input(&mut tls_socket).await{
socket = listener.accept().await?.0;
tls_socket=tls_acceptor.accept(socket).await.expect("failed to accept TLS connection");
}
let mut qstream=QStream::new(Box::new(TlsStream::from(tls_socket)), ConnectionMethod::TCP, true);
qstream.send_async_message(&".kdbplus.close_tls_connection_:{[] hclose .z.w;}").await?;
Ok(qstream)
}
ConnectionMethod::UDS => {
let uds_path=create_sockfile_path(port)?;
let abstract_sockfile_=format!("\x00{}", uds_path);
let abstract_sockfile=Path::new(&abstract_sockfile_);
let listener = UnixListener::bind(&abstract_sockfile).unwrap();
let (mut socket, _) = listener.accept().await?;
while let Err(_) = read_client_input(&mut socket).await{
socket = listener.accept().await?.0;
}
Ok(QStream::new(Box::new(socket), method, true))
}
}
}
pub async fn shutdown(mut self)-> io::Result<()>{
self.stream.shutdown(self.listener).await
}
pub async fn send_message(&mut self, message: &dyn Query, message_type: u8)-> io::Result<()>{
self.stream.send_message(message, message_type).await
}
pub async fn send_async_message(&mut self, message: &dyn Query)-> io::Result<()>{
self.stream.send_async_message(message).await
}
pub async fn send_sync_message(&mut self, message: &dyn Query)-> io::Result<K>{
self.stream.send_sync_message(message).await
}
pub async fn receive_message(&mut self) -> io::Result<(u8, K)>{
self.stream.receive_message().await
}
pub fn get_connection_type(&self) -> &str{
match self.method{
ConnectionMethod::TCP => "TCP",
ConnectionMethod::TLS => "TLS",
ConnectionMethod::UDS => "UDS"
}
}
}
#[async_trait]
impl QStreamInner for TcpStream{
async fn shutdown(&mut self, _: bool) -> io::Result<()>{
AsyncWriteExt::shutdown(self).await
}
async fn send_message(&mut self, message: &dyn Query, message_type: u8) -> io::Result<()>{
let byte_message=message.serialize(message_type);
self.write_all(&byte_message).await
}
async fn send_async_message(&mut self, message: &dyn Query) -> io::Result<()>{
let byte_message=message.serialize(qmsg_type::asynchronous);
self.write_all(&byte_message).await
}
async fn send_sync_message(&mut self, message: &dyn Query) -> io::Result<K>{
let byte_message=message.serialize(qmsg_type::synchronous);
self.write_all(&byte_message).await?;
match receive_message(self).await{
Ok((qmsg_type::response, response)) => Ok(response),
Err(error) => Err(error),
Ok((_ , message)) => Err(io::Error::new(io::ErrorKind::InvalidData, format!("expected a response: {}", message)))
}
}
async fn receive_message(&mut self) -> io::Result<(u8, K)>{
receive_message(self).await
}
}
#[async_trait]
impl QStreamInner for TlsStream<TcpStream>{
async fn shutdown(&mut self, is_listener: bool) -> io::Result<()>{
if is_listener{
self.send_async_message(&".kdbplus.close_tls_connection_[]").await
}
else{
self.get_mut().shutdown()
}
}
async fn send_message(&mut self, message: &dyn Query, message_type: u8) -> io::Result<()>{
let byte_message=message.serialize(message_type);
self.write_all(&byte_message).await
}
async fn send_async_message(&mut self, message: &dyn Query) -> io::Result<()>{
let byte_message=message.serialize(qmsg_type::asynchronous);
self.write_all(&byte_message).await
}
async fn send_sync_message(&mut self, message: &dyn Query) -> io::Result<K>{
let byte_message=message.serialize(qmsg_type::synchronous);
self.write_all(&byte_message).await?;
match receive_message(self).await{
Ok((qmsg_type::response, response)) => Ok(response),
Err(error) => Err(error),
Ok((_ , message)) => Err(io::Error::new(io::ErrorKind::InvalidData, format!("expected a response: {}", message)))
}
}
async fn receive_message(&mut self) -> io::Result<(u8, K)>{
receive_message(self).await
}
}
#[async_trait]
impl QStreamInner for UnixStream{
async fn shutdown(&mut self, _: bool) -> io::Result<()>{
AsyncWriteExt::shutdown(self).await
}
async fn send_message(&mut self, message: &dyn Query, message_type: u8) -> io::Result<()>{
let byte_message=message.serialize(message_type);
self.write_all(&byte_message).await
}
async fn send_async_message(&mut self, message: &dyn Query) -> io::Result<()>{
let byte_message=message.serialize(qmsg_type::asynchronous);
self.write_all(&byte_message).await
}
async fn send_sync_message(&mut self, message: &dyn Query) -> io::Result<K>{
let byte_message=message.serialize(qmsg_type::synchronous);
self.write_all(&byte_message).await?;
match receive_message(self).await{
Ok((qmsg_type::response, response)) => Ok(response),
Err(error) => Err(error),
Ok((_ , message)) => Err(io::Error::new(io::ErrorKind::InvalidData, format!("expected a response: {}", message)))
}
}
async fn receive_message(&mut self) -> io::Result<(u8, K)>{
receive_message(self).await
}
}
impl MessageHeader{
fn new(encoding: u8, message_type: u8, compressed: u8, length: u32) -> Self{
MessageHeader{
encoding: encoding,
message_type: message_type,
compressed: compressed,
unused: 0,
length: length
}
}
fn from_bytes(bytes: [u8; 8]) -> Self{
let encoding=bytes[0];
let length=match encoding{
0 => u32::from_be_bytes(bytes[4..8].try_into().unwrap()),
_ => u32::from_le_bytes(bytes[4..8].try_into().unwrap())
};
MessageHeader::new(encoding, bytes[1], bytes[2], length)
}
fn size() -> usize{
8
}
}
async fn connect_tcp_impl(host: &str, port: u16) -> io::Result<TcpStream>{
let resolver=TokioAsyncResolver::tokio_from_system_conf().expect("failed to create a resolver");
let ips;
if let Ok(ip) = host.parse::<IpAddr>(){
ips=vec![ip.to_string()]
}
else{
ips=resolver.ipv4_lookup(format!("{}.", host).as_str()).await.unwrap().iter().map(|result| result.to_string()).collect()
};
for answer in ips{
let host_port=format!("{}:{}", answer, port);
match TcpStream::connect(&host_port).await{
Ok(socket) => {
println!("connected: {}", host_port);
return Ok(socket);
},
Err(_) => {
eprintln!("connection refused: {}. try next.", host_port);
}
}
}
Err(io::Error::new(io::ErrorKind::ConnectionRefused, "failed to connect"))
}
async fn handshake<S>(socket: &mut S, credential_: &str, method_bytes: &str) -> io::Result<()> where S: Unpin + AsyncWriteExt + AsyncReadExt{
let credential=credential_.to_string()+method_bytes;
socket.write_all(credential.as_bytes()).await?;
let mut cap= [0u8;1];
if let Err(_)=socket.read_exact(&mut cap).await{
Err(io::Error::new(io::ErrorKind::ConnectionAborted, "authentication failure"))
}
else{
Ok(())
}
}
async fn connect_tcp(host: &str, port: u16, credential: &str) -> io::Result<TcpStream>{
let mut socket=connect_tcp_impl(host, port).await?;
handshake(&mut socket, credential, "\x03\x00").await?;
Ok(socket)
}
async fn connect_tls(host: &str, port: u16, credential: &str) -> io::Result<TlsStream<TcpStream>>{
let socket_=connect_tcp_impl(host, port).await?;
let connector = TlsConnector::from(TlsConnectorInner::new().unwrap());
let mut socket = connector.connect(host, socket_).await.expect("failed to create TLS session");
handshake(&mut socket, credential, "\x03\x00").await?;
Ok(socket)
}
fn create_sockfile_path(port: u16) -> io::Result<String>{
let udspath=match env::var("QUDSPATH"){
Ok(dir) => format!("{}/kx.{}", dir, port),
Err(_) => format!("/tmp/kx.{}", port)
};
Ok(udspath)
}
#[cfg(unix)]
async fn connect_uds(port: u16, credential: &str) -> io::Result<UnixStream>{
let uds_path=create_sockfile_path(port)?;
let abstract_sockfile_=format!("\x00{}", uds_path);
let abstract_sockfile=Path::new(&abstract_sockfile_);
let mut socket =UnixStream::connect(&abstract_sockfile).await?;
handshake(&mut socket, credential, "\x06\x00").await?;
Ok(socket)
}
async fn read_client_input<S>(socket: &mut S) -> io::Result<()> where S: Unpin + AsyncWriteExt + AsyncReadExt{
let mut client_input=[0u8; 32];
let mut passed_credential=String::new();
loop{
match socket.read(&mut client_input).await{
Ok(0) => {
},
Ok(_) => {
if let Some(index) = client_input.iter().position(|byte| *byte == 0x03 || *byte == 0x06){
let capacity=client_input[index];
passed_credential.push_str(str::from_utf8(&client_input[0..index]).expect("invalid bytes"));
let credential = passed_credential.as_str().split(':').collect::<Vec<&str>>();
if let Some(encoded) = ACCOUNTS.get(&credential[0].to_string()){
let mut hasher = Sha1::new();
hasher.update(credential[1].as_bytes());
let encoded_password=hasher.digest().to_string();
if encoded == &encoded_password{
socket.write_all(&[capacity; 1]).await?;
return Ok(());
}
else{
socket.shutdown().await?;
return Err(io::Error::new(io::ErrorKind::InvalidData, "authentication failed"));
}
}
else{
socket.shutdown().await?;
return Err(io::Error::new(io::ErrorKind::InvalidData, "authentication failed"));
}
}
else{
passed_credential.push_str(str::from_utf8(&client_input).expect("invalid bytes"));
}
},
Err(error) => {
return Err(error);
}
}
}
}
async fn build_identity_from_cert() -> io::Result<Identity>{
if let Ok(path) = env::var("KDBPLUS_TLS_KEY_FILE"){
if let Ok(password) = env::var("KDBPLUS_TLS_KEY_FILE_SECRET"){
let cert_file=tokio::fs::File::open(Path::new(&path)).await.unwrap();
let mut reader=BufReader::new(cert_file);
let mut der: Vec<u8>=Vec::new();
reader.read_to_end(&mut der).await?;
if let Ok(identity) = Identity::from_pkcs12(&der, &password){
return Ok(identity);
}
else{
return Err(io::Error::new(io::ErrorKind::InvalidData, "authentication failed"));
}
}
else{
return Err(io::Error::new(io::ErrorKind::NotFound, "KDBPLUS_TLS_KEY_FILE_SECRET is not set"));
}
}
else{
return Err(io::Error::new(io::ErrorKind::NotFound, "KDBPLUS_TLS_KEY_FILE is not set"));
}
}
async fn receive_message<S>(socket: &mut S) -> io::Result<(u8, K)> where S: Unpin + AsyncReadExt{
let mut header_buffer=[0u8; 8];
if let Err(err)=socket.read_exact(&mut header_buffer).await{
return Err(io::Error::new(io::ErrorKind::ConnectionAborted, format!("Connection dropped: {}", err)));
}
let header=MessageHeader::from_bytes(header_buffer);
let body_length= header.length as usize - MessageHeader::size();
let mut body: Vec<u8>=Vec::with_capacity(body_length);
body.resize(body_length, 0_u8);
if let Err(err)=socket.read_exact(&mut body).await{
return Err(io::Error::new(io::ErrorKind::UnexpectedEof, format!("Failed to read body of message: {}", err)));
}
if header.compressed == 0x01{
body = decompress(body, header.encoding);
}
Ok((header.message_type, K::q_ipc_decode(&body, header.encoding)))
}
fn compress(raw: Vec<u8>) -> (bool, Vec<u8>){
let mut i = 0_u8;
let mut f = 0_u8;
let mut h0 = 0_usize;
let mut h = 0_usize;
let mut g: bool;
let mut compressed: Vec<u8> = Vec::with_capacity((raw.len()) / 2);
compressed.resize((raw.len()) / 2, 0_u8);
let mut c = 12;
let mut d = c;
let e = compressed.len();
let mut p = 0_usize;
let mut q: usize;
let mut r: usize;
let mut s0 = 0_usize;
let mut s = 8_usize;
let t = raw.len();
let mut a =[0_i32; 256];
compressed[0..4].copy_from_slice(&raw[0..4]);
compressed[2]=1;
let raw_size=match ENCODING{
0 => (t as u32).to_be_bytes(),
_ => (t as u32).to_le_bytes()
};
compressed[8..12].copy_from_slice(&raw_size);
while s < t {
if i == 0 {
if d > e-17 {
return (false, raw);
}
i = 1;
compressed[c] = f;
c = d;
d += 1;
f = 0;
}
g = s > t-3;
if !g {
h = (raw[s] ^ raw[s+1]) as usize;
p = a[h] as usize;
g = (0 == p) || (0 != (raw[s] ^ raw[p]));
}
if 0 < s0 {
a[h0] = s0 as i32;
s0 = 0;
}
if g {
h0 = h;
s0 = s;
compressed[d] = raw[s];
d += 1;
s += 1;
}
else {
a[h] = s as i32;
f |= i;
p += 2;
s += 2;
r = s;
q = if s+255 > t {t}else{s+255};
while (s < q) && (raw[p] == raw[s]) {
s += 1;
if s < q {
p += 1;
}
}
compressed[d] = h as u8;
d += 1;
compressed[d] = (s - r) as u8;
d += 1;
}
i=i.wrapping_mul(2);
}
compressed[c] = f;
let compressed_size=match ENCODING{
0 => (d as u32).to_be_bytes(),
_ => (d as u32).to_le_bytes()
};
compressed[4..8].copy_from_slice(&compressed_size);
let _ = compressed.split_off(d);
(true, compressed)
}
fn decompress(compressed: Vec<u8>, encoding: u8) -> Vec<u8>{
let mut n=0;
let mut r: usize;
let mut f=0_usize;
let mut s=0_usize;
let mut p = s;
let mut i = 0_usize;
let size=match encoding{
0 => i32::from_be_bytes(compressed[0..3].try_into().expect("slice does not have length 4"))-8,
_ => i32::from_le_bytes(compressed[0..3].try_into().expect("slice does not have length 4"))-8
};
let mut decompressed: Vec<u8> = Vec::with_capacity(size as usize);
decompressed.resize(size as usize, 0_u8);
let mut d=4;
let mut aa= [0_i32; 256];
while s < decompressed.len() {
if i == 0 {
f = (0xff & compressed[d]) as usize;
d+=1;
i = 1;
}
if (f & i) != 0{
r = aa[(0xff & compressed[d]) as usize] as usize;
d+=1;
decompressed[s] = decompressed[r];
s+=1;
r+=1;
decompressed[s] = decompressed[r];
s+=1;
r+=1;
n = (0xff & compressed[d]) as usize;
d+=1;
for m in 0..n{
decompressed[s+m] = decompressed[r+m];
}
}
else{
decompressed[s] = compressed[d];
s+=1;
d+=1;
}
while p < s-1 {
aa[((0xff & decompressed[p])^(0xff & decompressed[p+1])) as usize] = p as i32;
p+=1;
}
if (f & i) != 0 {
s += n;
p = s;
}
i *= 2;
if i == 256 {
i = 0;
}
}
decompressed
}