pub mod credential;
pub mod error;
pub mod pool;
pub mod proto;
pub use credential::Credential;
pub use error::Error;
pub use pool::{Pool, PoolConfig};
use std::collections::HashMap;
use std::net::SocketAddr;
use std::time::Instant;
use bytes::Bytes;
use ringline::{ConnCtx, ParseResult};
use crate::proto::{
CacheCommand, CacheResponse, CacheResponseResult, DecodedMessage, StatusCode, UnaryCommand,
decode_length_delimited_message_bytes,
};
pub const MAX_RECV_SKIPS: usize = 256;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct RequestId(u64);
impl RequestId {
pub fn value(&self) -> u64 {
self.0
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum CommandType {
Get,
Set,
Delete,
}
#[derive(Debug)]
#[non_exhaustive]
pub enum CompletedOp {
Get {
id: RequestId,
key: Bytes,
result: Result<Option<Bytes>, Error>,
user_data: u64,
latency_ns: u64,
},
Set {
id: RequestId,
key: Bytes,
result: Result<(), Error>,
user_data: u64,
latency_ns: u64,
},
Delete {
id: RequestId,
key: Bytes,
result: Result<(), Error>,
user_data: u64,
latency_ns: u64,
},
}
impl CompletedOp {
fn set_latency(self, latency_ns: u64) -> Self {
match self {
Self::Get {
id,
key,
result,
user_data,
..
} => Self::Get {
id,
key,
result,
user_data,
latency_ns,
},
Self::Set {
id,
key,
result,
user_data,
..
} => Self::Set {
id,
key,
result,
user_data,
latency_ns,
},
Self::Delete {
id,
key,
result,
user_data,
..
} => Self::Delete {
id,
key,
result,
user_data,
latency_ns,
},
}
}
}
#[derive(Debug, Clone)]
pub struct CommandResult {
pub command: CommandType,
pub latency_ns: u64,
pub success: bool,
pub ttfb_ns: Option<u64>,
pub tx_bytes: u32,
pub rx_bytes: u32,
}
type ResultCallback = Box<dyn Fn(&CommandResult)>;
enum PendingOpKind {
Get,
Set,
Delete,
}
struct PendingOp {
kind: PendingOpKind,
key: Bytes,
send_ts: u64,
start: Option<Instant>,
user_data: u64,
tx_bytes: u32,
}
#[cfg(feature = "metrics")]
pub struct ClientMetrics {
pub latency: histogram::Histogram,
pub requests: u64,
pub errors: u64,
}
#[cfg(feature = "metrics")]
impl ClientMetrics {
fn new() -> Self {
Self {
latency: histogram::Histogram::new(7, 64).unwrap(),
requests: 0,
errors: 0,
}
}
fn record(&mut self, result: &CommandResult) {
self.requests += 1;
let _ = self.latency.increment(result.latency_ns);
if !result.success {
self.errors += 1;
}
}
}
pub struct ClientBuilder {
conn: ConnCtx,
on_result: Option<ResultCallback>,
namespace: Bytes,
max_in_flight: usize,
#[cfg(feature = "timestamps")]
use_kernel_ts: bool,
#[cfg(feature = "metrics")]
with_metrics: bool,
}
impl ClientBuilder {
pub(crate) fn new(conn: ConnCtx) -> Self {
Self {
conn,
on_result: None,
namespace: Bytes::new(),
max_in_flight: usize::MAX,
#[cfg(feature = "timestamps")]
use_kernel_ts: false,
#[cfg(feature = "metrics")]
with_metrics: false,
}
}
pub fn max_in_flight(mut self, n: usize) -> Self {
self.max_in_flight = n;
self
}
pub fn namespace(mut self, ns: impl AsRef<[u8]>) -> Self {
self.namespace = Bytes::copy_from_slice(ns.as_ref());
self
}
pub fn on_result<F: Fn(&CommandResult) + 'static>(mut self, f: F) -> Self {
self.on_result = Some(Box::new(f));
self
}
#[cfg(feature = "timestamps")]
pub fn kernel_timestamps(mut self, enabled: bool) -> Self {
self.use_kernel_ts = enabled;
self
}
#[cfg(feature = "metrics")]
pub fn with_metrics(mut self) -> Self {
self.with_metrics = true;
self
}
pub fn build(self) -> Client {
Client {
conn: self.conn,
next_message_id: 1,
pending: HashMap::new(),
send_buf: Vec::with_capacity(4096),
on_result: self.on_result,
namespace: self.namespace,
max_in_flight: self.max_in_flight,
#[cfg(feature = "timestamps")]
use_kernel_ts: self.use_kernel_ts,
#[cfg(feature = "metrics")]
metrics: if self.with_metrics {
Some(ClientMetrics::new())
} else {
None
},
}
}
}
pub struct Client {
conn: ConnCtx,
next_message_id: u64,
pending: HashMap<u64, PendingOp>,
send_buf: Vec<u8>,
on_result: Option<ResultCallback>,
namespace: Bytes,
max_in_flight: usize,
#[cfg(feature = "timestamps")]
use_kernel_ts: bool,
#[cfg(feature = "metrics")]
metrics: Option<ClientMetrics>,
}
impl Client {
pub async fn connect(credential: &Credential) -> Result<Self, Error> {
let host = credential.host();
let port = credential.port();
let addr: SocketAddr = Self::resolve_addr(host, port)?;
let tls_host = credential.tls_host();
let conn = ringline::connect_tls(addr, tls_host)?.await?;
let mut client = Self {
conn,
next_message_id: 1,
pending: HashMap::new(),
send_buf: Vec::with_capacity(4096),
on_result: None,
namespace: Bytes::new(),
max_in_flight: usize::MAX,
#[cfg(feature = "timestamps")]
use_kernel_ts: false,
#[cfg(feature = "metrics")]
metrics: None,
};
client.authenticate(credential.token()).await?;
Ok(client)
}
pub async fn connect_with_timeout(
credential: &Credential,
timeout_ms: u64,
) -> Result<Self, Error> {
let host = credential.host();
let port = credential.port();
let addr: SocketAddr = Self::resolve_addr(host, port)?;
let tls_host = credential.tls_host();
let conn = ringline::connect_tls_with_timeout(addr, tls_host, timeout_ms)?.await?;
let mut client = Self {
conn,
next_message_id: 1,
pending: HashMap::new(),
send_buf: Vec::with_capacity(4096),
on_result: None,
namespace: Bytes::new(),
max_in_flight: usize::MAX,
#[cfg(feature = "timestamps")]
use_kernel_ts: false,
#[cfg(feature = "metrics")]
metrics: None,
};
client.authenticate(credential.token()).await?;
Ok(client)
}
pub fn builder(conn: ConnCtx) -> ClientBuilder {
ClientBuilder::new(conn)
}
pub fn conn(&self) -> ConnCtx {
self.conn
}
#[cfg(feature = "metrics")]
pub fn metrics(&self) -> Option<&ClientMetrics> {
self.metrics.as_ref()
}
#[cfg(feature = "metrics")]
pub fn metrics_mut(&mut self) -> Option<&mut ClientMetrics> {
self.metrics.as_mut()
}
pub fn set_namespace(&mut self, ns: impl AsRef<[u8]>) {
self.namespace = Bytes::copy_from_slice(ns.as_ref());
}
pub fn pending_count(&self) -> usize {
self.pending.len()
}
fn poison(&mut self) {
self.pending.clear();
self.conn.close();
}
#[inline]
fn check_in_flight(&self) -> Result<(), Error> {
if self.pending.len() >= self.max_in_flight {
Err(Error::TooManyInFlight)
} else {
Ok(())
}
}
pub fn fire_get(
&mut self,
cache: &str,
key: &[u8],
user_data: u64,
) -> Result<RequestId, Error> {
self.check_in_flight()?;
let message_id = self.next_id();
let ns = self.namespace_for(cache);
let key = Bytes::copy_from_slice(key);
let cmd = CacheCommand::new(
message_id,
UnaryCommand::Get {
namespace: ns,
key: key.clone(),
},
);
self.send_command(&cmd)?;
let tx_bytes = self.send_buf.len() as u32;
let (send_ts, start) = self.timing_start();
self.pending.insert(
message_id,
PendingOp {
kind: PendingOpKind::Get,
key,
send_ts,
start,
user_data,
tx_bytes,
},
);
Ok(RequestId(message_id))
}
pub fn fire_set(
&mut self,
cache: &str,
key: &[u8],
value: &[u8],
ttl_ms: u64,
user_data: u64,
) -> Result<RequestId, Error> {
self.check_in_flight()?;
let message_id = self.next_id();
let ns = self.namespace_for(cache);
let key = Bytes::copy_from_slice(key);
let cmd = CacheCommand::new(
message_id,
UnaryCommand::Set {
namespace: ns,
key: key.clone(),
value: Bytes::copy_from_slice(value),
ttl_millis: ttl_ms,
},
);
self.send_command(&cmd)?;
let tx_bytes = self.send_buf.len() as u32;
let (send_ts, start) = self.timing_start();
self.pending.insert(
message_id,
PendingOp {
kind: PendingOpKind::Set,
key,
send_ts,
start,
user_data,
tx_bytes,
},
);
Ok(RequestId(message_id))
}
pub fn fire_delete(
&mut self,
cache: &str,
key: &[u8],
user_data: u64,
) -> Result<RequestId, Error> {
self.check_in_flight()?;
let message_id = self.next_id();
let ns = self.namespace_for(cache);
let key = Bytes::copy_from_slice(key);
let cmd = CacheCommand::new(
message_id,
UnaryCommand::Delete {
namespace: ns,
key: key.clone(),
},
);
self.send_command(&cmd)?;
let tx_bytes = self.send_buf.len() as u32;
let (send_ts, start) = self.timing_start();
self.pending.insert(
message_id,
PendingOp {
kind: PendingOpKind::Delete,
key,
send_ts,
start,
user_data,
tx_bytes,
},
);
Ok(RequestId(message_id))
}
pub async fn recv(&mut self) -> Result<CompletedOp, Error> {
if self.pending.is_empty() {
return Err(Error::NoPending);
}
let mut skips = 0usize;
let (dr, total_bytes) = loop {
let pending = &mut self.pending;
let mut dispatch_result: Option<DispatchResult> = None;
let mut malformed = false;
let mut oversize = false;
let n = self
.conn
.with_bytes(
|bytes| match decode_length_delimited_message_bytes(&bytes) {
DecodedMessage::Message(consumed, msg_bytes) => {
if let Some(response) = CacheResponse::decode_bytes(msg_bytes) {
dispatch_result = dispatch_response(response, pending);
} else {
malformed = true;
}
ParseResult::Consumed(consumed)
}
DecodedMessage::Incomplete => ParseResult::Consumed(0),
DecodedMessage::Oversize => {
oversize = true;
ParseResult::Consumed(0)
}
},
)
.await;
if oversize {
self.poison();
return Err(Error::Protocol(
"inbound message exceeded MAX_MESSAGE_SIZE".into(),
));
}
if n == 0 {
self.poison();
return Err(Error::ConnectionClosed);
}
if malformed {
self.poison();
return Err(Error::Protocol("failed to decode response".into()));
}
if let Some(dr) = dispatch_result {
break (dr, n);
}
if self.pending.is_empty() {
return Err(Error::NoPending);
}
skips += 1;
if skips > MAX_RECV_SKIPS {
self.poison();
return Err(Error::Protocol(
"too many consecutive unmatched responses".into(),
));
}
};
let n = total_bytes;
let recv_ts = self.capture_recv_ts();
let rx_bytes = n as u32;
let ttfb_ns = Self::ttfb_from_timestamps(recv_ts, dr.send_ts);
let latency_ns = if self.is_instrumented() {
let latency_ns = self.finish_timing(recv_ts, dr.send_ts, dr.start);
self.record(&CommandResult {
command: dr.cmd_type,
latency_ns,
success: dr.success,
ttfb_ns,
tx_bytes: dr.tx_bytes,
rx_bytes,
});
latency_ns
} else {
0
};
Ok(dr.op.set_latency(latency_ns))
}
#[inline]
fn check_no_pending(&self) -> Result<(), Error> {
if self.pending.is_empty() {
Ok(())
} else {
Err(Error::PendingOpsInFlight)
}
}
pub async fn get(&mut self, cache: &str, key: &[u8]) -> Result<Option<Bytes>, Error> {
self.check_no_pending()?;
let _id = self.fire_get(cache, key, 0)?;
match self.recv().await? {
CompletedOp::Get { result, .. } => result,
_ => Err(Error::Protocol("unexpected response type".into())),
}
}
pub async fn set(
&mut self,
cache: &str,
key: &[u8],
value: &[u8],
ttl_ms: u64,
) -> Result<(), Error> {
self.check_no_pending()?;
let _id = self.fire_set(cache, key, value, ttl_ms, 0)?;
match self.recv().await? {
CompletedOp::Set { result, .. } => result,
_ => Err(Error::Protocol("unexpected response type".into())),
}
}
pub async fn delete(&mut self, cache: &str, key: &[u8]) -> Result<(), Error> {
self.check_no_pending()?;
let _id = self.fire_delete(cache, key, 0)?;
match self.recv().await? {
CompletedOp::Delete { result, .. } => result,
_ => Err(Error::Protocol("unexpected response type".into())),
}
}
#[inline]
fn namespace_for(&self, cache: &str) -> Bytes {
if self.namespace.is_empty() {
Bytes::copy_from_slice(cache.as_bytes())
} else {
self.namespace.clone()
}
}
fn next_id(&mut self) -> u64 {
let id = self.next_message_id;
self.next_message_id += 1;
id
}
fn send_command(&mut self, cmd: &CacheCommand) -> Result<(), Error> {
self.send_buf.clear();
cmd.encode_length_delimited_into(&mut self.send_buf);
self.conn.send_nowait(&self.send_buf)?;
Ok(())
}
async fn authenticate(&mut self, token: &str) -> Result<(), Error> {
let message_id = self.next_id();
let cmd = CacheCommand::new(
message_id,
UnaryCommand::Authenticate {
auth_token: token.to_string(),
},
);
self.send_buf.clear();
cmd.encode_length_delimited_into(&mut self.send_buf);
if let Err(e) = self.conn.send_nowait(&self.send_buf) {
self.conn.close();
return Err(Error::Io(e));
}
let mut skips = 0usize;
loop {
let mut auth_result: Option<Result<(), Error>> = None;
let mut malformed = false;
let mut oversize = false;
let n = self
.conn
.with_bytes(
|bytes| match decode_length_delimited_message_bytes(&bytes) {
DecodedMessage::Message(consumed, msg_bytes) => {
if let Some(response) = CacheResponse::decode_bytes(msg_bytes) {
if response.message_id == message_id {
match response.result {
CacheResponseResult::Authenticate => {
auth_result = Some(Ok(()));
}
CacheResponseResult::Error(err) => {
auth_result = Some(Err(Error::AuthFailed(err.message)));
}
_ => {
auth_result = Some(Err(Error::Protocol(
"unexpected auth response type".into(),
)));
}
}
}
} else {
malformed = true;
}
ParseResult::Consumed(consumed)
}
DecodedMessage::Incomplete => ParseResult::Consumed(0),
DecodedMessage::Oversize => {
oversize = true;
ParseResult::Consumed(0)
}
},
)
.await;
if oversize {
self.conn.close();
return Err(Error::Protocol(
"auth response exceeded MAX_MESSAGE_SIZE".into(),
));
}
if n == 0 {
self.conn.close();
return Err(Error::ConnectionClosed);
}
if malformed {
self.conn.close();
return Err(Error::Protocol("failed to decode auth response".into()));
}
if let Some(r) = auth_result {
return r;
}
skips += 1;
if skips > MAX_RECV_SKIPS {
self.conn.close();
return Err(Error::Protocol(
"too many unmatched messages before auth response".into(),
));
}
}
}
fn resolve_addr(host: &str, port: u16) -> Result<SocketAddr, Error> {
use std::net::ToSocketAddrs;
let addr_str = format!("{}:{}", host, port);
addr_str
.to_socket_addrs()
.map_err(|e| Error::Config(format!("failed to resolve {}: {}", addr_str, e)))?
.next()
.ok_or_else(|| Error::Config(format!("no addresses found for {}", addr_str)))
}
#[inline]
fn is_instrumented(&self) -> bool {
if self.on_result.is_some() {
return true;
}
#[cfg(feature = "metrics")]
if self.metrics.is_some() {
return true;
}
false
}
#[inline]
fn timing_start(&self) -> (u64, Option<Instant>) {
if self.is_instrumented() {
(self.send_timestamp(), Some(Instant::now()))
} else {
(0, None)
}
}
#[cfg(feature = "timestamps")]
#[inline]
fn capture_recv_ts(&self) -> u64 {
if self.use_kernel_ts {
self.conn.recv_timestamp()
} else {
0
}
}
#[cfg(not(feature = "timestamps"))]
#[inline]
fn capture_recv_ts(&self) -> u64 {
0
}
#[inline]
fn ttfb_from_timestamps(recv_ts: u64, send_ts: u64) -> Option<u64> {
if recv_ts > 0 && recv_ts > send_ts {
Some(recv_ts - send_ts)
} else {
None
}
}
#[cfg(feature = "timestamps")]
#[inline]
fn send_timestamp(&self) -> u64 {
if self.use_kernel_ts {
now_realtime_ns()
} else {
0
}
}
#[cfg(not(feature = "timestamps"))]
#[inline]
fn send_timestamp(&self) -> u64 {
0
}
#[inline]
fn finish_timing(&self, recv_ts: u64, send_ts: u64, start: Option<Instant>) -> u64 {
if recv_ts > 0 && recv_ts > send_ts {
return recv_ts - send_ts;
}
start.map_or(0, |s| s.elapsed().as_nanos() as u64)
}
fn record(&mut self, result: &CommandResult) {
if let Some(ref cb) = self.on_result {
cb(result);
}
#[cfg(feature = "metrics")]
if let Some(ref mut m) = self.metrics {
m.record(result);
}
}
}
struct DispatchResult {
op: CompletedOp,
cmd_type: CommandType,
success: bool,
send_ts: u64,
start: Option<Instant>,
tx_bytes: u32,
}
fn dispatch_response(
response: CacheResponse,
pending: &mut HashMap<u64, PendingOp>,
) -> Option<DispatchResult> {
let message_id = response.message_id;
let id = RequestId(message_id);
let op = pending.remove(&message_id)?;
let send_ts = op.send_ts;
let start = op.start;
let user_data = op.user_data;
let tx_bytes = op.tx_bytes;
match op.kind {
PendingOpKind::Get => {
let result = match response.result {
CacheResponseResult::Get { value } => Ok(value),
CacheResponseResult::Error(ref err) if err.code == StatusCode::NotFound => Ok(None),
CacheResponseResult::Error(err) => Err(Error::Protocol(format!(
"{}: {}",
err.code as u32, err.message
))),
_ => Err(Error::Protocol("unexpected response type for get".into())),
};
let success = result.is_ok();
Some(DispatchResult {
op: CompletedOp::Get {
id,
key: op.key,
result,
user_data,
latency_ns: 0,
},
cmd_type: CommandType::Get,
success,
send_ts,
start,
tx_bytes,
})
}
PendingOpKind::Set => {
let result = match response.result {
CacheResponseResult::Set => Ok(()),
CacheResponseResult::Error(err) => Err(Error::Protocol(format!(
"{}: {}",
err.code as u32, err.message
))),
_ => Err(Error::Protocol("unexpected response type for set".into())),
};
let success = result.is_ok();
Some(DispatchResult {
op: CompletedOp::Set {
id,
key: op.key,
result,
user_data,
latency_ns: 0,
},
cmd_type: CommandType::Set,
success,
send_ts,
start,
tx_bytes,
})
}
PendingOpKind::Delete => {
let result = match response.result {
CacheResponseResult::Delete => Ok(()),
CacheResponseResult::Error(err) => Err(Error::Protocol(format!(
"{}: {}",
err.code as u32, err.message
))),
_ => Err(Error::Protocol(
"unexpected response type for delete".into(),
)),
};
let success = result.is_ok();
Some(DispatchResult {
op: CompletedOp::Delete {
id,
key: op.key,
result,
user_data,
latency_ns: 0,
},
cmd_type: CommandType::Delete,
success,
send_ts,
start,
tx_bytes,
})
}
}
}
#[cfg(feature = "timestamps")]
fn now_realtime_ns() -> u64 {
let mut ts = libc::timespec {
tv_sec: 0,
tv_nsec: 0,
};
unsafe {
libc::clock_gettime(libc::CLOCK_REALTIME, &mut ts);
}
ts.tv_sec as u64 * 1_000_000_000 + ts.tv_nsec as u64
}
#[cfg(test)]
mod tests {
use super::*;
use crate::proto::CacheResponse;
fn make_pending_get(key: &[u8]) -> PendingOp {
PendingOp {
kind: PendingOpKind::Get,
key: Bytes::copy_from_slice(key),
send_ts: 0,
start: None,
user_data: 42,
tx_bytes: 0,
}
}
#[test]
fn max_recv_skips_is_positive() {
const { assert!(MAX_RECV_SKIPS > 0) };
}
#[test]
fn pending_ops_in_flight_display() {
let msg = format!("{}", Error::PendingOpsInFlight);
assert!(
msg.contains("fire_*"),
"PendingOpsInFlight display should mention fire_*, got: {msg}"
);
}
#[test]
fn dispatch_response_returns_some_on_match() {
let mut pending: HashMap<u64, PendingOp> = HashMap::new();
pending.insert(7, make_pending_get(b"k"));
let response = CacheResponse::get_hit(7, Bytes::from_static(b"v"));
let dr = dispatch_response(response, &mut pending);
assert!(dr.is_some(), "expected matched dispatch");
assert!(pending.is_empty(), "matched op should be drained");
}
#[test]
fn dispatch_response_returns_none_on_unmatched_id() {
let mut pending: HashMap<u64, PendingOp> = HashMap::new();
pending.insert(1, make_pending_get(b"k"));
let response = CacheResponse::get_hit(99, Bytes::from_static(b"v"));
let dr = dispatch_response(response, &mut pending);
assert!(dr.is_none(), "unmatched id must return None");
assert_eq!(pending.len(), 1, "pending must be unchanged on miss");
}
#[test]
fn dispatch_response_get_returns_protocol_error_on_wrong_kind() {
let mut pending: HashMap<u64, PendingOp> = HashMap::new();
pending.insert(3, make_pending_get(b"k"));
let response = CacheResponse::set_ok(3);
let dr = dispatch_response(response, &mut pending).expect("matched id");
assert!(pending.is_empty());
assert!(!dr.success, "kind mismatch should mark op as failed");
match dr.op {
CompletedOp::Get { result, .. } => assert!(result.is_err()),
_ => panic!("expected Get op kind preserved"),
}
}
}