use super::serialize::ENCODING;
use super::Result;
use super::{qtype, K};
use async_trait::async_trait;
use io::BufRead;
use once_cell::sync::Lazy;
use sha1_smol::Sha1;
use std::collections::HashMap;
use std::convert::TryInto;
use std::net::{IpAddr, Ipv4Addr};
use std::path::Path;
use std::{env, fs, io, str};
use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader};
use tokio::net::{TcpListener, TcpStream};
#[cfg(unix)]
use tokio::net::{UnixListener, UnixStream};
use tokio_native_tls::native_tls::{
Identity, TlsAcceptor as TlsAcceptorInner, TlsConnector as TlsConnectorInner,
};
use tokio_native_tls::{TlsAcceptor, TlsConnector, TlsStream};
use trust_dns_resolver::TokioAsyncResolver;
pub mod qmsg_type {
pub const asynchronous: u8 = 0;
pub const synchronous: u8 = 1;
pub const response: u8 = 2;
}
const ACCOUNTS: Lazy<HashMap<String, String>> = Lazy::new(|| {
let mut map: HashMap<String, String> = HashMap::new();
let file = fs::OpenOptions::new()
.read(true)
.open(env::var("KDBPLUS_ACCOUNT_FILE").expect("KDBPLUS_ACCOUNT_FILE is not set"))
.expect("failed to open account file");
let mut reader = io::BufReader::new(file);
let mut line = String::new();
loop {
match reader.read_line(&mut line) {
Ok(0) => break,
Ok(_) => {
let credential = line.as_str().split(':').collect::<Vec<&str>>();
let mut password = credential[1];
if password.ends_with('\n') {
password = &password[0..password.len() - 1];
}
map.insert(credential[0].to_string(), password.to_string());
line.clear();
}
_ => unreachable!(),
}
}
map
});
pub enum ConnectionMethod {
TCP = 0,
TLS = 1,
UDS = 2,
}
#[async_trait]
pub trait Query: Send + Sync {
async fn serialize(&self, message_type: u8, is_local: bool) -> Vec<u8>;
}
#[async_trait]
trait QStreamInner: Send + Sync {
async fn shutdown(&mut self, is_server: bool) -> Result<()>;
async fn send_message(
&mut self,
message: &dyn Query,
message_type: u8,
is_local: bool,
) -> Result<()>;
async fn send_async_message(&mut self, message: &dyn Query, is_local: bool) -> Result<()>;
async fn send_sync_message(&mut self, message: &dyn Query, is_local: bool) -> Result<K>;
async fn receive_message(&mut self) -> Result<(u8, K)>;
}
pub struct QStream {
stream: Box<dyn QStreamInner>,
method: ConnectionMethod,
listener: bool,
local: bool,
}
#[derive(Clone, Copy, Debug)]
struct MessageHeader {
encoding: u8,
message_type: u8,
compressed: u8,
_unused: u8,
length: u32,
}
#[async_trait]
impl Query for &str {
async fn serialize(&self, message_type: u8, _: bool) -> Vec<u8> {
let byte_message = self.as_bytes();
let message_length = byte_message.len() as u32;
let total_length = MessageHeader::size() as u32 + 6 + message_length;
let total_length_bytes = match ENCODING {
0 => total_length.to_be_bytes(),
_ => total_length.to_le_bytes(),
};
let mut message = Vec::with_capacity(message_length as usize + MessageHeader::size());
message.extend_from_slice(&[ENCODING, message_type, 0, 0]);
message.extend_from_slice(&total_length_bytes);
message.extend_from_slice(&[qtype::STRING as u8, 0]);
let length_info = match ENCODING {
0 => message_length.to_be_bytes(),
_ => message_length.to_le_bytes(),
};
message.extend_from_slice(&length_info);
message.extend_from_slice(byte_message);
message
}
}
#[async_trait]
impl Query for K {
async fn serialize(&self, message_type: u8, is_local: bool) -> Vec<u8> {
let mut byte_message = self.q_ipc_encode();
let message_length = byte_message.len();
let total_length = (MessageHeader::size() + message_length) as u32;
let total_length_bytes = match ENCODING {
0 => total_length.to_be_bytes(),
_ => total_length.to_le_bytes(),
};
if message_length > 1992 && !is_local {
let mut message = Vec::with_capacity(message_length + 8);
message.extend_from_slice(&[ENCODING, message_type as u8, 0, 0, 0, 0, 0, 0]);
message.append(&mut byte_message);
match compress(message).await {
(true, compressed) => {
return compressed;
}
(false, mut uncompressed) => {
uncompressed[4..8].copy_from_slice(&total_length_bytes);
return uncompressed;
}
}
} else {
let mut message = Vec::with_capacity(message_length + MessageHeader::size());
message.extend_from_slice(&[ENCODING, message_type as u8, 0, 0]);
message.extend_from_slice(&total_length_bytes);
message.append(&mut byte_message);
return message;
}
}
}
impl QStream {
fn new(
stream: Box<dyn QStreamInner>,
method: ConnectionMethod,
is_listener: bool,
is_local: bool,
) -> Self {
QStream {
stream: stream,
method: method,
listener: is_listener,
local: is_local,
}
}
pub async fn connect(
method: ConnectionMethod,
host: &str,
port: u16,
credential: &str,
) -> Result<Self> {
match method {
ConnectionMethod::TCP => {
let stream = connect_tcp(host, port, credential).await?;
let is_local = match host {
"localhost" | "127.0.0.1" => true,
_ => false,
};
Ok(QStream::new(
Box::new(stream),
ConnectionMethod::TCP,
false,
is_local,
))
}
ConnectionMethod::TLS => {
let stream = connect_tls(host, port, credential).await?;
Ok(QStream::new(
Box::new(stream),
ConnectionMethod::TLS,
false,
false,
))
}
ConnectionMethod::UDS => {
let stream = connect_uds(port, credential).await?;
Ok(QStream::new(
Box::new(stream),
ConnectionMethod::UDS,
false,
true,
))
}
}
}
pub async fn accept(method: ConnectionMethod, host: &str, port: u16) -> Result<Self> {
match method {
ConnectionMethod::TCP => {
let listener = TcpListener::bind(&format!("{}:{}", host, port)).await?;
let (mut socket, ip_address) = listener.accept().await?;
while let Err(_) = read_client_input(&mut socket).await {
socket = listener.accept().await?.0;
}
Ok(QStream::new(
Box::new(socket),
ConnectionMethod::TCP,
true,
ip_address.ip() == IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
))
}
ConnectionMethod::TLS => {
let listener = TcpListener::bind(&format!("{}:{}", host, port)).await?;
let identity = build_identity_from_cert().await?;
let tls_acceptor = TlsAcceptor::from(TlsAcceptorInner::new(identity).unwrap());
let (mut socket, _) = listener.accept().await?;
let mut tls_socket = tls_acceptor
.accept(socket)
.await
.expect("failed to accept TLS connection");
while let Err(_) = read_client_input(&mut tls_socket).await {
socket = listener.accept().await?.0;
tls_socket = tls_acceptor
.accept(socket)
.await
.expect("failed to accept TLS connection");
}
let mut qstream = QStream::new(
Box::new(TlsStream::from(tls_socket)),
ConnectionMethod::TCP,
true,
false,
);
qstream
.send_async_message(&".kdbplus.close_tls_connection_:{[] hclose .z.w;}")
.await?;
Ok(qstream)
}
ConnectionMethod::UDS => {
let uds_path = create_sockfile_path(port)?;
let abstract_sockfile_ = format!("\x00{}", uds_path);
let abstract_sockfile = Path::new(&abstract_sockfile_);
let listener = UnixListener::bind(&abstract_sockfile).unwrap();
let (mut socket, _) = listener.accept().await?;
while let Err(_) = read_client_input(&mut socket).await {
socket = listener.accept().await?.0;
}
Ok(QStream::new(Box::new(socket), method, true, true))
}
}
}
pub async fn shutdown(mut self) -> Result<()> {
self.stream.shutdown(self.listener).await
}
pub async fn send_message(&mut self, message: &dyn Query, message_type: u8) -> Result<()> {
self.stream
.send_message(message, message_type, self.local)
.await
}
pub async fn send_async_message(&mut self, message: &dyn Query) -> Result<()> {
self.stream.send_async_message(message, self.local).await
}
pub async fn send_sync_message(&mut self, message: &dyn Query) -> Result<K> {
self.stream.send_sync_message(message, self.local).await
}
pub async fn receive_message(&mut self) -> Result<(u8, K)> {
self.stream.receive_message().await
}
pub fn get_connection_type(&self) -> &str {
match self.method {
ConnectionMethod::TCP => "TCP",
ConnectionMethod::TLS => "TLS",
ConnectionMethod::UDS => "UDS",
}
}
pub fn enforce_compression(&mut self) {
self.local = false;
}
}
#[async_trait]
impl QStreamInner for TcpStream {
async fn shutdown(&mut self, _: bool) -> Result<()> {
AsyncWriteExt::shutdown(self).await?;
Ok(())
}
async fn send_message(
&mut self,
message: &dyn Query,
message_type: u8,
is_local: bool,
) -> Result<()> {
let byte_message = message.serialize(message_type, is_local).await;
write_all_cancellation_safe(self, &byte_message).await?;
Ok(())
}
async fn send_async_message(&mut self, message: &dyn Query, is_local: bool) -> Result<()> {
let byte_message = message.serialize(qmsg_type::asynchronous, is_local).await;
write_all_cancellation_safe(self, &byte_message).await?;
Ok(())
}
async fn send_sync_message(&mut self, message: &dyn Query, is_local: bool) -> Result<K> {
let byte_message = message.serialize(qmsg_type::synchronous, is_local).await;
write_all_cancellation_safe(self, &byte_message).await?;
match receive_message(self).await {
Ok((qmsg_type::response, response)) => Ok(response),
Err(error) => Err(error),
Ok((_, message)) => Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("expected a response: {}", message),
)
.into()),
}
}
async fn receive_message(&mut self) -> Result<(u8, K)> {
receive_message(self).await
}
}
#[async_trait]
impl QStreamInner for TlsStream<TcpStream> {
async fn shutdown(&mut self, is_listener: bool) -> Result<()> {
if is_listener {
self.send_async_message(&".kdbplus.close_tls_connection_[]", false)
.await
.into()
} else {
self.get_mut().shutdown()?;
Ok(())
}
}
async fn send_message(
&mut self,
message: &dyn Query,
message_type: u8,
is_local: bool,
) -> Result<()> {
let byte_message = message.serialize(message_type, is_local).await;
write_all_cancellation_safe(self, &byte_message).await?;
Ok(())
}
async fn send_async_message(&mut self, message: &dyn Query, is_local: bool) -> Result<()> {
let byte_message = message.serialize(qmsg_type::asynchronous, is_local).await;
write_all_cancellation_safe(self, &byte_message).await?;
Ok(())
}
async fn send_sync_message(&mut self, message: &dyn Query, is_local: bool) -> Result<K> {
let byte_message = message.serialize(qmsg_type::synchronous, is_local).await;
write_all_cancellation_safe(self, &byte_message).await?;
match receive_message(self).await {
Ok((qmsg_type::response, response)) => Ok(response),
Err(error) => Err(error),
Ok((_, message)) => Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("expected a response: {}", message),
)
.into()),
}
}
async fn receive_message(&mut self) -> Result<(u8, K)> {
receive_message(self).await
}
}
#[async_trait]
impl QStreamInner for UnixStream {
async fn shutdown(&mut self, _: bool) -> Result<()> {
AsyncWriteExt::shutdown(self).await?;
Ok(())
}
async fn send_message(
&mut self,
message: &dyn Query,
message_type: u8,
is_local: bool,
) -> Result<()> {
let byte_message = message.serialize(message_type, is_local).await;
write_all_cancellation_safe(self, &byte_message).await?;
Ok(())
}
async fn send_async_message(&mut self, message: &dyn Query, is_local: bool) -> Result<()> {
let byte_message = message.serialize(qmsg_type::asynchronous, is_local).await;
write_all_cancellation_safe(self, &byte_message).await?;
Ok(())
}
async fn send_sync_message(&mut self, message: &dyn Query, is_local: bool) -> Result<K> {
let byte_message = message.serialize(qmsg_type::synchronous, is_local).await;
write_all_cancellation_safe(self, &byte_message).await?;
match receive_message(self).await {
Ok((qmsg_type::response, response)) => Ok(response),
Err(error) => Err(error),
Ok((_, message)) => Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("expected a response: {}", message),
)
.into()),
}
}
async fn receive_message(&mut self) -> Result<(u8, K)> {
receive_message(self).await
}
}
impl MessageHeader {
fn new(encoding: u8, message_type: u8, compressed: u8, length: u32) -> Self {
MessageHeader {
encoding: encoding,
message_type: message_type,
compressed: compressed,
_unused: 0,
length: length,
}
}
fn from_bytes(bytes: [u8; 8]) -> Self {
let encoding = bytes[0];
let length = match encoding {
0 => u32::from_be_bytes(bytes[4..8].try_into().unwrap()),
_ => u32::from_le_bytes(bytes[4..8].try_into().unwrap()),
};
MessageHeader::new(encoding, bytes[1], bytes[2], length)
}
fn size() -> usize {
8
}
}
async fn connect_tcp_impl(host: &str, port: u16) -> Result<TcpStream> {
let resolver =
TokioAsyncResolver::tokio_from_system_conf().expect("failed to create a resolver");
let ips;
if let Ok(ip) = host.parse::<IpAddr>() {
ips = vec![ip.to_string()]
} else {
ips = resolver
.ipv4_lookup(format!("{}.", host).as_str())
.await
.unwrap()
.iter()
.map(|result| result.to_string())
.collect()
};
for answer in ips {
let host_port = format!("{}:{}", answer, port);
match TcpStream::connect(&host_port).await {
Ok(socket) => {
println!("connected: {}", host_port);
return Ok(socket);
}
Err(_) => {
eprintln!("connection refused: {}. try next.", host_port);
}
}
}
Err(io::Error::new(io::ErrorKind::ConnectionRefused, "failed to connect").into())
}
async fn handshake<S>(socket: &mut S, credential_: &str, method_bytes: &str) -> Result<()>
where
S: Unpin + AsyncWriteExt + AsyncReadExt,
{
let credential = credential_.to_string() + method_bytes;
write_all_cancellation_safe(socket, credential.as_bytes()).await?;
let mut cap = [0u8; 1];
if let Err(_) = socket.read_exact(&mut cap).await {
Err(io::Error::new(io::ErrorKind::ConnectionAborted, "authentication failure").into())
} else {
Ok(())
}
}
async fn connect_tcp(host: &str, port: u16, credential: &str) -> Result<TcpStream> {
let mut socket = connect_tcp_impl(host, port).await?;
handshake(&mut socket, credential, "\x03\x00").await?;
Ok(socket)
}
async fn connect_tls(host: &str, port: u16, credential: &str) -> Result<TlsStream<TcpStream>> {
let socket_ = connect_tcp_impl(host, port).await?;
let connector = TlsConnector::from(TlsConnectorInner::new().unwrap());
let mut socket = connector
.connect(host, socket_)
.await
.expect("failed to create TLS session");
handshake(&mut socket, credential, "\x03\x00").await?;
Ok(socket)
}
fn create_sockfile_path(port: u16) -> Result<String> {
let udspath = match env::var("QUDSPATH") {
Ok(dir) => format!("{}/kx.{}", dir, port),
Err(_) => format!("/tmp/kx.{}", port),
};
Ok(udspath)
}
#[cfg(unix)]
async fn connect_uds(port: u16, credential: &str) -> Result<UnixStream> {
let uds_path = create_sockfile_path(port)?;
let abstract_sockfile_ = format!("\x00{}", uds_path);
let abstract_sockfile = Path::new(&abstract_sockfile_);
let mut socket = UnixStream::connect(&abstract_sockfile).await?;
handshake(&mut socket, credential, "\x06\x00").await?;
Ok(socket)
}
async fn read_client_input<S>(socket: &mut S) -> Result<()>
where
S: Unpin + AsyncWriteExt + AsyncReadExt,
{
let mut client_input = [0u8; 32];
let mut passed_credential = String::new();
loop {
match socket.read(&mut client_input).await {
Ok(0) => {
}
Ok(_) => {
if let Some(index) = client_input
.iter()
.position(|byte| *byte == 0x03 || *byte == 0x06)
{
let capacity = client_input[index];
passed_credential
.push_str(str::from_utf8(&client_input[0..index]).expect("invalid bytes"));
let credential = passed_credential.as_str().split(':').collect::<Vec<&str>>();
if let Some(encoded) = ACCOUNTS.get(&credential[0].to_string()) {
let mut hasher = Sha1::new();
hasher.update(credential[1].as_bytes());
let encoded_password = hasher.digest().to_string();
if encoded == &encoded_password {
socket.write_all(&[capacity; 1]).await?;
return Ok(());
} else {
socket.shutdown().await?;
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"authentication failed",
)
.into());
}
} else {
socket.shutdown().await?;
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"authentication failed",
)
.into());
}
} else {
passed_credential
.push_str(str::from_utf8(&client_input).expect("invalid bytes"));
}
}
Err(error) => {
return Err(error.into());
}
}
}
}
async fn build_identity_from_cert() -> Result<Identity> {
if let Ok(path) = env::var("KDBPLUS_TLS_KEY_FILE") {
if let Ok(password) = env::var("KDBPLUS_TLS_KEY_FILE_SECRET") {
let cert_file = tokio::fs::File::open(Path::new(&path)).await.unwrap();
let mut reader = BufReader::new(cert_file);
let mut der: Vec<u8> = Vec::new();
reader.read_to_end(&mut der).await?;
if let Ok(identity) = Identity::from_pkcs12(&der, &password) {
return Ok(identity);
} else {
return Err(
io::Error::new(io::ErrorKind::InvalidData, "authentication failed").into(),
);
}
} else {
return Err(io::Error::new(
io::ErrorKind::NotFound,
"KDBPLUS_TLS_KEY_FILE_SECRET is not set",
)
.into());
}
} else {
return Err(
io::Error::new(io::ErrorKind::NotFound, "KDBPLUS_TLS_KEY_FILE is not set").into(),
);
}
}
async fn read_exact_cancellation_safe<S>(socket: &mut S, buffer: &mut [u8]) -> Result<usize>
where
S: Unpin + AsyncReadExt,
{
let mut read_total = 0;
let to_read = buffer.len();
loop {
read_total += socket.read(buffer).await?;
if read_total == to_read {
break;
}
}
Ok(read_total)
}
async fn write_all_cancellation_safe<S>(socket: &mut S, buffer: &[u8]) -> Result<usize>
where
S: Unpin + AsyncWriteExt,
{
let mut write_total = 0;
let to_write = buffer.len();
loop {
write_total += socket.write(&buffer[write_total..]).await?;
if write_total == to_write {
break;
}
}
Ok(write_total)
}
async fn receive_message<S>(socket: &mut S) -> Result<(u8, K)>
where
S: Unpin + AsyncReadExt,
{
let mut header_buffer = [0u8; 8];
if let Err(err) = read_exact_cancellation_safe(socket, &mut header_buffer).await {
return Err(io::Error::new(
io::ErrorKind::ConnectionAborted,
format!("Connection dropped: {}", err),
)
.into());
}
let header = MessageHeader::from_bytes(header_buffer);
let body_length = header.length as usize - MessageHeader::size();
let mut body: Vec<u8> = Vec::with_capacity(body_length);
body.resize(body_length, 0_u8);
if let Err(err) = read_exact_cancellation_safe(socket, &mut body).await {
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
format!("Failed to read body of message: {}", err),
)
.into());
}
if header.compressed == 0x01 {
body = decompress(body, header.encoding).await;
}
Ok((
header.message_type,
K::q_ipc_decode(&body, header.encoding).await,
))
}
async fn compress(raw: Vec<u8>) -> (bool, Vec<u8>) {
let mut i = 0_u8;
let mut f = 0_u8;
let mut h0 = 0_usize;
let mut h = 0_usize;
let mut g: bool;
let mut compressed: Vec<u8> = Vec::with_capacity((raw.len()) / 2);
compressed.resize((raw.len()) / 2, 0_u8);
let mut c = 12;
let mut d = c;
let e = compressed.len();
let mut p = 0_usize;
let mut q: usize;
let mut r: usize;
let mut s0 = 0_usize;
let mut s = 8_usize;
let t = raw.len();
let mut a = [0_i32; 256];
compressed[0..4].copy_from_slice(&raw[0..4]);
compressed[2] = 1;
let raw_size = match ENCODING {
0 => (t as u32).to_be_bytes(),
_ => (t as u32).to_le_bytes(),
};
compressed[8..12].copy_from_slice(&raw_size);
while s < t {
if i == 0 {
if d > e - 17 {
return (false, raw);
}
i = 1;
compressed[c] = f;
c = d;
d += 1;
f = 0;
}
g = s > t - 3;
if !g {
h = (raw[s] ^ raw[s + 1]) as usize;
p = a[h] as usize;
g = (0 == p) || (0 != (raw[s] ^ raw[p]));
}
if 0 < s0 {
a[h0] = s0 as i32;
s0 = 0;
}
if g {
h0 = h;
s0 = s;
compressed[d] = raw[s];
d += 1;
s += 1;
} else {
a[h] = s as i32;
f |= i;
p += 2;
s += 2;
r = s;
q = if s + 255 > t { t } else { s + 255 };
while (s < q) && (raw[p] == raw[s]) {
s += 1;
if s < q {
p += 1;
}
}
compressed[d] = h as u8;
d += 1;
compressed[d] = (s - r) as u8;
d += 1;
}
i = i.wrapping_mul(2);
}
compressed[c] = f;
let compressed_size = match ENCODING {
0 => (d as u32).to_be_bytes(),
_ => (d as u32).to_le_bytes(),
};
compressed[4..8].copy_from_slice(&compressed_size);
let _ = compressed.split_off(d);
(true, compressed)
}
async fn decompress(compressed: Vec<u8>, encoding: u8) -> Vec<u8> {
let mut n = 0;
let mut r: usize;
let mut f = 0_usize;
let mut s = 0_usize;
let mut p = s;
let mut i = 0_usize;
let size = match encoding {
0 => {
i32::from_be_bytes(
compressed[0..4]
.try_into()
.expect("slice does not have length 4"),
) - 8
}
_ => {
i32::from_le_bytes(
compressed[0..4]
.try_into()
.expect("slice does not have length 4"),
) - 8
}
};
let mut decompressed: Vec<u8> = Vec::with_capacity(size as usize);
decompressed.resize(size as usize, 0_u8);
let mut d = 4;
let mut aa = [0_i32; 256];
while s < decompressed.len() {
if i == 0 {
f = (0xff & compressed[d]) as usize;
d += 1;
i = 1;
}
if (f & i) != 0 {
r = aa[(0xff & compressed[d]) as usize] as usize;
d += 1;
decompressed[s] = decompressed[r];
s += 1;
r += 1;
decompressed[s] = decompressed[r];
s += 1;
r += 1;
n = (0xff & compressed[d]) as usize;
d += 1;
for m in 0..n {
decompressed[s + m] = decompressed[r + m];
}
} else {
decompressed[s] = compressed[d];
s += 1;
d += 1;
}
while p < s - 1 {
aa[((0xff & decompressed[p]) ^ (0xff & decompressed[p + 1])) as usize] = p as i32;
p += 1;
}
if (f & i) != 0 {
s += n;
p = s;
}
i *= 2;
if i == 256 {
i = 0;
}
}
decompressed
}