rsmc_core/
client.rs

1//! This module implements the high-level client used to communicate with
2//! a memcached cluster. Regardless of the async runtime used, all
3//! implementations use the same client interface with the same API.
4
5use crate::{
6    protocol::{Header, Packet, ProtocolError, SetExtras, Status},
7    ring::Ring,
8};
9use async_trait::async_trait;
10use deadpool::managed::{Manager, RecycleResult};
11use serde::{de::DeserializeOwned, Serialize};
12use std::{
13    collections::HashMap,
14    error::Error as StdError,
15    fmt::{Display, Formatter, Result as FmtResult},
16    hash::Hash,
17    marker::PhantomData,
18};
19
20/// An error causing during client communication with Memcached.
21#[derive(Debug)]
22pub enum Error {
23    /// An error communicating over the wire.
24    IoError(std::io::Error),
25    /// An error caused by incorrectly implementing the memcached protocol.
26    Protocol(ProtocolError),
27    /// An error caused by (de-)serializing a value.
28    Bincode(bincode::Error),
29    /// An error caused by a non-zero status received from a packet.
30    Status(Status),
31}
32
33/// The result of of a multi_get() request. A map of all of keys for which
34/// memcached returned a found response, and their corresponding values.
35pub type BulkOkResponse<V> = HashMap<Vec<u8>, V>;
36
37/// The result of a multi_*() request. A map of all keys for which there
38/// was an error for specific keys. These can be treated as get misses
39/// and ignored, but it may be desirable to log these errors to uncover
40/// underlying issues.
41pub type BulkErrResponse = HashMap<Vec<u8>, Error>;
42
43/// The result of doing a multi_set(), multi_delete(), etc...
44pub type BulkUpdateResponse = Result<BulkErrResponse, Error>;
45
46/// The result of doing a multi_get(). The Ok result will be a tuple of ok, err
47/// responses. The err responses can be treated as get misses, but should be
48/// logged somewhere for visibility. Lots of them could indicate a serious
49/// underlying issue.
50pub type BulkGetResponse<V> = Result<(BulkOkResponse<V>, BulkErrResponse), Error>;
51
52impl From<std::io::Error> for Error {
53    fn from(err: std::io::Error) -> Self {
54        Self::IoError(err)
55    }
56}
57
58impl From<ProtocolError> for Error {
59    fn from(err: ProtocolError) -> Self {
60        Self::Protocol(err)
61    }
62}
63
64impl From<bincode::Error> for Error {
65    fn from(err: bincode::Error) -> Self {
66        Self::Bincode(err)
67    }
68}
69
70impl From<Status> for Error {
71    fn from(err: Status) -> Self {
72        Self::Status(err)
73    }
74}
75
76impl Display for Error {
77    fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
78        match self {
79            Error::IoError(err) => write!(f, "IoError: {}", err),
80            Error::Protocol(err) => write!(f, "ProtocolError: {}", err),
81            Error::Bincode(err) => write!(f, "BincodeError: {}", err),
82            Error::Status(err) => write!(f, "StatusError: {}", err),
83        }
84    }
85}
86
87impl StdError for Error {
88    fn source(&self) -> Option<&(dyn StdError + 'static)> {
89        match self {
90            Error::IoError(err) => Some(err),
91            Error::Protocol(err) => Some(err),
92            Error::Bincode(err) => Some(err),
93            Error::Status(err) => Some(err),
94        }
95    }
96}
97
98/// A Compressor is used to implement compression of packet values. A default
99/// implementation is provided for [`NoCompressor`], as well as
100/// [`ZlibCompressor`].
101///
102/// If other compression algorithms are desired it is possible to implement
103/// this trait yourself and pass it into [`Client::new`].
104pub trait Compressor: Clone + Copy + Send + Sync {
105    /// Consume a packet, returning a (possibly) modified packet with the
106    /// packet value compressed. This should set the appropriate packet
107    /// flags on the extras field.
108    fn compress(&self, packet: Packet) -> Result<Packet, Error>;
109    /// Consume a packet, returning a (possibly) modified packet with the
110    /// packet value decompressed. This should unset the appropriate packet
111    /// flags on the extras field.
112    fn decompress(&self, packet: Packet) -> Result<Packet, Error>;
113}
114
115/// An implementation of [`Compressor`] that does nothing. This is useful if
116/// you want to disable compression.
117#[derive(Debug, Clone, Copy)]
118pub struct NoCompressor;
119
120impl Compressor for NoCompressor {
121    fn compress(&self, bytes: Packet) -> Result<Packet, Error> {
122        Ok(bytes)
123    }
124
125    fn decompress(&self, bytes: Packet) -> Result<Packet, Error> {
126        Ok(bytes)
127    }
128}
129
130/// A connection is an async interface to memcached, which requires a concrete
131/// implementation using an underlying async runtime (e.g. tokio or async-std.)
132#[async_trait]
133pub trait Connection: Clone + Sized + Send + Sync + 'static {
134    /// Connect to a memcached server over TCP.
135    async fn connect(url: String) -> Result<Self, Error>;
136
137    /// Read to fill the incoming buffer.
138    async fn read(&mut self, buf: &mut Vec<u8>) -> Result<usize, Error>;
139
140    /// Write an entire buffer to the TCP stream.
141    async fn write(&mut self, data: &[u8]) -> Result<(), Error>;
142
143    /// Read a packet response, possibly decompressing it. It is most likely
144    /// unnecessary to implement this yourself.
145    async fn read_packet<P: Compressor>(&mut self, compressor: P) -> Result<Packet, Error> {
146        let mut buf = vec![0_u8; 24];
147        self.read(&mut buf).await?;
148        let header = Header::read_response(&buf[..])?;
149        let mut body = vec![0_u8; header.body_len as usize];
150        if !body.is_empty() {
151            self.read(&mut body).await?;
152        }
153        let packet = header.read_packet(&body[..])?;
154        compressor.decompress(packet)
155    }
156
157    /// Write a packet request, possibly compressing it. It is most likely
158    /// unnecessary to implement this yourself.
159    async fn write_packet<P: Compressor>(
160        &mut self,
161        compressor: P,
162        packet: Packet,
163    ) -> Result<(), Error> {
164        let packet = compressor.compress(packet)?;
165        let bytes: Vec<u8> = packet.into();
166        self.write(&bytes[..]).await
167    }
168}
169
170/// Set configuration values for a memcached client.
171#[derive(Debug, Clone)]
172pub struct ClientConfig<C: Connection, P: Compressor> {
173    endpoints: Vec<String>,
174    compressor: P,
175    phantom: PhantomData<C>,
176}
177
178impl<C, P> ClientConfig<C, P>
179where
180    C: Connection,
181    P: Compressor,
182{
183    /// Create a new client config from the given memcached servers and
184    /// compressor. If no compression is desired, then use
185    /// [`ClientConfig::new_uncompressed`]
186    pub fn new(endpoints: Vec<String>, compressor: P) -> Self {
187        Self {
188            endpoints,
189            compressor,
190            phantom: PhantomData,
191        }
192    }
193}
194
195impl<C> ClientConfig<C, NoCompressor>
196where
197    C: Connection,
198{
199    /// Create a new client config with no compression.
200    pub fn new_uncompressed(endpoints: Vec<String>) -> Self {
201        Self::new(endpoints, NoCompressor)
202    }
203}
204
205/// A client manages connections to every node in a memcached cluster using
206/// consistent hashing to decide which connection to use based on the key.
207#[derive(Debug, Clone)]
208pub struct Client<C: Connection, P: Compressor> {
209    ring: Ring<C>,
210    compressor: P,
211}
212
213impl<C: Connection, P: Compressor> Client<C, P> {
214    /// Create a new client using the client config provided.
215    pub async fn new(config: ClientConfig<C, P>) -> Result<Self, Error> {
216        let ClientConfig {
217            endpoints,
218            compressor,
219            ..
220        } = config;
221        let ring = Ring::new(endpoints).await?;
222        Ok(Self { ring, compressor })
223    }
224
225    /// Get a single value from memcached. Returns None when the key is not
226    /// found (i.e., a miss).
227    pub async fn get<K: AsRef<[u8]>, V: DeserializeOwned>(
228        &mut self,
229        key: K,
230    ) -> Result<Option<V>, Error> {
231        let key = key.as_ref();
232        let conn = self.ring.get_conn(key)?;
233        conn.write_packet(self.compressor, Packet::get(key)?)
234            .await?;
235
236        let packet = conn.read_packet(self.compressor).await?;
237        match packet.error_for_status() {
238            Ok(()) => Ok(Some(packet.deserialize_value()?)),
239            Err(Status::KeyNotFound) => Ok(None),
240            Err(status) => Err(status.into()),
241        }
242    }
243
244    /// Get multiple values from memcached at once. On success, it returns
245    /// a tuple of (ok, err) responses. The error responses can be treated as
246    /// misses, but should be logged for visibility. Lots of errors could be
247    /// indicative of a serious problem.
248    pub async fn get_multi<'a, K: AsRef<[u8]>, V: DeserializeOwned>(
249        &mut self,
250        keys: &[K],
251    ) -> BulkGetResponse<V> {
252        let mut values = HashMap::new();
253        let mut errors = HashMap::new();
254
255        // TODO: parallelize
256        for (conn, mut pipeline) in self.ring.get_conns(keys) {
257            let last_key = pipeline.pop().unwrap();
258            let reqs = pipeline
259                .iter()
260                .map(Packet::getkq)
261                .chain(vec![Packet::getk(last_key)])
262                .collect::<Result<Vec<_>, _>>()?;
263
264            for packet in reqs {
265                let key = packet.key.clone();
266                let result = conn.write_packet(self.compressor, packet).await;
267                if let Err(err) = result {
268                    errors.insert(key, err);
269                }
270            }
271        }
272
273        // TODO: parallelize
274        for (conn, mut pipeline) in self.ring.get_conns(keys) {
275            let last_key = pipeline.pop().unwrap();
276            let mut finished = false;
277            while !finished {
278                let packet = conn.read_packet(self.compressor).await?;
279                let key = packet.key.clone();
280                finished = key == last_key.as_ref();
281                match packet.error_for_status() {
282                    Err(Status::KeyNotFound) => (),
283                    Err(err) => {
284                        errors.insert(key, Error::Status(err));
285                    }
286                    Ok(()) => {
287                        values.insert(key, packet.deserialize_value()?);
288                    }
289                }
290            }
291        }
292
293        Ok((values, errors))
294    }
295
296    /// Set a single key/value pair in memcached to expire at the desired
297    /// time. A value of 0 means "never expire", but the value could still be
298    /// evicted by the LRU cache. Important: if `expire` is set to more than 30
299    /// days in the future, then memcached will treat it as a unix timestamp
300    /// instead of a duration.
301    pub async fn set<K: AsRef<[u8]>, V: Serialize + ?Sized>(
302        &mut self,
303        key: K,
304        data: &V,
305        expire: u32,
306    ) -> Result<(), Error> {
307        let key = key.as_ref();
308        let conn = self.ring.get_conn(key)?;
309        let packet = Packet::set(key, data, SetExtras::new(0, expire))?;
310        conn.write_packet(self.compressor, packet).await?;
311        conn.read_packet(self.compressor)
312            .await?
313            .error_for_status()?;
314        Ok(())
315    }
316
317    /// Set multiple key/value pairs in memcached to expire at the desired
318    /// time. A value of 0 means "never expire", but the value could still be
319    /// evicted by the LRU cache. Important: if `expire` is set to more than 30
320    /// days in the future, then memcached will treat it as a unix timestamp
321    /// instead of a duration.
322    pub async fn set_multi<'a, V: Serialize, K: AsRef<[u8]> + Eq + Hash>(
323        &mut self,
324        data: HashMap<K, V>,
325        expire: u32,
326    ) -> BulkUpdateResponse {
327        let mut errors = HashMap::new();
328        let keys = data.keys().collect::<Vec<_>>();
329        let extras = SetExtras::new(0, expire);
330
331        // TODO: parallelize
332        for (conn, mut pipeline) in self.ring.get_conns(&keys[..]) {
333            let last_key = pipeline.pop().unwrap();
334            let last_val = data.get(last_key).unwrap();
335            let reqs = pipeline
336                .into_iter()
337                .map(|key| (key, data.get(key).unwrap()))
338                .map(|(key, value)| Packet::setq(key, value, extras))
339                .chain(vec![Packet::set(last_key, last_val, extras)])
340                .collect::<Result<Vec<_>, _>>()?;
341
342            for packet in reqs {
343                let key = packet.key.clone();
344                if let Err(err) = conn.write_packet(self.compressor, packet).await {
345                    errors.insert(key, err);
346                }
347            }
348        }
349
350        // TODO: parallelize
351        for (conn, _) in self.ring.get_conns(&keys[..]) {
352            let mut finished = false;
353            while !finished {
354                let packet = conn.read_packet(self.compressor).await?;
355                let key = packet.key.clone();
356                finished = packet.header.vbucket_or_status == 0;
357                match packet.error_for_status() {
358                    Ok(()) => (),
359                    Err(Status::KeyNotFound) => (),
360                    Err(err) => {
361                        errors.insert(key, Error::Status(err));
362                    }
363                }
364            }
365        }
366
367        Ok(errors)
368    }
369
370    /// Delete a key from memcached. Does nothing if the key is not set.
371    pub async fn delete<K: AsRef<[u8]>>(&mut self, key: K) -> Result<(), Error> {
372        let key = key.as_ref();
373        let conn = self.ring.get_conn(key)?;
374        conn.write_packet(self.compressor, Packet::delete(key)?)
375            .await?;
376        conn.read_packet(self.compressor)
377            .await?
378            .error_for_status()?;
379        Ok(())
380    }
381
382    /// Delete multiple keys from memcached. Does nothing when a key is unset.
383    pub async fn delete_multi<K: AsRef<[u8]>>(&mut self, keys: &[K]) -> BulkUpdateResponse {
384        let mut errors = HashMap::new();
385
386        // TODO: parallelize
387        for (conn, pipeline) in self.ring.get_conns(keys) {
388            let reqs = pipeline
389                .into_iter()
390                .map(Packet::delete)
391                .collect::<Result<Vec<_>, _>>()?;
392            for packet in reqs {
393                let key = packet.key.clone();
394                if let Err(err) = conn.write_packet(self.compressor, packet).await {
395                    errors.insert(key, err);
396                }
397            }
398        }
399
400        // TODO: parallelize
401        for (conn, pipeline) in self.ring.get_conns(keys) {
402            for _ in pipeline {
403                let packet = conn.read_packet(self.compressor).await?;
404                let key = packet.key.clone();
405                match packet.error_for_status() {
406                    Ok(()) => (),
407                    Err(err) => {
408                        errors.insert(key, Error::Status(err));
409                    }
410                }
411            }
412        }
413
414        Ok(errors)
415    }
416
417    async fn keep_alive(&mut self) -> Result<(), Error> {
418        // TODO: verify read_packet returns a noop code
419        for conn in self.ring.into_iter() {
420            conn.write_packet(self.compressor, Packet::noop()?).await?;
421            let packet = conn.read_packet(self.compressor).await?;
422            packet.error_for_status()?;
423        }
424        Ok(())
425    }
426}
427
428#[async_trait]
429impl<C, P> Manager for ClientConfig<C, P>
430where
431    C: Connection,
432    P: Compressor,
433{
434    type Type = Client<C, P>;
435    type Error = Error;
436
437    async fn create(&self) -> Result<Self::Type, Error> {
438        let mut client = Client::new(self.clone()).await?;
439        client.keep_alive().await?;
440        Ok(client)
441    }
442
443    async fn recycle(&self, client: &mut Self::Type) -> RecycleResult<Error> {
444        client.keep_alive().await?;
445        Ok(())
446    }
447}
448
449/// A connection pool for multiple connections. Using a pool is recommended
450/// for best performance since it eliminates the overhead of having to
451/// constantly recreate TCP connections, while also balancing the total
452/// number of connections open at a time.
453pub type Pool<C, P> = deadpool::managed::Pool<ClientConfig<C, P>>;
454
455#[cfg(test)]
456mod tests {
457    use crate::protocol::ProtocolError;
458
459    use super::Error;
460
461    #[test]
462    fn test_err_display() {
463        assert_eq!(
464            "ProtocolError: Invalid magic byte: 8",
465            format!("{}", Error::Protocol(ProtocolError::InvalidMagic(8)))
466        );
467        assert_eq!(
468            "StatusError: Key not found",
469            format!("{}", Error::Status(crate::protocol::Status::KeyNotFound))
470        );
471    }
472}