1use std::{
4    collections::HashMap,
5    fmt::Display,
6    io::{
7        Error,
8        ErrorKind,
9    },
10    marker::Unpin,
11};
12use tokio::io::{
13    AsyncBufReadExt,
14    AsyncRead,
15    AsyncReadExt,
16    AsyncWrite,
17    AsyncWriteExt,
18    BufReader,
19};
20
21pub struct Protocol<S> {
23    stream: BufReader<S>,
24    buf: Vec<u8>,
25}
26
27impl<S> Protocol<S>
28where
29    S: AsyncRead + AsyncWrite + Unpin,
30{
31    pub fn new(stream: S) -> Self {
33        Self {
34            stream: BufReader::new(stream),
35            buf: Vec::new(),
36        }
37    }
38
39    pub async fn get<K: AsRef<[u8]>>(
41        &mut self,
42        key: K,
43    ) -> Result<Vec<u8>, Error> {
44        let writer = self.stream.get_mut();
46        writer
47            .write_all(&[b"get ", key.as_ref(), b"\r\n"].concat())
48            .await?;
49        writer.flush().await?;
50
51        let (val, _) = self.read_get_response().await?;
52        Ok(val)
53    }
54
55    async fn read_get_response(&mut self) -> Result<(Vec<u8>, Option<u64>), Error> {
56        let header = self.read_line().await?;
58        let header = std::str::from_utf8(header).map_err(|_| ErrorKind::InvalidData)?;
59
60        if header.starts_with("ERROR") ||
62            header.starts_with("CLIENT_ERROR") ||
63            header.starts_with("SERVER_ERROR")
64        {
65            return Err(Error::other(header));
66        } else if header.starts_with("END") {
67            return Err(ErrorKind::NotFound.into());
68        }
69
70        let mut parts = header.split(' ');
72        let length: usize = parts
73            .nth(3)
74            .and_then(|len| len.trim_end().parse().ok())
75            .ok_or(ErrorKind::InvalidData)?;
76
77        let cas: Option<u64> = parts
79            .next()
80            .and_then(|len| len.trim_end().parse().ok());
81
82        let mut buffer: Vec<u8> = vec![0; length];
84        self.stream
85            .read_exact(&mut buffer)
86            .await?;
87
88        self.read_line().await?; self.read_line().await?; Ok((buffer, cas))
93    }
94
95    pub async fn get_multi<K: AsRef<[u8]>>(
98        &mut self,
99        keys: &[K],
100    ) -> Result<HashMap<String, Vec<u8>>, Error> {
101        if keys.is_empty() {
102            return Ok(HashMap::new());
103        }
104
105        let writer = self.stream.get_mut();
107        writer
108            .write_all("get".as_bytes())
109            .await?;
110        for k in keys {
111            writer.write_all(b" ").await?;
112            writer
113                .write_all(k.as_ref())
114                .await?;
115        }
116        writer.write_all(b"\r\n").await?;
117        writer.flush().await?;
118
119        self.read_many_values().await
121    }
122
123    async fn read_many_values(&mut self) -> Result<HashMap<String, Vec<u8>>, Error> {
124        let mut map = HashMap::new();
125        loop {
126            let header = {
127                let buf = self.read_line().await?;
128                std::str::from_utf8(buf).map_err(|_| Error::from(ErrorKind::InvalidData))?
129            }
130            .to_string();
131            let mut parts = header.split(' ');
132            match parts.next() {
133                Some("VALUE") => {
134                    if let (Some(key), _flags, Some(size_str)) = (
135                        parts.next(),
136                        parts.next(),
137                        parts.next(),
138                    ) {
139                        let size: usize = size_str
140                            .trim_end()
141                            .parse()
142                            .map_err(|_| Error::from(ErrorKind::InvalidData))?;
143                        let mut buffer: Vec<u8> = vec![0; size];
144                        self.stream
145                            .read_exact(&mut buffer)
146                            .await?;
147                        let mut crlf = vec![0; 2];
148                        self.stream
149                            .read_exact(&mut crlf)
150                            .await?;
151
152                        map.insert(key.to_owned(), buffer);
153                    } else {
154                        return Err(Error::new(
155                            ErrorKind::InvalidData,
156                            header,
157                        ));
158                    }
159                }
160                Some("END\r\n") => return Ok(map),
161                Some("ERROR") => {
162                    return Err(Error::other(header));
163                }
164                _ => {
165                    return Err(Error::new(
166                        ErrorKind::InvalidData,
167                        header,
168                    ));
169                }
170            }
171        }
172    }
173
174    pub async fn get_prefix<K: Display>(
177        &mut self,
178        key_prefix: K,
179        limit: Option<usize>,
180    ) -> Result<HashMap<String, Vec<u8>>, Error> {
181        let header = if let Some(limit) = limit {
183            format!("get_prefix {key_prefix} {limit}\r\n")
184        } else {
185            format!("get_prefix {key_prefix}\r\n")
186        };
187        self.stream
188            .write_all(header.as_bytes())
189            .await?;
190        self.stream.flush().await?;
191
192        self.read_many_values().await
194    }
195
196    pub async fn add<K: Display>(
198        &mut self,
199        key: K,
200        val: &[u8],
201        expiration: u32,
202    ) -> Result<(), Error> {
203        let header = format!(
205            "add {} 0 {} {}\r\n",
206            key,
207            expiration,
208            val.len()
209        );
210        self.stream
211            .write_all(header.as_bytes())
212            .await?;
213        self.stream
214            .write_all(val)
215            .await?;
216        self.stream
217            .write_all(b"\r\n")
218            .await?;
219        self.stream.flush().await?;
220
221        let header = {
223            let buf = self.read_line().await?;
224            std::str::from_utf8(buf).map_err(|_| Error::from(ErrorKind::InvalidData))?
225        };
226
227        if header.contains("ERROR") {
229            return Err(Error::other(header));
230        } else if header.starts_with("NOT_STORED") {
231            return Err(ErrorKind::AlreadyExists.into());
232        }
233
234        Ok(())
235    }
236
237    pub async fn set<K: Display>(
239        &mut self,
240        key: K,
241        val: &[u8],
242        expiration: u32,
243    ) -> Result<(), Error> {
244        let header = format!(
245            "set {} 0 {} {} noreply\r\n",
246            key,
247            expiration,
248            val.len()
249        );
250        self.stream
251            .write_all(header.as_bytes())
252            .await?;
253        self.stream
254            .write_all(val)
255            .await?;
256        self.stream
257            .write_all(b"\r\n")
258            .await?;
259        self.stream.flush().await?;
260        Ok(())
261    }
262
263    pub async fn append<K: Display>(
265        &mut self,
266        key: K,
267        val: &[u8],
268    ) -> Result<(), Error> {
269        let header = format!(
270            "append {} 0 0 {} noreply\r\n",
271            key,
272            val.len()
273        );
274        self.stream
275            .write_all(header.as_bytes())
276            .await?;
277        self.stream
278            .write_all(val)
279            .await?;
280        self.stream
281            .write_all(b"\r\n")
282            .await?;
283        self.stream.flush().await?;
284        Ok(())
285    }
286
287    pub async fn delete<K: Display>(
289        &mut self,
290        key: K,
291    ) -> Result<(), Error> {
292        let header = format!("delete {key} noreply\r\n");
293        self.stream
294            .write_all(header.as_bytes())
295            .await?;
296        self.stream.flush().await?;
297        Ok(())
298    }
299
300    pub async fn version(&mut self) -> Result<String, Error> {
302        self.stream
303            .write_all(b"version\r\n")
304            .await?;
305        self.stream.flush().await?;
306
307        let header = {
309            let buf = self.read_line().await?;
310            std::str::from_utf8(buf).map_err(|_| Error::from(ErrorKind::InvalidData))?
311        };
312
313        if !header.starts_with("VERSION") {
314            return Err(Error::other(header));
315        }
316        let version = header
317            .trim_start_matches("VERSION ")
318            .trim_end();
319        Ok(version.to_string())
320    }
321
322    pub async fn flush(&mut self) -> Result<(), Error> {
324        self.stream
325            .write_all(b"flush_all\r\n")
326            .await?;
327        self.stream.flush().await?;
328
329        let header = {
331            let buf = self.read_line().await?;
332            std::str::from_utf8(buf).map_err(|_| Error::from(ErrorKind::InvalidData))?
333        };
334
335        if header == "OK\r\n" {
336            Ok(())
337        } else {
338            Err(Error::other(header))
339        }
340    }
341
342    pub async fn increment<K: AsRef<[u8]>>(
345        &mut self,
346        key: K,
347        amount: u64,
348    ) -> Result<u64, Error> {
349        let writer = self.stream.get_mut();
351        let buf = &[
352            b"incr ",
353            key.as_ref(),
354            b" ",
355            amount.to_string().as_bytes(),
356            b"\r\n",
357        ]
358        .concat();
359        writer.write_all(buf).await?;
360        writer.flush().await?;
361
362        let header = {
364            let buf = self.read_line().await?;
365            std::str::from_utf8(buf).map_err(|_| Error::from(ErrorKind::InvalidData))?
366        };
367
368        if header == "NOT_FOUND\r\n" {
369            Err(ErrorKind::NotFound.into())
370        } else {
371            let value = header
372                .trim_end()
373                .parse::<u64>()
374                .map_err(|_| Error::from(ErrorKind::InvalidData))?;
375            Ok(value)
376        }
377    }
378
379    pub async fn decrement<K: AsRef<[u8]>>(
382        &mut self,
383        key: K,
384        amount: u64,
385    ) -> Result<u64, Error> {
386        let writer = self.stream.get_mut();
388        let buf = &[
389            b"decr ",
390            key.as_ref(),
391            b" ",
392            amount.to_string().as_bytes(),
393            b"\r\n",
394        ]
395        .concat();
396        writer.write_all(buf).await?;
397        writer.flush().await?;
398
399        let header = {
401            let buf = self.read_line().await?;
402            std::str::from_utf8(buf).map_err(|_| Error::from(ErrorKind::InvalidData))?
403        };
404
405        if header == "NOT_FOUND\r\n" {
406            Err(ErrorKind::NotFound.into())
407        } else {
408            let value = header
409                .trim_end()
410                .parse::<u64>()
411                .map_err(|_| Error::from(ErrorKind::InvalidData))?;
412            Ok(value)
413        }
414    }
415
416    async fn read_line(&mut self) -> Result<&[u8], Error> {
417        let Self { stream: io, buf } = self;
418        buf.clear();
419        io.read_until(b'\n', buf).await?;
420        if buf.last().copied() != Some(b'\n') {
421            return Err(ErrorKind::UnexpectedEof.into());
422        }
423        Ok(&buf[..])
424    }
425
426    pub async fn gets_cas<K: AsRef<[u8]>>(
428        &mut self,
429        key: K,
430    ) -> Result<(Vec<u8>, u64), Error> {
431        let writer = self.stream.get_mut();
433        writer
434            .write_all(&[b"gets ", key.as_ref(), b"\r\n"].concat())
435            .await?;
436        writer.flush().await?;
437
438        let (val, maybe_cas) = self.read_get_response().await?;
439        let cas = maybe_cas.ok_or(ErrorKind::InvalidData)?;
440
441        Ok((val, cas))
442    }
443
444    pub async fn cas<K: Display>(
449        &mut self,
450        key: K,
451        val: &[u8],
452        cas_id: u64,
453        expiration: u32,
454    ) -> Result<bool, Error> {
455        let header = format!(
456            "cas {} 0 {} {} {}\r\n",
457            key,
458            expiration,
459            val.len(),
460            cas_id
461        );
462        self.stream
463            .write_all(header.as_bytes())
464            .await?;
465        self.stream
466            .write_all(val)
467            .await?;
468        self.stream
469            .write_all(b"\r\n")
470            .await?;
471        self.stream.flush().await?;
472
473        let header = {
475            let buf = self.read_line().await?;
476            std::str::from_utf8(buf).map_err(|_| Error::from(ErrorKind::InvalidData))?
477        };
478
479        if header.starts_with("STORED") {
497            Ok(true)
498        } else if header.starts_with("EXISTS") || header.starts_with("NOT_STORED") {
499            Ok(false)
500        } else if header.starts_with("NOT FOUND") {
501            Err(ErrorKind::NotFound.into())
502        } else {
503            Err(Error::other(header))
504        }
505    }
506
507    pub async fn append_or_vivify<K: Display>(
509        &mut self,
510        key: K,
511        val: &[u8],
512        ttl: u32,
513    ) -> Result<(), Error> {
514        let header = format!(
527            "ms {} {} MA N{}\r\n",
528            key,
529            val.len(),
530            ttl
531        );
532        self.stream
533            .write_all(header.as_bytes())
534            .await?;
535        self.stream
536            .write_all(val)
537            .await?;
538        self.stream
539            .write_all(b"\r\n")
540            .await?;
541        self.stream.flush().await?;
542
543        let header = {
545            let buf = self.read_line().await?;
546            std::str::from_utf8(buf).map_err(|_| Error::from(ErrorKind::InvalidData))?
547        };
548
549        if header.starts_with("HD") {
570            Ok(())
571        } else {
572            Err(Error::other(header))
573        }
574    }
575}