use crate::{
protocol::{Header, Packet, ProtocolError, SetExtras, Status},
ring::Ring,
};
use async_trait::async_trait;
use deadpool::managed::{Manager, RecycleResult};
use serde::{de::DeserializeOwned, Serialize};
use std::{
collections::HashMap,
error::Error as StdError,
fmt::{Display, Formatter, Result as FmtResult},
hash::Hash,
marker::PhantomData,
};
#[derive(Debug)]
pub enum Error {
IoError(std::io::Error),
Protocol(ProtocolError),
Bincode(bincode::Error),
Status(Status),
}
pub type BulkOkResponse<V> = HashMap<Vec<u8>, V>;
pub type BulkErrResponse = HashMap<Vec<u8>, Error>;
pub type BulkUpdateResponse = Result<BulkErrResponse, Error>;
pub type BulkGetResponse<V> = Result<(BulkOkResponse<V>, BulkErrResponse), Error>;
impl From<std::io::Error> for Error {
fn from(err: std::io::Error) -> Self {
Self::IoError(err)
}
}
impl From<ProtocolError> for Error {
fn from(err: ProtocolError) -> Self {
Self::Protocol(err)
}
}
impl From<bincode::Error> for Error {
fn from(err: bincode::Error) -> Self {
Self::Bincode(err)
}
}
impl From<Status> for Error {
fn from(err: Status) -> Self {
Self::Status(err)
}
}
impl Display for Error {
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
match self {
Error::IoError(err) => write!(f, "IoError: {}", err),
Error::Protocol(err) => write!(f, "ProtocolError: {}", err),
Error::Bincode(err) => write!(f, "BincodeError: {}", err),
Error::Status(err) => write!(f, "StatusError: {}", err),
}
}
}
impl StdError for Error {
fn source(&self) -> Option<&(dyn StdError + 'static)> {
match self {
Error::IoError(err) => Some(err),
Error::Protocol(err) => Some(err),
Error::Bincode(err) => Some(err),
Error::Status(err) => Some(err),
}
}
}
pub trait Compressor: Clone + Copy + Send + Sync {
fn compress(&self, packet: Packet) -> Result<Packet, Error>;
fn decompress(&self, packet: Packet) -> Result<Packet, Error>;
}
#[derive(Debug, Clone, Copy)]
pub struct NoCompressor;
impl Compressor for NoCompressor {
fn compress(&self, bytes: Packet) -> Result<Packet, Error> {
Ok(bytes)
}
fn decompress(&self, bytes: Packet) -> Result<Packet, Error> {
Ok(bytes)
}
}
#[async_trait]
pub trait Connection: Clone + Sized + Send + Sync + 'static {
async fn connect(url: String) -> Result<Self, Error>;
async fn read(&mut self, buf: &mut Vec<u8>) -> Result<usize, Error>;
async fn write(&mut self, data: &[u8]) -> Result<(), Error>;
async fn read_packet<P: Compressor>(&mut self, compressor: P) -> Result<Packet, Error> {
let mut buf = vec![0_u8; 24];
self.read(&mut buf).await?;
let header = Header::read_response(&buf[..])?;
let mut body = vec![0_u8; header.body_len as usize];
if !body.is_empty() {
self.read(&mut body).await?;
}
let packet = header.read_packet(&body[..])?;
compressor.decompress(packet)
}
async fn write_packet<P: Compressor>(
&mut self,
compressor: P,
packet: Packet,
) -> Result<(), Error> {
let packet = compressor.compress(packet)?;
let bytes: Vec<u8> = packet.into();
self.write(&bytes[..]).await
}
}
#[derive(Debug, Clone)]
pub struct ClientConfig<C: Connection, P: Compressor> {
endpoints: Vec<String>,
compressor: P,
phantom: PhantomData<C>,
}
impl<C, P> ClientConfig<C, P>
where
C: Connection,
P: Compressor,
{
pub fn new(endpoints: Vec<String>, compressor: P) -> Self {
Self {
endpoints,
compressor,
phantom: PhantomData,
}
}
}
impl<C> ClientConfig<C, NoCompressor>
where
C: Connection,
{
pub fn new_uncompressed(endpoints: Vec<String>) -> Self {
Self::new(endpoints, NoCompressor)
}
}
#[derive(Debug, Clone)]
pub struct Client<C: Connection, P: Compressor> {
ring: Ring<C>,
compressor: P,
}
impl<C: Connection, P: Compressor> Client<C, P> {
pub async fn new(config: ClientConfig<C, P>) -> Result<Self, Error> {
let ClientConfig {
endpoints,
compressor,
..
} = config;
let ring = Ring::new(endpoints).await?;
Ok(Self { ring, compressor })
}
pub async fn get<K: AsRef<[u8]>, V: DeserializeOwned>(
&mut self,
key: K,
) -> Result<Option<V>, Error> {
let key = key.as_ref();
let conn = self.ring.get_conn(key)?;
conn.write_packet(self.compressor, Packet::get(key)?)
.await?;
let packet = conn.read_packet(self.compressor).await?;
match packet.error_for_status() {
Ok(()) => Ok(Some(packet.deserialize_value()?)),
Err(Status::KeyNotFound) => Ok(None),
Err(status) => Err(status.into()),
}
}
pub async fn get_multi<'a, K: AsRef<[u8]>, V: DeserializeOwned>(
&mut self,
keys: &[K],
) -> BulkGetResponse<V> {
let mut values = HashMap::new();
let mut errors = HashMap::new();
for (conn, mut pipeline) in self.ring.get_conns(keys) {
let last_key = pipeline.pop().unwrap();
let reqs = pipeline
.iter()
.map(Packet::getkq)
.chain(vec![Packet::getk(last_key)])
.collect::<Result<Vec<_>, _>>()?;
for packet in reqs {
let key = packet.key.clone();
let result = conn.write_packet(self.compressor, packet).await;
if let Err(err) = result {
errors.insert(key, err);
}
}
}
for (conn, mut pipeline) in self.ring.get_conns(keys) {
let last_key = pipeline.pop().unwrap();
let mut finished = false;
while !finished {
let packet = conn.read_packet(self.compressor).await?;
let key = packet.key.clone();
finished = key == last_key.as_ref();
match packet.error_for_status() {
Err(Status::KeyNotFound) => (),
Err(err) => {
errors.insert(key, Error::Status(err));
}
Ok(()) => {
values.insert(key, packet.deserialize_value()?);
}
}
}
}
Ok((values, errors))
}
pub async fn set<K: AsRef<[u8]>, V: Serialize + ?Sized>(
&mut self,
key: K,
data: &V,
expire: u32,
) -> Result<(), Error> {
let key = key.as_ref();
let conn = self.ring.get_conn(key)?;
let packet = Packet::set(key, data, SetExtras::new(0, expire))?;
conn.write_packet(self.compressor, packet).await?;
conn.read_packet(self.compressor)
.await?
.error_for_status()?;
Ok(())
}
pub async fn set_multi<'a, V: Serialize, K: AsRef<[u8]> + Eq + Hash>(
&mut self,
data: HashMap<K, V>,
expire: u32,
) -> BulkUpdateResponse {
let mut errors = HashMap::new();
let keys = data.keys().collect::<Vec<_>>();
let extras = SetExtras::new(0, expire);
for (conn, mut pipeline) in self.ring.get_conns(&keys[..]) {
let last_key = pipeline.pop().unwrap();
let last_val = data.get(last_key).unwrap();
let reqs = pipeline
.into_iter()
.map(|key| (key, data.get(key).unwrap()))
.map(|(key, value)| Packet::setq(key, value, extras))
.chain(vec![Packet::set(last_key, last_val, extras)])
.collect::<Result<Vec<_>, _>>()?;
for packet in reqs {
let key = packet.key.clone();
if let Err(err) = conn.write_packet(self.compressor, packet).await {
errors.insert(key, err);
}
}
}
for (conn, _) in self.ring.get_conns(&keys[..]) {
let mut finished = false;
while !finished {
let packet = conn.read_packet(self.compressor).await?;
let key = packet.key.clone();
finished = packet.header.vbucket_or_status == 0;
match packet.error_for_status() {
Ok(()) => (),
Err(Status::KeyNotFound) => (),
Err(err) => {
errors.insert(key, Error::Status(err));
}
}
}
}
Ok(errors)
}
pub async fn delete<K: AsRef<[u8]>>(&mut self, key: K) -> Result<(), Error> {
let key = key.as_ref();
let conn = self.ring.get_conn(key)?;
conn.write_packet(self.compressor, Packet::delete(key)?)
.await?;
conn.read_packet(self.compressor)
.await?
.error_for_status()?;
Ok(())
}
pub async fn delete_multi<K: AsRef<[u8]>>(&mut self, keys: &[K]) -> BulkUpdateResponse {
let mut errors = HashMap::new();
for (conn, pipeline) in self.ring.get_conns(keys) {
let reqs = pipeline
.into_iter()
.map(Packet::delete)
.collect::<Result<Vec<_>, _>>()?;
for packet in reqs {
let key = packet.key.clone();
if let Err(err) = conn.write_packet(self.compressor, packet).await {
errors.insert(key, err);
}
}
}
for (conn, pipeline) in self.ring.get_conns(keys) {
for _ in pipeline {
let packet = conn.read_packet(self.compressor).await?;
let key = packet.key.clone();
match packet.error_for_status() {
Ok(()) => (),
Err(err) => {
errors.insert(key, Error::Status(err));
}
}
}
}
Ok(errors)
}
async fn keep_alive(&mut self) -> Result<(), Error> {
for conn in self.ring.into_iter() {
conn.write_packet(self.compressor, Packet::noop()?).await?;
let packet = conn.read_packet(self.compressor).await?;
packet.error_for_status()?;
}
Ok(())
}
}
#[async_trait]
impl<C, P> Manager for ClientConfig<C, P>
where
C: Connection,
P: Compressor,
{
type Type = Client<C, P>;
type Error = Error;
async fn create(&self) -> Result<Self::Type, Error> {
let mut client = Client::new(self.clone()).await?;
client.keep_alive().await?;
Ok(client)
}
async fn recycle(&self, client: &mut Self::Type) -> RecycleResult<Error> {
client.keep_alive().await?;
Ok(())
}
}
pub type Pool<C, P> = deadpool::managed::Pool<ClientConfig<C, P>>;
#[cfg(test)]
mod tests {
use crate::protocol::ProtocolError;
use super::Error;
#[test]
fn test_err_display() {
assert_eq!(
"ProtocolError: Invalid magic byte: 8",
format!("{}", Error::Protocol(ProtocolError::InvalidMagic(8)))
);
assert_eq!(
"StatusError: Key not found",
format!("{}", Error::Status(crate::protocol::Status::KeyNotFound))
);
}
}