memcache_rawl/
lib.rs

1//! This is a simplified implementation of [rust-memcache](https://github.com/aisk/rust-memcache)
2//! ported for AsyncRead + AsyncWrite.
3use std::{
4    collections::HashMap,
5    fmt::Display,
6    io::{
7        Error,
8        ErrorKind,
9    },
10    marker::Unpin,
11};
12use tokio::io::{
13    AsyncBufReadExt,
14    AsyncRead,
15    AsyncReadExt,
16    AsyncWrite,
17    AsyncWriteExt,
18    BufReader,
19};
20
21/// Memcache ASCII protocol implementation.
22pub struct Protocol<S> {
23    stream: BufReader<S>,
24    buf: Vec<u8>,
25}
26
27impl<S> Protocol<S>
28where
29    S: AsyncRead + AsyncWrite + Unpin,
30{
31    /// Creates the ASCII protocol on a stream.
32    pub fn new(stream: S) -> Self {
33        Self {
34            stream: BufReader::new(stream),
35            buf: Vec::new(),
36        }
37    }
38
39    /// Returns the value for given key as bytes. If the value doesn't exist, [`ErrorKind::NotFound`] is returned.
40    pub async fn get<K: AsRef<[u8]>>(
41        &mut self,
42        key: K,
43    ) -> Result<Vec<u8>, Error> {
44        // Send command
45        let writer = self.stream.get_mut();
46        writer
47            .write_all(&[b"get ", key.as_ref(), b"\r\n"].concat())
48            .await?;
49        writer.flush().await?;
50
51        let (val, _) = self.read_get_response().await?;
52        Ok(val)
53    }
54
55    async fn read_get_response(&mut self) -> Result<(Vec<u8>, Option<u64>), Error> {
56        // Read response header
57        let header = self.read_line().await?;
58        let header = std::str::from_utf8(header).map_err(|_| ErrorKind::InvalidData)?;
59
60        // Check response header and parse value length
61        if header.starts_with("ERROR") ||
62            header.starts_with("CLIENT_ERROR") ||
63            header.starts_with("SERVER_ERROR")
64        {
65            return Err(Error::other(header));
66        } else if header.starts_with("END") {
67            return Err(ErrorKind::NotFound.into());
68        }
69
70        // VALUE <key> <flags> <bytes> [<cas unique>]\r\n
71        let mut parts = header.split(' ');
72        let length: usize = parts
73            .nth(3)
74            .and_then(|len| len.trim_end().parse().ok())
75            .ok_or(ErrorKind::InvalidData)?;
76
77        // cas is present only if gets is called
78        let cas: Option<u64> = parts
79            .next()
80            .and_then(|len| len.trim_end().parse().ok());
81
82        // Read value
83        let mut buffer: Vec<u8> = vec![0; length];
84        self.stream
85            .read_exact(&mut buffer)
86            .await?;
87
88        // Read the trailing header
89        self.read_line().await?; // \r\n
90        self.read_line().await?; // END\r\n
91
92        Ok((buffer, cas))
93    }
94
95    /// Returns values for multiple keys in a single call as a [`HashMap`] from keys to found values.
96    /// If a key is not present in memcached it will be absent from returned map.
97    pub async fn get_multi<K: AsRef<[u8]>>(
98        &mut self,
99        keys: &[K],
100    ) -> Result<HashMap<String, Vec<u8>>, Error> {
101        if keys.is_empty() {
102            return Ok(HashMap::new());
103        }
104
105        // Send command
106        let writer = self.stream.get_mut();
107        writer
108            .write_all("get".as_bytes())
109            .await?;
110        for k in keys {
111            writer.write_all(b" ").await?;
112            writer
113                .write_all(k.as_ref())
114                .await?;
115        }
116        writer.write_all(b"\r\n").await?;
117        writer.flush().await?;
118
119        // Read response header
120        self.read_many_values().await
121    }
122
123    async fn read_many_values(&mut self) -> Result<HashMap<String, Vec<u8>>, Error> {
124        let mut map = HashMap::new();
125        loop {
126            let header = {
127                let buf = self.read_line().await?;
128                std::str::from_utf8(buf).map_err(|_| Error::from(ErrorKind::InvalidData))?
129            }
130            .to_string();
131            let mut parts = header.split(' ');
132            match parts.next() {
133                Some("VALUE") => {
134                    if let (Some(key), _flags, Some(size_str)) = (
135                        parts.next(),
136                        parts.next(),
137                        parts.next(),
138                    ) {
139                        let size: usize = size_str
140                            .trim_end()
141                            .parse()
142                            .map_err(|_| Error::from(ErrorKind::InvalidData))?;
143                        let mut buffer: Vec<u8> = vec![0; size];
144                        self.stream
145                            .read_exact(&mut buffer)
146                            .await?;
147                        let mut crlf = vec![0; 2];
148                        self.stream
149                            .read_exact(&mut crlf)
150                            .await?;
151
152                        map.insert(key.to_owned(), buffer);
153                    } else {
154                        return Err(Error::new(
155                            ErrorKind::InvalidData,
156                            header,
157                        ));
158                    }
159                }
160                Some("END\r\n") => return Ok(map),
161                Some("ERROR") => {
162                    return Err(Error::other(header));
163                }
164                _ => {
165                    return Err(Error::new(
166                        ErrorKind::InvalidData,
167                        header,
168                    ));
169                }
170            }
171        }
172    }
173
174    /// Get up to `limit` keys which match the given prefix. Returns a [HashMap] from keys to found values.
175    /// This is not part of the Memcached standard, but some servers implement it nonetheless.
176    pub async fn get_prefix<K: Display>(
177        &mut self,
178        key_prefix: K,
179        limit: Option<usize>,
180    ) -> Result<HashMap<String, Vec<u8>>, Error> {
181        // Send command
182        let header = if let Some(limit) = limit {
183            format!("get_prefix {key_prefix} {limit}\r\n")
184        } else {
185            format!("get_prefix {key_prefix}\r\n")
186        };
187        self.stream
188            .write_all(header.as_bytes())
189            .await?;
190        self.stream.flush().await?;
191
192        // Read response header
193        self.read_many_values().await
194    }
195
196    /// Add a key. If the value exists, [`ErrorKind::AlreadyExists`] is returned.
197    pub async fn add<K: Display>(
198        &mut self,
199        key: K,
200        val: &[u8],
201        expiration: u32,
202    ) -> Result<(), Error> {
203        // Send command
204        let header = format!(
205            "add {} 0 {} {}\r\n",
206            key,
207            expiration,
208            val.len()
209        );
210        self.stream
211            .write_all(header.as_bytes())
212            .await?;
213        self.stream
214            .write_all(val)
215            .await?;
216        self.stream
217            .write_all(b"\r\n")
218            .await?;
219        self.stream.flush().await?;
220
221        // Read response header
222        let header = {
223            let buf = self.read_line().await?;
224            std::str::from_utf8(buf).map_err(|_| Error::from(ErrorKind::InvalidData))?
225        };
226
227        // Check response header and parse value length
228        if header.contains("ERROR") {
229            return Err(Error::other(header));
230        } else if header.starts_with("NOT_STORED") {
231            return Err(ErrorKind::AlreadyExists.into());
232        }
233
234        Ok(())
235    }
236
237    /// Set key to given value and don't wait for response.
238    pub async fn set<K: Display>(
239        &mut self,
240        key: K,
241        val: &[u8],
242        expiration: u32,
243    ) -> Result<(), Error> {
244        let header = format!(
245            "set {} 0 {} {} noreply\r\n",
246            key,
247            expiration,
248            val.len()
249        );
250        self.stream
251            .write_all(header.as_bytes())
252            .await?;
253        self.stream
254            .write_all(val)
255            .await?;
256        self.stream
257            .write_all(b"\r\n")
258            .await?;
259        self.stream.flush().await?;
260        Ok(())
261    }
262
263    /// Append bytes to the value in memcached and don't wait for response.
264    pub async fn append<K: Display>(
265        &mut self,
266        key: K,
267        val: &[u8],
268    ) -> Result<(), Error> {
269        let header = format!(
270            "append {} 0 0 {} noreply\r\n",
271            key,
272            val.len()
273        );
274        self.stream
275            .write_all(header.as_bytes())
276            .await?;
277        self.stream
278            .write_all(val)
279            .await?;
280        self.stream
281            .write_all(b"\r\n")
282            .await?;
283        self.stream.flush().await?;
284        Ok(())
285    }
286
287    /// Delete a key and don't wait for response.
288    pub async fn delete<K: Display>(
289        &mut self,
290        key: K,
291    ) -> Result<(), Error> {
292        let header = format!("delete {key} noreply\r\n");
293        self.stream
294            .write_all(header.as_bytes())
295            .await?;
296        self.stream.flush().await?;
297        Ok(())
298    }
299
300    /// Return the version of the remote server.
301    pub async fn version(&mut self) -> Result<String, Error> {
302        self.stream
303            .write_all(b"version\r\n")
304            .await?;
305        self.stream.flush().await?;
306
307        // Read response header
308        let header = {
309            let buf = self.read_line().await?;
310            std::str::from_utf8(buf).map_err(|_| Error::from(ErrorKind::InvalidData))?
311        };
312
313        if !header.starts_with("VERSION") {
314            return Err(Error::other(header));
315        }
316        let version = header
317            .trim_start_matches("VERSION ")
318            .trim_end();
319        Ok(version.to_string())
320    }
321
322    /// Delete all keys from the cache.
323    pub async fn flush(&mut self) -> Result<(), Error> {
324        self.stream
325            .write_all(b"flush_all\r\n")
326            .await?;
327        self.stream.flush().await?;
328
329        // Read response header
330        let header = {
331            let buf = self.read_line().await?;
332            std::str::from_utf8(buf).map_err(|_| Error::from(ErrorKind::InvalidData))?
333        };
334
335        if header == "OK\r\n" {
336            Ok(())
337        } else {
338            Err(Error::other(header))
339        }
340    }
341
342    /// Increment a specific integer stored with a key by a given value. If the value doesn't exist, [`ErrorKind::NotFound`] is returned.
343    /// Otherwise the new value is returned
344    pub async fn increment<K: AsRef<[u8]>>(
345        &mut self,
346        key: K,
347        amount: u64,
348    ) -> Result<u64, Error> {
349        // Send command
350        let writer = self.stream.get_mut();
351        let buf = &[
352            b"incr ",
353            key.as_ref(),
354            b" ",
355            amount.to_string().as_bytes(),
356            b"\r\n",
357        ]
358        .concat();
359        writer.write_all(buf).await?;
360        writer.flush().await?;
361
362        // Read response header
363        let header = {
364            let buf = self.read_line().await?;
365            std::str::from_utf8(buf).map_err(|_| Error::from(ErrorKind::InvalidData))?
366        };
367
368        if header == "NOT_FOUND\r\n" {
369            Err(ErrorKind::NotFound.into())
370        } else {
371            let value = header
372                .trim_end()
373                .parse::<u64>()
374                .map_err(|_| Error::from(ErrorKind::InvalidData))?;
375            Ok(value)
376        }
377    }
378
379    /// Decrement a specific integer stored with a key by a given value. If the value doesn't exist, [`ErrorKind::NotFound`] is returned.
380    /// Otherwise the new value is returned
381    pub async fn decrement<K: AsRef<[u8]>>(
382        &mut self,
383        key: K,
384        amount: u64,
385    ) -> Result<u64, Error> {
386        // Send command
387        let writer = self.stream.get_mut();
388        let buf = &[
389            b"decr ",
390            key.as_ref(),
391            b" ",
392            amount.to_string().as_bytes(),
393            b"\r\n",
394        ]
395        .concat();
396        writer.write_all(buf).await?;
397        writer.flush().await?;
398
399        // Read response header
400        let header = {
401            let buf = self.read_line().await?;
402            std::str::from_utf8(buf).map_err(|_| Error::from(ErrorKind::InvalidData))?
403        };
404
405        if header == "NOT_FOUND\r\n" {
406            Err(ErrorKind::NotFound.into())
407        } else {
408            let value = header
409                .trim_end()
410                .parse::<u64>()
411                .map_err(|_| Error::from(ErrorKind::InvalidData))?;
412            Ok(value)
413        }
414    }
415
416    async fn read_line(&mut self) -> Result<&[u8], Error> {
417        let Self { stream: io, buf } = self;
418        buf.clear();
419        io.read_until(b'\n', buf).await?;
420        if buf.last().copied() != Some(b'\n') {
421            return Err(ErrorKind::UnexpectedEof.into());
422        }
423        Ok(&buf[..])
424    }
425
426    /// Call gets to also return CAS id, which can be used to run a second CAS command
427    pub async fn gets_cas<K: AsRef<[u8]>>(
428        &mut self,
429        key: K,
430    ) -> Result<(Vec<u8>, u64), Error> {
431        // Send command
432        let writer = self.stream.get_mut();
433        writer
434            .write_all(&[b"gets ", key.as_ref(), b"\r\n"].concat())
435            .await?;
436        writer.flush().await?;
437
438        let (val, maybe_cas) = self.read_get_response().await?;
439        let cas = maybe_cas.ok_or(ErrorKind::InvalidData)?;
440
441        Ok((val, cas))
442    }
443
444    // CAS: compare and swap a value. the value of `cas` can be obtained by first making a gets_cas
445    // call
446    // returns true/false to indicate the cas operation succeeded or failed
447    // returns an error for all other failures
448    pub async fn cas<K: Display>(
449        &mut self,
450        key: K,
451        val: &[u8],
452        cas_id: u64,
453        expiration: u32,
454    ) -> Result<bool, Error> {
455        let header = format!(
456            "cas {} 0 {} {} {}\r\n",
457            key,
458            expiration,
459            val.len(),
460            cas_id
461        );
462        self.stream
463            .write_all(header.as_bytes())
464            .await?;
465        self.stream
466            .write_all(val)
467            .await?;
468        self.stream
469            .write_all(b"\r\n")
470            .await?;
471        self.stream.flush().await?;
472
473        // Read response header
474        let header = {
475            let buf = self.read_line().await?;
476            std::str::from_utf8(buf).map_err(|_| Error::from(ErrorKind::InvalidData))?
477        };
478
479        /* From memcached docs:
480         *    After sending the command line and the data block the client awaits
481         *    the reply, which may be:
482         *
483         *    - "STORED\r\n", to indicate success.
484         *
485         *    - "NOT_STORED\r\n" to indicate the data was not stored, but not
486         *    because of an error. This normally means that the
487         *    condition for an "add" or a "replace" command wasn't met.
488         *
489         *    - "EXISTS\r\n" to indicate that the item you are trying to store with
490         *    a "cas" command has been modified since you last fetched it.
491         *
492         *    - "NOT_FOUND\r\n" to indicate that the item you are trying to store
493         *    with a "cas" command did not exist.
494         */
495
496        if header.starts_with("STORED") {
497            Ok(true)
498        } else if header.starts_with("EXISTS") || header.starts_with("NOT_STORED") {
499            Ok(false)
500        } else if header.starts_with("NOT FOUND") {
501            Err(ErrorKind::NotFound.into())
502        } else {
503            Err(Error::other(header))
504        }
505    }
506
507    /// Append bytes to the value in memcached, and creates the key if it is missing instead of failing compared to simple append
508    pub async fn append_or_vivify<K: Display>(
509        &mut self,
510        key: K,
511        val: &[u8],
512        ttl: u32,
513    ) -> Result<(), Error> {
514        /* From memcached docs:
515         * - M(token): mode switch to change behavior to add, replace, append, prepend. Takes a single character for the mode.
516         *       A: "append" command. If item exists, append the new value to its data.
517         * ----
518         * The "cas" command is supplanted by specifying the cas value with the 'C' flag.
519         * Append and Prepend modes will also respect a supplied cas value.
520         *
521         * - N(token): if in append mode, autovivify on miss with supplied TTL
522         *
523         * If N is supplied, and append reaches a miss, it will
524         * create a new item seeded with the data from the append command.
525         */
526        let header = format!(
527            "ms {} {} MA N{}\r\n",
528            key,
529            val.len(),
530            ttl
531        );
532        self.stream
533            .write_all(header.as_bytes())
534            .await?;
535        self.stream
536            .write_all(val)
537            .await?;
538        self.stream
539            .write_all(b"\r\n")
540            .await?;
541        self.stream.flush().await?;
542
543        // Read response header
544        let header = {
545            let buf = self.read_line().await?;
546            std::str::from_utf8(buf).map_err(|_| Error::from(ErrorKind::InvalidData))?
547        };
548
549        /* From memcached docs:
550         *    After sending the command line and the data block the client awaits
551         *    the reply, which is of the format:
552         *
553         *    <CD> <flags>*\r\n
554         *
555         *    Where CD is one of:
556         *
557         *    - "HD" (STORED), to indicate success.
558         *
559         *    - "NS" (NOT_STORED), to indicate the data was not stored, but not
560         *    because of an error.
561         *
562         *    - "EX" (EXISTS), to indicate that the item you are trying to store with
563         *    CAS semantics has been modified since you last fetched it.
564         *
565         *    - "NF" (NOT_FOUND), to indicate that the item you are trying to store
566         *    with CAS semantics did not exist.
567         */
568
569        if header.starts_with("HD") {
570            Ok(())
571        } else {
572            Err(Error::other(header))
573        }
574    }
575}