use std::fmt;
use std::fs::File;
use std::io::{BufReader, Read, Write};
use std::net::{Shutdown, TcpStream, ToSocketAddrs};
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::time::Duration;
use rustls::client::ServerName;
use rustls::{ClientConfig as TlsClientConfig, ClientConnection, RootCertStore, StreamOwned};
use rustls_pemfile::certs;
use crate::net::BasicAuthConfig;
use crate::types::{StoreKey as Key, MAX_KEY_BYTES, MAX_VALUE_BYTES, MIN_KEY_BYTES};
const KEY_BYTES_FIELD_WIDTH: usize = 2;
const VALUE_LEN_WIDTH: usize = 2;
const MAX_KEYS_PER_REQUEST: usize = 255;
const AUTH_READY: u8 = 0;
#[derive(Clone, Debug)]
pub struct ClientConfig {
pub security: ClientSecurity,
pub auth: BasicAuthConfig,
pub timeout: Option<Duration>,
}
#[derive(Clone, Debug)]
pub enum ClientSecurity {
Tls {
server_name: String,
ca_certificate: PathBuf,
},
Plain,
}
impl ClientConfig {
pub fn new(
server_name: impl Into<String>,
ca_certificate: impl Into<PathBuf>,
auth: BasicAuthConfig,
) -> Self {
Self::with_tls(server_name, ca_certificate, auth)
}
pub fn with_tls(
server_name: impl Into<String>,
ca_certificate: impl Into<PathBuf>,
auth: BasicAuthConfig,
) -> Self {
Self {
security: ClientSecurity::Tls {
server_name: server_name.into(),
ca_certificate: ca_certificate.into(),
},
auth,
timeout: None,
}
}
pub fn without_tls(auth: BasicAuthConfig) -> Self {
Self {
security: ClientSecurity::Plain,
auth,
timeout: None,
}
}
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.timeout = Some(timeout);
self
}
}
pub struct RemoteStoreClient {
stream: ClientStream,
request_buf: Vec<u8>,
response_buf: Vec<u8>,
key_bytes: usize,
}
impl fmt::Debug for RemoteStoreClient {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("RemoteStoreClient")
.field("key_bytes", &self.key_bytes)
.field("request_buf_len", &self.request_buf.len())
.field("response_buf_len", &self.response_buf.len())
.finish_non_exhaustive()
}
}
impl RemoteStoreClient {
pub fn connect<A: ToSocketAddrs>(addr: A, config: ClientConfig) -> Result<Self, ClientError> {
let ClientConfig {
security,
auth,
timeout,
} = config;
let stream = TcpStream::connect(addr).map_err(map_io)?;
stream.set_nodelay(true).ok();
if let Some(timeout) = timeout {
stream.set_read_timeout(Some(timeout)).map_err(map_io)?;
stream.set_write_timeout(Some(timeout)).map_err(map_io)?;
}
let mut stream = match security {
ClientSecurity::Tls {
server_name,
ca_certificate,
} => {
let tls_config = build_tls_client(&ca_certificate)?;
let server_name = ServerName::try_from(server_name.as_str())
.map_err(|_| ClientError::InvalidDnsName(server_name.clone()))?;
let connection = ClientConnection::new(tls_config, server_name)?;
ClientStream::Tls(Box::new(StreamOwned::new(connection, stream)))
}
ClientSecurity::Plain => ClientStream::Plain(stream),
};
let mut header_line = auth.authorization_header();
header_line.push('\n');
stream.write_all(header_line.as_bytes()).map_err(map_io)?;
stream.flush().map_err(map_io)?;
let mut ack = [0u8; 1];
let received = stream.read(&mut ack).map_err(map_io)?;
if received == 0 {
return Err(ClientError::ConnectionClosed {
received: 0,
expected: 1,
});
}
if ack[0] != AUTH_READY {
return Err(ClientError::Server { code: ack[0] });
}
let mut key_bytes_buf = [0u8; KEY_BYTES_FIELD_WIDTH];
read_response_chunk(&mut stream, &mut key_bytes_buf, false)?;
let negotiated_key_bytes = u16::from_le_bytes(key_bytes_buf) as usize;
if negotiated_key_bytes < MIN_KEY_BYTES {
return Err(ClientError::InvalidResponse("key width below minimum"));
}
if negotiated_key_bytes > MAX_KEY_BYTES {
return Err(ClientError::InvalidResponse("key width above maximum"));
}
if negotiated_key_bytes != Key::BYTES {
return Err(ClientError::KeyWidthMismatch {
server: negotiated_key_bytes,
client: Key::BYTES,
});
}
Ok(Self {
stream,
request_buf: Vec::with_capacity(1 + negotiated_key_bytes * MAX_KEYS_PER_REQUEST),
response_buf: Vec::with_capacity(1024),
key_bytes: negotiated_key_bytes,
})
}
pub fn get(&mut self, keys: &[Key]) -> Result<Vec<Vec<u8>>, ClientError> {
if keys.is_empty() {
return Err(ClientError::InvalidRequest(
"at least one key must be provided",
));
}
if keys.len() > MAX_KEYS_PER_REQUEST {
return Err(ClientError::InvalidRequest(
"too many keys in a single request",
));
}
self.request_buf.clear();
self.request_buf.push(keys.len() as u8);
for key in keys {
debug_assert_eq!(key.as_slice().len(), self.key_bytes);
self.request_buf.extend_from_slice(key.as_slice());
}
self.stream.write_all(&self.request_buf).map_err(map_io)?;
self.stream.flush().map_err(map_io)?;
let mut values = Vec::with_capacity(keys.len());
let mut len_buf = [0u8; VALUE_LEN_WIDTH];
for (index, _) in keys.iter().enumerate() {
read_response_chunk(&mut self.stream, &mut len_buf, index == 0)?;
let value_len = u16::from_le_bytes(len_buf) as usize;
if value_len > MAX_VALUE_BYTES {
return Err(ClientError::InvalidResponse(
"value length exceeds MAX_VALUE_BYTES",
));
}
if self.response_buf.len() < value_len {
self.response_buf.resize(value_len, 0);
}
if value_len > 0 {
let buffer = &mut self.response_buf[..value_len];
read_response_chunk(&mut self.stream, buffer, false)?;
}
values.push(self.response_buf[..value_len].to_vec());
}
Ok(values)
}
pub fn get_one(&mut self, key: Key) -> Result<Vec<u8>, ClientError> {
let values = self.get(std::slice::from_ref(&key))?;
Ok(values.into_iter().next().unwrap_or_default())
}
pub fn close(self) -> Result<(), ClientError> {
self.stream.close().map_err(map_io)
}
}
fn read_response_chunk(
stream: &mut ClientStream,
buffer: &mut [u8],
allow_error_code: bool,
) -> Result<(), ClientError> {
if buffer.is_empty() {
return Ok(());
}
let mut read = 0;
while read < buffer.len() {
let n = stream.read(&mut buffer[read..]).map_err(map_io)?;
if n == 0 {
if allow_error_code && read == 1 {
return Err(ClientError::Server { code: buffer[0] });
}
return Err(ClientError::ConnectionClosed {
received: read,
expected: buffer.len(),
});
}
read += n;
}
Ok(())
}
fn build_tls_client(path: &Path) -> Result<Arc<TlsClientConfig>, ClientError> {
let root_store = load_root_store(path)?;
let tls_config = TlsClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(root_store)
.with_no_client_auth();
Ok(Arc::new(tls_config))
}
fn load_root_store(path: &Path) -> Result<RootCertStore, ClientError> {
let file = File::open(path).map_err(|source| ClientError::IoPath {
path: path.to_path_buf(),
source,
})?;
let mut reader = BufReader::new(file);
let certs = certs(&mut reader)
.map_err(|err| ClientError::TlsConfig(format!("failed to parse {path:?}: {err}")))?;
if certs.is_empty() {
return Err(ClientError::TlsConfig(format!(
"no certificates found in {:?}",
path
)));
}
let mut store = RootCertStore::empty();
let (added, _) = store.add_parsable_certificates(&certs);
if added == 0 {
return Err(ClientError::TlsConfig(format!(
"no valid certificates in {:?}",
path
)));
}
Ok(store)
}
fn map_io(err: std::io::Error) -> ClientError {
match err.kind() {
std::io::ErrorKind::TimedOut | std::io::ErrorKind::WouldBlock => ClientError::Timeout,
_ => ClientError::Io(err),
}
}
#[derive(thiserror::Error, Debug)]
pub enum ClientError {
#[error("io error: {0}")]
Io(std::io::Error),
#[error("network operation timed out")]
Timeout,
#[error("tls error: {0}")]
Tls(#[from] rustls::Error),
#[error("invalid dns name: {0}")]
InvalidDnsName(String),
#[error("invalid request: {0}")]
InvalidRequest(&'static str),
#[error("invalid response from server: {0}")]
InvalidResponse(&'static str),
#[error("key width mismatch: server uses {server} bytes, client built with {client} bytes")]
KeyWidthMismatch { server: usize, client: usize },
#[error("server returned error code {code}")]
Server { code: u8 },
#[error("connection closed before response (received {received} of {expected} bytes)")]
ConnectionClosed { received: usize, expected: usize },
#[error("io error at {path:?}: {source}")]
IoPath {
path: PathBuf,
source: std::io::Error,
},
#[error("failed to load certificates: {0}")]
TlsConfig(String),
}
enum ClientStream {
Tls(Box<StreamOwned<ClientConnection, TcpStream>>),
Plain(TcpStream),
}
impl ClientStream {
fn close(self) -> std::io::Result<()> {
match self {
ClientStream::Tls(mut stream) => {
stream.conn.send_close_notify();
stream.flush()
}
ClientStream::Plain(stream) => stream.shutdown(Shutdown::Both),
}
}
}
impl Read for ClientStream {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
match self {
ClientStream::Tls(stream) => stream.read(buf),
ClientStream::Plain(stream) => stream.read(buf),
}
}
}
impl Write for ClientStream {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
match self {
ClientStream::Tls(stream) => stream.write(buf),
ClientStream::Plain(stream) => stream.write(buf),
}
}
fn flush(&mut self) -> std::io::Result<()> {
match self {
ClientStream::Tls(stream) => stream.flush(),
ClientStream::Plain(stream) => stream.flush(),
}
}
}