#![deny(warnings, missing_docs)]
use std::collections::HashMap;
use bytes::BytesMut;
use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt};
mod connection;
use self::connection::Connection;
mod error;
pub use self::error::Error;
mod parser;
use self::parser::{
parse_ascii_metadump_response, parse_ascii_response, parse_ascii_stats_response, Response,
};
pub use self::parser::{ErrorKind, KeyMetadata, MetadumpResponse, StatsResponse, Status, Value};
pub struct Client {
buf: BytesMut,
last_read_n: Option<usize>,
conn: Connection,
}
impl Client {
pub async fn new<S: AsRef<str>>(dsn: S) -> Result<Client, Error> {
let connection = Connection::new(dsn).await?;
Ok(Client {
buf: BytesMut::new(),
last_read_n: None,
conn: connection,
})
}
pub(crate) async fn drive_receive<R, F>(&mut self, op: F) -> Result<R, Error>
where
F: Fn(&[u8]) -> Result<Option<(usize, R)>, ErrorKind>,
{
if let Some(n) = self.last_read_n {
let _ = self.buf.split_to(n);
}
let mut needs_more_data = false;
loop {
if self.buf.is_empty() || needs_more_data {
match self.conn {
Connection::Tcp(ref mut s) => {
self.buf.reserve(1024);
let n = s.read_buf(&mut self.buf).await?;
if n == 0 {
return Err(Error::Io(std::io::ErrorKind::UnexpectedEof.into()));
}
}
}
}
match op(&self.buf) {
Ok(Some((n, response))) => {
self.last_read_n = Some(n);
return Ok(response);
}
Ok(None) => {
needs_more_data = true;
continue;
}
Err(kind) => return Err(Status::Error(kind).into()),
}
}
}
pub(crate) async fn get_read_write_response(&mut self) -> Result<Response, Error> {
self.drive_receive(parse_ascii_response).await
}
pub(crate) async fn get_metadump_response(&mut self) -> Result<MetadumpResponse, Error> {
self.drive_receive(parse_ascii_metadump_response).await
}
pub(crate) async fn get_stats_response(&mut self) -> Result<StatsResponse, Error> {
self.drive_receive(parse_ascii_stats_response).await
}
pub async fn get<K: AsRef<[u8]>>(&mut self, key: K) -> Result<Option<Value>, Error> {
self.conn.write_all(b"get ").await?;
self.conn.write_all(key.as_ref()).await?;
self.conn.write_all(b"\r\n").await?;
self.conn.flush().await?;
match self.get_read_write_response().await? {
Response::Status(Status::NotFound) => Ok(None),
Response::Status(s) => Err(s.into()),
Response::Data(d) => d
.map(|mut items| {
if items.len() != 1 {
Err(Status::Error(ErrorKind::Protocol(None)).into())
} else {
Ok(items.remove(0))
}
})
.transpose(),
_ => Err(Error::Protocol(Status::Error(ErrorKind::Protocol(None)))),
}
}
pub async fn get_many<I, K>(&mut self, keys: I) -> Result<Vec<Value>, Error>
where
I: IntoIterator<Item = K>,
K: AsRef<[u8]>,
{
self.conn.write_all(b"get ").await?;
for key in keys.into_iter() {
self.conn.write_all(key.as_ref()).await?;
self.conn.write_all(b" ").await?;
}
self.conn.write_all(b"\r\n").await?;
self.conn.flush().await?;
match self.get_read_write_response().await? {
Response::Status(s) => Err(s.into()),
Response::Data(d) => d.ok_or(Status::NotFound.into()),
_ => Err(Status::Error(ErrorKind::Protocol(None)).into()),
}
}
pub async fn set<K, V>(
&mut self,
key: K,
value: V,
ttl: Option<i64>,
flags: Option<u32>,
) -> Result<(), Error>
where
K: AsRef<[u8]>,
V: AsRef<[u8]>,
{
let kr = key.as_ref();
let vr = value.as_ref();
self.conn.write_all(b"set ").await?;
self.conn.write_all(kr).await?;
let flags = flags.unwrap_or(0).to_string();
self.conn.write_all(b" ").await?;
self.conn.write_all(flags.as_ref()).await?;
let ttl = ttl.unwrap_or(0).to_string();
self.conn.write_all(b" ").await?;
self.conn.write_all(ttl.as_ref()).await?;
self.conn.write_all(b" ").await?;
let vlen = vr.len().to_string();
self.conn.write_all(vlen.as_ref()).await?;
self.conn.write_all(b"\r\n").await?;
self.conn.write_all(vr).await?;
self.conn.write_all(b"\r\n").await?;
self.conn.flush().await?;
match self.get_read_write_response().await? {
Response::Status(Status::Stored) => Ok(()),
Response::Status(s) => Err(s.into()),
_ => Err(Status::Error(ErrorKind::Protocol(None)).into()),
}
}
pub async fn add<K, V>(
&mut self,
key: K,
value: V,
ttl: Option<i64>,
flags: Option<u32>,
) -> Result<(), Error>
where
K: AsRef<[u8]>,
V: AsRef<[u8]>,
{
let kr = key.as_ref();
let vr = value.as_ref();
self.conn
.write_all(
&[
b"add ",
kr,
b" ",
flags.unwrap_or(0).to_string().as_ref(),
b" ",
ttl.unwrap_or(0).to_string().as_ref(),
b" ",
vr.len().to_string().as_ref(),
b"\r\n",
vr,
b"\r\n",
]
.concat(),
)
.await?;
self.conn.flush().await?;
match self.get_read_write_response().await? {
Response::Status(Status::Stored) => Ok(()),
Response::Status(s) => Err(s.into()),
_ => Err(Status::Error(ErrorKind::Protocol(None)).into()),
}
}
pub async fn delete_no_reply<K>(&mut self, key: K) -> Result<(), Error>
where
K: AsRef<[u8]>,
{
let kr = key.as_ref();
self.conn
.write_all(&[b"delete ", kr, b" noreply\r\n"].concat())
.await?;
self.conn.flush().await?;
Ok(())
}
pub async fn delete<K>(&mut self, key: K) -> Result<(), Error>
where
K: AsRef<[u8]>,
{
let kr = key.as_ref();
self.conn
.write_all(&[b"delete ", kr, b"\r\n"].concat())
.await?;
self.conn.flush().await?;
match self.get_read_write_response().await? {
Response::Status(Status::Deleted) => Ok(()),
Response::Status(s) => Err(s.into()),
_ => Err(Status::Error(ErrorKind::Protocol(None)).into()),
}
}
pub async fn version(&mut self) -> Result<String, Error> {
self.conn.write_all(b"version\r\n").await?;
self.conn.flush().await?;
let mut version = String::new();
let bytes = self.conn.read_line(&mut version).await?;
if bytes >= 8 && version.is_char_boundary(8) {
Ok(version.split_off(8))
} else {
Err(Error::from(Status::Error(ErrorKind::Protocol(Some(
format!("Invalid response for `version` command: `{version}`"),
)))))
}
}
pub async fn dump_keys(&mut self) -> Result<MetadumpIter<'_>, Error> {
self.conn.write_all(b"lru_crawler metadump all\r\n").await?;
self.conn.flush().await?;
Ok(MetadumpIter {
client: self,
done: false,
})
}
pub async fn stats(&mut self) -> Result<HashMap<String, String>, Error> {
let mut entries = HashMap::new();
self.conn.write_all(b"stats\r\n").await?;
self.conn.flush().await?;
while let StatsResponse::Entry(key, value) = self.get_stats_response().await? {
entries.insert(key, value);
}
Ok(entries)
}
}
pub struct MetadumpIter<'a> {
client: &'a mut Client,
done: bool,
}
impl<'a> MetadumpIter<'a> {
pub async fn next(&mut self) -> Option<Result<KeyMetadata, Error>> {
if self.done {
return None;
}
match self.client.get_metadump_response().await {
Ok(MetadumpResponse::End) => {
self.done = true;
None
}
Ok(MetadumpResponse::BadClass(s)) => {
self.done = true;
Some(Err(Error::Protocol(MetadumpResponse::BadClass(s).into())))
}
Ok(MetadumpResponse::Busy(s)) => {
Some(Err(Error::Protocol(MetadumpResponse::Busy(s).into())))
}
Ok(MetadumpResponse::Entry(km)) => Some(Ok(km)),
Err(e) => Some(Err(e)),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
const KEY: &str = "async-memcache-test-key";
#[tokio::test]
async fn test_add() {
let mut client = Client::new("localhost:47386")
.await
.expect("Failed to connect to server");
let result = client.delete_no_reply(KEY).await;
assert!(result.is_ok(), "failed to delete {}, {:?}", KEY, result);
let result = client.add(KEY, "value", None, None).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_delete() {
let mut client = Client::new("localhost:47386")
.await
.expect("Failed to connect to server");
let key = "async-memcache-test-key";
let value = rand::random::<u64>().to_string();
let result = client.set(key, &value, None, None).await;
assert!(result.is_ok(), "failed to set {}, {:?}", key, result);
let result = client.get(key).await;
assert!(result.is_ok(), "failed to get {}, {:?}", key, result);
let get_result = result.unwrap();
match get_result {
Some(get_value) => assert_eq!(
String::from_utf8(get_value.data).expect("failed to parse a string"),
value
),
None => panic!("failed to get {}", key),
}
let result = client.delete(key).await;
assert!(result.is_ok(), "failed to delete {}, {:?}", key, result);
}
}