#![deny(warnings, missing_docs)]
use bytes::BytesMut;
use fxhash::FxHashMap;
use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt};
mod connection;
use self::connection::Connection;
mod error;
pub use self::error::Error;
mod parser;
use self::parser::{
parse_ascii_metadump_response, parse_ascii_response, parse_ascii_stats_response,
};
pub use self::parser::{
ErrorKind, KeyMetadata, MetadumpResponse, Response, StatsResponse, Status, Value,
};
pub mod proto;
pub use self::proto::{AsciiProtocol, MetaProtocol};
mod value_serializer;
pub use self::value_serializer::AsMemcachedValue;
const MAX_KEY_LENGTH: usize = 250;
pub struct Client {
buf: BytesMut,
last_read_n: Option<usize>,
conn: Connection,
}
impl Client {
pub async fn new<S: AsRef<str>>(dsn: S) -> Result<Client, Error> {
let connection = Connection::new(dsn).await?;
Ok(Client {
buf: BytesMut::new(),
last_read_n: None,
conn: connection,
})
}
pub(crate) async fn drive_receive<R, F>(&mut self, op: F) -> Result<R, Error>
where
F: Fn(&[u8]) -> Result<Option<(usize, R)>, ErrorKind>,
{
if let Some(n) = self.last_read_n {
if n > self.buf.len() {
return Err(Status::Error(ErrorKind::Client(
"Buffer length is less than last read length".to_string(),
))
.into());
}
let _ = self.buf.split_to(n);
}
let mut needs_more_data = false;
loop {
if self.buf.is_empty() || needs_more_data {
match self.conn {
Connection::Tcp(ref mut s) => {
self.buf.reserve(1024);
let n = s.read_buf(&mut self.buf).await?;
if n == 0 {
return Err(Error::Io(std::io::ErrorKind::UnexpectedEof.into()));
}
}
Connection::Unix(ref mut s) => {
self.buf.reserve(1024);
let n = s.read_buf(&mut self.buf).await?;
if n == 0 {
return Err(Error::Io(std::io::ErrorKind::UnexpectedEof.into()));
}
}
}
}
match op(&self.buf) {
Ok(Some((n, response))) => {
self.last_read_n = Some(n);
return Ok(response);
}
Ok(None) => {
needs_more_data = true;
continue;
}
Err(kind) => return Err(Status::Error(kind).into()),
}
}
}
pub(crate) async fn get_read_write_response(&mut self) -> Result<Response, Error> {
self.drive_receive(parse_ascii_response).await
}
pub(crate) async fn map_set_multi_responses<'a, K, V>(
&mut self,
kv: &'a [(K, V)],
) -> Result<FxHashMap<&'a K, Result<(), Error>>, Error>
where
K: AsRef<[u8]> + Eq + std::hash::Hash,
V: AsMemcachedValue,
{
let mut results = FxHashMap::with_capacity_and_hasher(kv.len(), Default::default());
for (key, _) in kv {
let kr = key.as_ref();
if kr.len() > MAX_KEY_LENGTH {
results.insert(
key,
Err(Error::Protocol(Status::Error(ErrorKind::Client(
"Key exceeds maximum length of 250 bytes".to_string(),
)))),
);
continue;
}
let result = match self.drive_receive(parse_ascii_response).await {
Ok(Response::Status(Status::Stored)) => Ok(()),
Ok(Response::Status(s)) => Err(s.into()),
Ok(_) => Err(Status::Error(ErrorKind::Protocol(None)).into()),
Err(e) => return Err(e),
};
results.insert(key, result);
}
Ok(results)
}
pub(crate) async fn get_metadump_response(&mut self) -> Result<MetadumpResponse, Error> {
self.drive_receive(parse_ascii_metadump_response).await
}
pub(crate) async fn get_stats_response(&mut self) -> Result<StatsResponse, Error> {
self.drive_receive(parse_ascii_stats_response).await
}
pub async fn version(&mut self) -> Result<String, Error> {
self.conn.write_all(b"version\r\n").await?;
self.conn.flush().await?;
let mut version = String::new();
let bytes = self.conn.read_line(&mut version).await?;
if bytes >= 8 && version.is_char_boundary(8) {
Ok(version.split_off(8))
} else {
Err(Error::from(Status::Error(ErrorKind::Protocol(Some(
format!("Invalid response for `version` command: `{version}`"),
)))))
}
}
pub async fn dump_keys(&mut self) -> Result<MetadumpIter<'_>, Error> {
self.conn.write_all(b"lru_crawler metadump all\r\n").await?;
self.conn.flush().await?;
Ok(MetadumpIter {
client: self,
done: false,
})
}
pub async fn stats(&mut self) -> Result<FxHashMap<String, String>, Error> {
let mut entries = FxHashMap::default();
self.conn.write_all(b"stats\r\n").await?;
self.conn.flush().await?;
while let StatsResponse::Entry(key, value) = self.get_stats_response().await? {
entries.insert(key, value);
}
Ok(entries)
}
pub async fn flush_all(&mut self) -> Result<(), Error> {
self.conn.write_all(b"flush_all\r\n").await?;
self.conn.flush().await?;
let mut response = String::new();
self.conn.read_line(&mut response).await?;
if response.trim() == "OK" {
Ok(())
} else {
Err(Error::from(Status::Error(ErrorKind::Protocol(Some(
format!("Invalid response for `flush_all` command: `{response}`"),
)))))
}
}
fn validate_key_length(kr: &[u8]) -> Result<&[u8], Error> {
if kr.len() > MAX_KEY_LENGTH {
return Err(Error::from(Status::Error(ErrorKind::KeyTooLong)));
}
Ok(kr)
}
fn validate_opaque_length(opaque: &[u8]) -> Result<&[u8], Error> {
if opaque.len() > 32 {
return Err(Error::from(Status::Error(ErrorKind::OpaqueTooLong)));
}
Ok(opaque)
}
async fn check_and_write_opaque(&mut self, opaque: Option<&[u8]>) -> Result<(), Error> {
if let Some(opaque) = &opaque {
self.conn.write_all(b" O").await?;
self.conn.write_all(opaque.as_ref()).await?;
}
Ok(())
}
async fn check_and_write_meta_flags(
&mut self,
meta_flags: Option<&[&str]>,
opaque: Option<&[u8]>,
) -> Result<(), Error> {
if let Some(meta_flags) = meta_flags {
for flag in meta_flags {
if flag.starts_with('q') || (flag.starts_with('O') && opaque.is_some()) {
continue;
} else {
self.conn.write_all(b" ").await?;
self.conn.write_all(flag.as_bytes()).await?;
}
}
}
Ok(())
}
async fn check_and_write_quiet_mode(&mut self, is_quiet: bool) -> Result<(), Error> {
if is_quiet {
self.conn.write_all(b" q\r\nmn\r\n").await?;
} else {
self.conn.write_all(b"\r\n").await?;
}
Ok(())
}
}
pub struct MetadumpIter<'a> {
client: &'a mut Client,
done: bool,
}
impl MetadumpIter<'_> {
pub async fn next(&mut self) -> Option<Result<KeyMetadata, Error>> {
if self.done {
return None;
}
match self.client.get_metadump_response().await {
Ok(MetadumpResponse::End) => {
self.done = true;
None
}
Ok(MetadumpResponse::BadClass(s)) => {
self.done = true;
Some(Err(Error::Protocol(MetadumpResponse::BadClass(s).into())))
}
Ok(MetadumpResponse::Busy(s)) => {
Some(Err(Error::Protocol(MetadumpResponse::Busy(s).into())))
}
Ok(MetadumpResponse::Entry(km)) => Some(Ok(km)),
Err(e) => Some(Err(e)),
}
}
}