1use crate::{
6 protocol::{Header, Packet, ProtocolError, SetExtras, Status},
7 ring::Ring,
8};
9use async_trait::async_trait;
10use deadpool::managed::{Manager, RecycleResult};
11use serde::{de::DeserializeOwned, Serialize};
12use std::{
13 collections::HashMap,
14 error::Error as StdError,
15 fmt::{Display, Formatter, Result as FmtResult},
16 hash::Hash,
17 marker::PhantomData,
18};
19
20#[derive(Debug)]
22pub enum Error {
23 IoError(std::io::Error),
25 Protocol(ProtocolError),
27 Bincode(bincode::Error),
29 Status(Status),
31}
32
33pub type BulkOkResponse<V> = HashMap<Vec<u8>, V>;
36
37pub type BulkErrResponse = HashMap<Vec<u8>, Error>;
42
43pub type BulkUpdateResponse = Result<BulkErrResponse, Error>;
45
46pub type BulkGetResponse<V> = Result<(BulkOkResponse<V>, BulkErrResponse), Error>;
51
52impl From<std::io::Error> for Error {
53 fn from(err: std::io::Error) -> Self {
54 Self::IoError(err)
55 }
56}
57
58impl From<ProtocolError> for Error {
59 fn from(err: ProtocolError) -> Self {
60 Self::Protocol(err)
61 }
62}
63
64impl From<bincode::Error> for Error {
65 fn from(err: bincode::Error) -> Self {
66 Self::Bincode(err)
67 }
68}
69
70impl From<Status> for Error {
71 fn from(err: Status) -> Self {
72 Self::Status(err)
73 }
74}
75
76impl Display for Error {
77 fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
78 match self {
79 Error::IoError(err) => write!(f, "IoError: {}", err),
80 Error::Protocol(err) => write!(f, "ProtocolError: {}", err),
81 Error::Bincode(err) => write!(f, "BincodeError: {}", err),
82 Error::Status(err) => write!(f, "StatusError: {}", err),
83 }
84 }
85}
86
87impl StdError for Error {
88 fn source(&self) -> Option<&(dyn StdError + 'static)> {
89 match self {
90 Error::IoError(err) => Some(err),
91 Error::Protocol(err) => Some(err),
92 Error::Bincode(err) => Some(err),
93 Error::Status(err) => Some(err),
94 }
95 }
96}
97
98pub trait Compressor: Clone + Copy + Send + Sync {
105 fn compress(&self, packet: Packet) -> Result<Packet, Error>;
109 fn decompress(&self, packet: Packet) -> Result<Packet, Error>;
113}
114
115#[derive(Debug, Clone, Copy)]
118pub struct NoCompressor;
119
120impl Compressor for NoCompressor {
121 fn compress(&self, bytes: Packet) -> Result<Packet, Error> {
122 Ok(bytes)
123 }
124
125 fn decompress(&self, bytes: Packet) -> Result<Packet, Error> {
126 Ok(bytes)
127 }
128}
129
130#[async_trait]
133pub trait Connection: Clone + Sized + Send + Sync + 'static {
134 async fn connect(url: String) -> Result<Self, Error>;
136
137 async fn read(&mut self, buf: &mut Vec<u8>) -> Result<usize, Error>;
139
140 async fn write(&mut self, data: &[u8]) -> Result<(), Error>;
142
143 async fn read_packet<P: Compressor>(&mut self, compressor: P) -> Result<Packet, Error> {
146 let mut buf = vec![0_u8; 24];
147 self.read(&mut buf).await?;
148 let header = Header::read_response(&buf[..])?;
149 let mut body = vec![0_u8; header.body_len as usize];
150 if !body.is_empty() {
151 self.read(&mut body).await?;
152 }
153 let packet = header.read_packet(&body[..])?;
154 compressor.decompress(packet)
155 }
156
157 async fn write_packet<P: Compressor>(
160 &mut self,
161 compressor: P,
162 packet: Packet,
163 ) -> Result<(), Error> {
164 let packet = compressor.compress(packet)?;
165 let bytes: Vec<u8> = packet.into();
166 self.write(&bytes[..]).await
167 }
168}
169
170#[derive(Debug, Clone)]
172pub struct ClientConfig<C: Connection, P: Compressor> {
173 endpoints: Vec<String>,
174 compressor: P,
175 phantom: PhantomData<C>,
176}
177
178impl<C, P> ClientConfig<C, P>
179where
180 C: Connection,
181 P: Compressor,
182{
183 pub fn new(endpoints: Vec<String>, compressor: P) -> Self {
187 Self {
188 endpoints,
189 compressor,
190 phantom: PhantomData,
191 }
192 }
193}
194
195impl<C> ClientConfig<C, NoCompressor>
196where
197 C: Connection,
198{
199 pub fn new_uncompressed(endpoints: Vec<String>) -> Self {
201 Self::new(endpoints, NoCompressor)
202 }
203}
204
205#[derive(Debug, Clone)]
208pub struct Client<C: Connection, P: Compressor> {
209 ring: Ring<C>,
210 compressor: P,
211}
212
213impl<C: Connection, P: Compressor> Client<C, P> {
214 pub async fn new(config: ClientConfig<C, P>) -> Result<Self, Error> {
216 let ClientConfig {
217 endpoints,
218 compressor,
219 ..
220 } = config;
221 let ring = Ring::new(endpoints).await?;
222 Ok(Self { ring, compressor })
223 }
224
225 pub async fn get<K: AsRef<[u8]>, V: DeserializeOwned>(
228 &mut self,
229 key: K,
230 ) -> Result<Option<V>, Error> {
231 let key = key.as_ref();
232 let conn = self.ring.get_conn(key)?;
233 conn.write_packet(self.compressor, Packet::get(key)?)
234 .await?;
235
236 let packet = conn.read_packet(self.compressor).await?;
237 match packet.error_for_status() {
238 Ok(()) => Ok(Some(packet.deserialize_value()?)),
239 Err(Status::KeyNotFound) => Ok(None),
240 Err(status) => Err(status.into()),
241 }
242 }
243
244 pub async fn get_multi<'a, K: AsRef<[u8]>, V: DeserializeOwned>(
249 &mut self,
250 keys: &[K],
251 ) -> BulkGetResponse<V> {
252 let mut values = HashMap::new();
253 let mut errors = HashMap::new();
254
255 for (conn, mut pipeline) in self.ring.get_conns(keys) {
257 let last_key = pipeline.pop().unwrap();
258 let reqs = pipeline
259 .iter()
260 .map(Packet::getkq)
261 .chain(vec![Packet::getk(last_key)])
262 .collect::<Result<Vec<_>, _>>()?;
263
264 for packet in reqs {
265 let key = packet.key.clone();
266 let result = conn.write_packet(self.compressor, packet).await;
267 if let Err(err) = result {
268 errors.insert(key, err);
269 }
270 }
271 }
272
273 for (conn, mut pipeline) in self.ring.get_conns(keys) {
275 let last_key = pipeline.pop().unwrap();
276 let mut finished = false;
277 while !finished {
278 let packet = conn.read_packet(self.compressor).await?;
279 let key = packet.key.clone();
280 finished = key == last_key.as_ref();
281 match packet.error_for_status() {
282 Err(Status::KeyNotFound) => (),
283 Err(err) => {
284 errors.insert(key, Error::Status(err));
285 }
286 Ok(()) => {
287 values.insert(key, packet.deserialize_value()?);
288 }
289 }
290 }
291 }
292
293 Ok((values, errors))
294 }
295
296 pub async fn set<K: AsRef<[u8]>, V: Serialize + ?Sized>(
302 &mut self,
303 key: K,
304 data: &V,
305 expire: u32,
306 ) -> Result<(), Error> {
307 let key = key.as_ref();
308 let conn = self.ring.get_conn(key)?;
309 let packet = Packet::set(key, data, SetExtras::new(0, expire))?;
310 conn.write_packet(self.compressor, packet).await?;
311 conn.read_packet(self.compressor)
312 .await?
313 .error_for_status()?;
314 Ok(())
315 }
316
317 pub async fn set_multi<'a, V: Serialize, K: AsRef<[u8]> + Eq + Hash>(
323 &mut self,
324 data: HashMap<K, V>,
325 expire: u32,
326 ) -> BulkUpdateResponse {
327 let mut errors = HashMap::new();
328 let keys = data.keys().collect::<Vec<_>>();
329 let extras = SetExtras::new(0, expire);
330
331 for (conn, mut pipeline) in self.ring.get_conns(&keys[..]) {
333 let last_key = pipeline.pop().unwrap();
334 let last_val = data.get(last_key).unwrap();
335 let reqs = pipeline
336 .into_iter()
337 .map(|key| (key, data.get(key).unwrap()))
338 .map(|(key, value)| Packet::setq(key, value, extras))
339 .chain(vec![Packet::set(last_key, last_val, extras)])
340 .collect::<Result<Vec<_>, _>>()?;
341
342 for packet in reqs {
343 let key = packet.key.clone();
344 if let Err(err) = conn.write_packet(self.compressor, packet).await {
345 errors.insert(key, err);
346 }
347 }
348 }
349
350 for (conn, _) in self.ring.get_conns(&keys[..]) {
352 let mut finished = false;
353 while !finished {
354 let packet = conn.read_packet(self.compressor).await?;
355 let key = packet.key.clone();
356 finished = packet.header.vbucket_or_status == 0;
357 match packet.error_for_status() {
358 Ok(()) => (),
359 Err(Status::KeyNotFound) => (),
360 Err(err) => {
361 errors.insert(key, Error::Status(err));
362 }
363 }
364 }
365 }
366
367 Ok(errors)
368 }
369
370 pub async fn delete<K: AsRef<[u8]>>(&mut self, key: K) -> Result<(), Error> {
372 let key = key.as_ref();
373 let conn = self.ring.get_conn(key)?;
374 conn.write_packet(self.compressor, Packet::delete(key)?)
375 .await?;
376 conn.read_packet(self.compressor)
377 .await?
378 .error_for_status()?;
379 Ok(())
380 }
381
382 pub async fn delete_multi<K: AsRef<[u8]>>(&mut self, keys: &[K]) -> BulkUpdateResponse {
384 let mut errors = HashMap::new();
385
386 for (conn, pipeline) in self.ring.get_conns(keys) {
388 let reqs = pipeline
389 .into_iter()
390 .map(Packet::delete)
391 .collect::<Result<Vec<_>, _>>()?;
392 for packet in reqs {
393 let key = packet.key.clone();
394 if let Err(err) = conn.write_packet(self.compressor, packet).await {
395 errors.insert(key, err);
396 }
397 }
398 }
399
400 for (conn, pipeline) in self.ring.get_conns(keys) {
402 for _ in pipeline {
403 let packet = conn.read_packet(self.compressor).await?;
404 let key = packet.key.clone();
405 match packet.error_for_status() {
406 Ok(()) => (),
407 Err(err) => {
408 errors.insert(key, Error::Status(err));
409 }
410 }
411 }
412 }
413
414 Ok(errors)
415 }
416
417 async fn keep_alive(&mut self) -> Result<(), Error> {
418 for conn in self.ring.into_iter() {
420 conn.write_packet(self.compressor, Packet::noop()?).await?;
421 let packet = conn.read_packet(self.compressor).await?;
422 packet.error_for_status()?;
423 }
424 Ok(())
425 }
426}
427
428#[async_trait]
429impl<C, P> Manager for ClientConfig<C, P>
430where
431 C: Connection,
432 P: Compressor,
433{
434 type Type = Client<C, P>;
435 type Error = Error;
436
437 async fn create(&self) -> Result<Self::Type, Error> {
438 let mut client = Client::new(self.clone()).await?;
439 client.keep_alive().await?;
440 Ok(client)
441 }
442
443 async fn recycle(&self, client: &mut Self::Type) -> RecycleResult<Error> {
444 client.keep_alive().await?;
445 Ok(())
446 }
447}
448
449pub type Pool<C, P> = deadpool::managed::Pool<ClientConfig<C, P>>;
454
455#[cfg(test)]
456mod tests {
457 use crate::protocol::ProtocolError;
458
459 use super::Error;
460
461 #[test]
462 fn test_err_display() {
463 assert_eq!(
464 "ProtocolError: Invalid magic byte: 8",
465 format!("{}", Error::Protocol(ProtocolError::InvalidMagic(8)))
466 );
467 assert_eq!(
468 "StatusError: Key not found",
469 format!("{}", Error::Status(crate::protocol::Status::KeyNotFound))
470 );
471 }
472}