use super::codec::{CompressionMode, KdbCodec, KdbMessage, ValidationMode};
use super::Result;
use super::K;
use futures::{SinkExt, StreamExt};
use io::BufRead;
use once_cell::sync::Lazy;
use sha1_smol::Sha1;
use std::collections::HashMap;
use std::net::{IpAddr, Ipv4Addr};
use std::path::Path;
use std::{env, fs, io, str};
use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader};
use tokio::net::{TcpListener, TcpStream};
#[cfg(unix)]
use tokio::net::{UnixListener, UnixStream};
use tokio_native_tls::native_tls::{
Identity, TlsAcceptor as TlsAcceptorInner, TlsConnector as TlsConnectorInner,
};
use tokio_native_tls::{TlsAcceptor, TlsConnector, TlsStream};
use tokio_util::codec::Framed;
use trust_dns_resolver::TokioAsyncResolver;
pub mod qmsg_type {
pub const asynchronous: u8 = 0;
pub const synchronous: u8 = 1;
pub const response: u8 = 2;
}
const DEFAULT_ACCOUNT_FILE: &str = "credential/kdbaccess";
const ACCOUNT_FILE_ENV: &str = "KDBPLUS_ACCOUNT_FILE";
const ACCOUNTS: Lazy<HashMap<String, String>> = Lazy::new(|| {
let mut map: HashMap<String, String> = HashMap::new();
let path = env::var(ACCOUNT_FILE_ENV).unwrap_or_else(|_| DEFAULT_ACCOUNT_FILE.to_string());
let file = match fs::OpenOptions::new().read(true).open(&path) {
Ok(f) => f,
Err(_) => return map,
};
let mut reader = io::BufReader::new(file);
let mut line = String::new();
loop {
match reader.read_line(&mut line) {
Ok(0) => {
break;
}
Ok(_) => {
let credential: Vec<&str> = line.trim_end().split(':').collect();
if credential.len() >= 2 {
map.insert(credential[0].to_string(), credential[1].to_string());
}
line.clear();
}
Err(_) => break,
}
}
map
});
pub enum ConnectionMethod {
TCP = 0,
TLS = 1,
UDS = 2,
}
pub trait Query: Send + Sync {
fn to_kdb_message(&self, message_type: u8) -> KdbMessage;
}
enum FramedStream {
Tcp(Framed<TcpStream, KdbCodec>),
Tls(Framed<TlsStream<TcpStream>, KdbCodec>),
#[cfg(unix)]
Uds(Framed<UnixStream, KdbCodec>),
}
pub struct QStream {
stream: FramedStream,
method: ConnectionMethod,
listener: bool,
}
impl Query for &str {
fn to_kdb_message(&self, message_type: u8) -> KdbMessage {
let k_string = K::new_string(self.to_string(), 0);
KdbMessage::new(message_type, k_string)
}
}
impl Query for K {
fn to_kdb_message(&self, message_type: u8) -> KdbMessage {
KdbMessage::new(message_type, self.clone())
}
}
#[bon::bon]
impl QStream {
fn new(stream: FramedStream, method: ConnectionMethod, is_listener: bool) -> Self {
QStream {
stream,
method,
listener: is_listener,
}
}
#[builder(on(String, into), on(&str, into))]
pub async fn builder(
method: ConnectionMethod,
#[builder(default = String::new())] host: String,
port: u16,
#[builder(default = String::new())] credential: String,
#[builder(default)] compression_mode: CompressionMode,
#[builder(default)] validation_mode: ValidationMode,
) -> Result<Self> {
Self::connect_with_options(
method,
&host,
port,
&credential,
compression_mode,
validation_mode,
)
.await
}
pub async fn connect(
method: ConnectionMethod,
host: &str,
port: u16,
credential: &str,
) -> Result<Self> {
Self::connect_with_options(
method,
host,
port,
credential,
CompressionMode::Auto,
ValidationMode::Strict,
)
.await
}
pub async fn connect_with_options(
method: ConnectionMethod,
host: &str,
port: u16,
credential: &str,
compression_mode: CompressionMode,
validation_mode: ValidationMode,
) -> Result<Self> {
match method {
ConnectionMethod::TCP => {
let stream = connect_tcp(host, port, credential).await?;
let is_local = matches!(host, "localhost" | "127.0.0.1");
let codec = KdbCodec::builder()
.is_local(is_local)
.compression_mode(compression_mode)
.validation_mode(validation_mode)
.build();
let framed = Framed::new(stream, codec);
Ok(QStream::new(
FramedStream::Tcp(framed),
ConnectionMethod::TCP,
false,
))
}
ConnectionMethod::TLS => {
let stream = connect_tls(host, port, credential).await?;
let codec = KdbCodec::builder()
.is_local(false)
.compression_mode(compression_mode)
.validation_mode(validation_mode)
.build(); let framed = Framed::new(stream, codec);
Ok(QStream::new(
FramedStream::Tls(framed),
ConnectionMethod::TLS,
false,
))
}
ConnectionMethod::UDS => {
let stream = connect_uds(port, credential).await?;
let codec = KdbCodec::builder()
.is_local(true)
.compression_mode(compression_mode)
.validation_mode(validation_mode)
.build(); let framed = Framed::new(stream, codec);
Ok(QStream::new(
FramedStream::Uds(framed),
ConnectionMethod::UDS,
false,
))
}
}
}
pub async fn accept(method: ConnectionMethod, host: &str, port: u16) -> Result<Self> {
Self::accept_with_options(
method,
host,
port,
CompressionMode::Auto,
ValidationMode::Strict,
)
.await
}
pub async fn accept_with_options(
method: ConnectionMethod,
host: &str,
port: u16,
compression_mode: CompressionMode,
validation_mode: ValidationMode,
) -> Result<Self> {
match method {
ConnectionMethod::TCP => {
let listener = TcpListener::bind(&format!("{}:{}", host, port)).await?;
let (mut socket, ip_address) = listener.accept().await?;
while let Err(_) = read_client_input(&mut socket).await {
socket = listener.accept().await?.0;
}
let is_local = ip_address.ip() == IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
let codec = KdbCodec::builder()
.is_local(is_local)
.compression_mode(compression_mode)
.validation_mode(validation_mode)
.build();
let framed = Framed::new(socket, codec);
Ok(QStream::new(
FramedStream::Tcp(framed),
ConnectionMethod::TCP,
true,
))
}
ConnectionMethod::TLS => {
let listener = TcpListener::bind(&format!("{}:{}", host, port)).await?;
let identity = build_identity_from_cert().await?;
let tls_acceptor = TlsAcceptor::from(TlsAcceptorInner::new(identity).unwrap());
let (mut socket, _) = listener.accept().await?;
let mut tls_socket = tls_acceptor
.accept(socket)
.await
.expect("failed to accept TLS connection");
while let Err(_) = read_client_input(&mut tls_socket).await {
socket = listener.accept().await?.0;
tls_socket = tls_acceptor
.accept(socket)
.await
.expect("failed to accept TLS connection");
}
let codec = KdbCodec::builder()
.is_local(false)
.compression_mode(compression_mode)
.validation_mode(validation_mode)
.build();
let framed = Framed::new(tls_socket, codec);
let mut qstream =
QStream::new(FramedStream::Tls(framed), ConnectionMethod::TLS, true);
qstream
.send_async_message(&".kdbplus.close_tls_connection_:{[] hclose .z.w;}")
.await?;
Ok(qstream)
}
ConnectionMethod::UDS => {
let uds_path = create_sockfile_path(port)?;
let abstract_sockfile_ = format!("\x00{}", uds_path);
let abstract_sockfile = Path::new(&abstract_sockfile_);
let listener = UnixListener::bind(&abstract_sockfile).unwrap();
let (mut socket, _) = listener.accept().await?;
while let Err(_) = read_client_input(&mut socket).await {
socket = listener.accept().await?.0;
}
let codec = KdbCodec::builder()
.is_local(true)
.compression_mode(compression_mode)
.validation_mode(validation_mode)
.build();
let framed = Framed::new(socket, codec);
Ok(QStream::new(
FramedStream::Uds(framed),
ConnectionMethod::UDS,
true,
))
}
}
}
pub async fn shutdown(mut self) -> Result<()> {
if self.listener && matches!(self.method, ConnectionMethod::TLS) {
self.send_async_message(&".kdbplus.close_tls_connection_[]")
.await?;
}
match self.stream {
FramedStream::Tcp(framed) => {
AsyncWriteExt::shutdown(&mut framed.into_inner()).await?;
}
FramedStream::Tls(framed) => {
if !self.listener {
framed.into_inner().get_mut().shutdown()?;
}
}
#[cfg(unix)]
FramedStream::Uds(framed) => {
AsyncWriteExt::shutdown(&mut framed.into_inner()).await?;
}
}
Ok(())
}
pub async fn send_message(&mut self, message: &dyn Query, message_type: u8) -> Result<()> {
let kdb_message = message.to_kdb_message(message_type);
match &mut self.stream {
FramedStream::Tcp(framed) => {
framed.send(kdb_message).await?;
}
FramedStream::Tls(framed) => {
framed.send(kdb_message).await?;
}
#[cfg(unix)]
FramedStream::Uds(framed) => {
framed.send(kdb_message).await?;
}
}
Ok(())
}
pub async fn send_async_message(&mut self, message: &dyn Query) -> Result<()> {
self.send_message(message, qmsg_type::asynchronous).await
}
pub async fn send_sync_message(&mut self, message: &dyn Query) -> Result<K> {
self.send_message(message, qmsg_type::synchronous).await?;
match self.receive_message().await? {
(qmsg_type::response, response) => Ok(response),
(_, message) => Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("expected a response: {}", message),
)
.into()),
}
}
pub async fn receive_message(&mut self) -> Result<(u8, K)> {
match &mut self.stream {
FramedStream::Tcp(framed) => match framed.next().await {
Some(Ok(response)) => Ok((response.message_type, response.payload)),
Some(Err(e)) => Err(io::Error::new(
io::ErrorKind::ConnectionAborted,
format!("Connection dropped: {}", e),
)
.into()),
None => Err(
io::Error::new(io::ErrorKind::ConnectionAborted, "Connection closed").into(),
),
},
FramedStream::Tls(framed) => match framed.next().await {
Some(Ok(response)) => Ok((response.message_type, response.payload)),
Some(Err(e)) => Err(io::Error::new(
io::ErrorKind::ConnectionAborted,
format!("Connection dropped: {}", e),
)
.into()),
None => Err(
io::Error::new(io::ErrorKind::ConnectionAborted, "Connection closed").into(),
),
},
#[cfg(unix)]
FramedStream::Uds(framed) => match framed.next().await {
Some(Ok(response)) => Ok((response.message_type, response.payload)),
Some(Err(e)) => Err(io::Error::new(
io::ErrorKind::ConnectionAborted,
format!("Connection dropped: {}", e),
)
.into()),
None => Err(
io::Error::new(io::ErrorKind::ConnectionAborted, "Connection closed").into(),
),
},
}
}
pub fn get_connection_type(&self) -> &str {
match self.method {
ConnectionMethod::TCP => "TCP",
ConnectionMethod::TLS => "TLS",
ConnectionMethod::UDS => "UDS",
}
}
}
async fn connect_tcp_impl(host: &str, port: u16) -> Result<TcpStream> {
let resolver =
TokioAsyncResolver::tokio_from_system_conf().expect("failed to create DNS resolver");
let ips;
if let Ok(ip) = host.parse::<IpAddr>() {
ips = vec![ip.to_string()]
} else {
let response = resolver
.lookup_ip(host)
.await
.expect(&format!("failed to resolve host: {}", host));
ips = response.iter().map(|ip| ip.to_string()).collect();
}
for answer in ips {
match TcpStream::connect(format!("{}:{}", answer, port)).await {
Ok(socket) => return Ok(socket),
Err(_) => continue,
}
}
Err(io::Error::new(io::ErrorKind::ConnectionRefused, "failed to connect").into())
}
pub async fn handshake<S>(socket: &mut S, credential_: &str, method_bytes: &str) -> Result<()>
where
S: Unpin + AsyncWriteExt + AsyncReadExt,
{
let mut credential = credential_.to_string();
credential.push_str(method_bytes);
socket.write_all(credential.as_bytes()).await?;
let mut capacity = [0u8; 1];
socket.read_exact(&mut capacity).await?;
Ok(())
}
async fn connect_tcp(host: &str, port: u16, credential: &str) -> Result<TcpStream> {
let mut socket = connect_tcp_impl(host, port).await?;
handshake(&mut socket, credential, "\x03\x00").await?;
Ok(socket)
}
async fn connect_tls(host: &str, port: u16, credential: &str) -> Result<TlsStream<TcpStream>> {
let socket_ = connect_tcp_impl(host, port).await?;
let connector = TlsConnector::from(TlsConnectorInner::new().unwrap());
let mut socket = connector
.connect(host, socket_)
.await
.expect("failed to create TLS session");
handshake(&mut socket, credential, "\x03\x00").await?;
Ok(socket)
}
fn create_sockfile_path(port: u16) -> Result<String> {
let udspath = match env::var("QUDSPATH") {
Ok(dir) => format!("{}/kx.{}", dir, port),
Err(_) => format!("/tmp/kx.{}", port),
};
Ok(udspath)
}
#[cfg(unix)]
async fn connect_uds(port: u16, credential: &str) -> Result<UnixStream> {
let uds_path = create_sockfile_path(port)?;
let abstract_sockfile_ = format!("\x00{}", uds_path);
let abstract_sockfile = Path::new(&abstract_sockfile_);
let mut socket = UnixStream::connect(&abstract_sockfile).await?;
handshake(&mut socket, credential, "\x06\x00").await?;
Ok(socket)
}
async fn read_client_input<S>(socket: &mut S) -> Result<()>
where
S: Unpin + AsyncWriteExt + AsyncReadExt,
{
let debug_auth = matches!(std::env::var("KDBPLUS_DEBUG_AUTH").ok().as_deref(), Some("1"));
let mut client_input = [0u8; 32];
let mut passed_credential = String::new();
loop {
match socket.read(&mut client_input).await {
Ok(0) => {
socket.shutdown().await?;
return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "client disconnected").into());
}
Ok(n) => {
let chunk = &client_input[..n];
if let Some(index) = chunk.iter().position(|byte| *byte == 0x03 || *byte == 0x06)
{
let capacity = chunk[index];
passed_credential
.push_str(str::from_utf8(&chunk[0..index]).expect("invalid bytes"));
let credential = passed_credential.as_str().split(':').collect::<Vec<&str>>();
if credential.len() < 2 {
if debug_auth {
eprintln!("[acceptor auth] invalid credential format");
}
socket.shutdown().await?;
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"authentication failed",
)
.into());
}
if debug_auth {
eprintln!(
"[acceptor auth] user='{}' capacity=0x{:02x}",
credential[0], capacity
);
}
if let Some(encoded) = ACCOUNTS.get(&credential[0].to_string()) {
let mut hasher = Sha1::new();
hasher.update(credential[1].as_bytes());
let encoded_password = hasher.digest().to_string();
if encoded == &encoded_password {
if debug_auth {
eprintln!("[acceptor auth] success");
}
socket.write_all(&[capacity; 1]).await?;
return Ok(());
} else {
if debug_auth {
eprintln!("[acceptor auth] password mismatch");
}
socket.shutdown().await?;
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"authentication failed",
)
.into());
}
} else {
if debug_auth {
eprintln!("[acceptor auth] unknown user");
}
socket.shutdown().await?;
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"authentication failed",
)
.into());
}
} else {
passed_credential
.push_str(str::from_utf8(chunk).expect("invalid bytes"));
}
}
Err(error) => {
return Err(error.into());
}
}
}
}
async fn build_identity_from_cert() -> Result<Identity> {
if let Ok(path) = env::var("KDBPLUS_TLS_KEY_FILE") {
if let Ok(password) = env::var("KDBPLUS_TLS_KEY_FILE_SECRET") {
let cert_file = tokio::fs::File::open(Path::new(&path)).await.unwrap();
let mut reader = BufReader::new(cert_file);
let mut der: Vec<u8> = Vec::new();
reader.read_to_end(&mut der).await?;
if let Ok(identity) = Identity::from_pkcs12(&der, &password) {
return Ok(identity);
} else {
return Err(
io::Error::new(io::ErrorKind::InvalidData, "authentication failed").into(),
);
}
} else {
return Err(io::Error::new(
io::ErrorKind::NotFound,
"KDBPLUS_TLS_KEY_FILE_SECRET is not set",
)
.into());
}
} else {
return Err(
io::Error::new(io::ErrorKind::NotFound, "KDBPLUS_TLS_KEY_FILE is not set").into(),
);
}
}