1use std::{
4 collections::HashMap,
5 fmt::Display,
6 io::{
7 Error,
8 ErrorKind,
9 },
10 marker::Unpin,
11};
12use tokio::io::{
13 AsyncBufReadExt,
14 AsyncRead,
15 AsyncReadExt,
16 AsyncWrite,
17 AsyncWriteExt,
18 BufReader,
19};
20
21pub struct Protocol<S> {
23 stream: BufReader<S>,
24 buf: Vec<u8>,
25}
26
27impl<S> Protocol<S>
28where
29 S: AsyncRead + AsyncWrite + Unpin,
30{
31 pub fn new(stream: S) -> Self {
33 Self {
34 stream: BufReader::new(stream),
35 buf: Vec::new(),
36 }
37 }
38
39 pub async fn get<K: AsRef<[u8]>>(
41 &mut self,
42 key: K,
43 ) -> Result<Vec<u8>, Error> {
44 let writer = self.stream.get_mut();
46 writer
47 .write_all(&[b"get ", key.as_ref(), b"\r\n"].concat())
48 .await?;
49 writer.flush().await?;
50
51 let (val, _) = self.read_get_response().await?;
52 Ok(val)
53 }
54
55 async fn read_get_response(&mut self) -> Result<(Vec<u8>, Option<u64>), Error> {
56 let header = self.read_line().await?;
58 let header = std::str::from_utf8(header).map_err(|_| ErrorKind::InvalidData)?;
59
60 if header.starts_with("ERROR") ||
62 header.starts_with("CLIENT_ERROR") ||
63 header.starts_with("SERVER_ERROR")
64 {
65 return Err(Error::other(header));
66 } else if header.starts_with("END") {
67 return Err(ErrorKind::NotFound.into());
68 }
69
70 let mut parts = header.split(' ');
72 let length: usize = parts
73 .nth(3)
74 .and_then(|len| len.trim_end().parse().ok())
75 .ok_or(ErrorKind::InvalidData)?;
76
77 let cas: Option<u64> = parts
79 .next()
80 .and_then(|len| len.trim_end().parse().ok());
81
82 let mut buffer: Vec<u8> = vec![0; length];
84 self.stream
85 .read_exact(&mut buffer)
86 .await?;
87
88 self.read_line().await?; self.read_line().await?; Ok((buffer, cas))
93 }
94
95 pub async fn get_multi<K: AsRef<[u8]>>(
98 &mut self,
99 keys: &[K],
100 ) -> Result<HashMap<String, Vec<u8>>, Error> {
101 if keys.is_empty() {
102 return Ok(HashMap::new());
103 }
104
105 let writer = self.stream.get_mut();
107 writer
108 .write_all("get".as_bytes())
109 .await?;
110 for k in keys {
111 writer.write_all(b" ").await?;
112 writer
113 .write_all(k.as_ref())
114 .await?;
115 }
116 writer.write_all(b"\r\n").await?;
117 writer.flush().await?;
118
119 self.read_many_values().await
121 }
122
123 async fn read_many_values(&mut self) -> Result<HashMap<String, Vec<u8>>, Error> {
124 let mut map = HashMap::new();
125 loop {
126 let header = {
127 let buf = self.read_line().await?;
128 std::str::from_utf8(buf).map_err(|_| Error::from(ErrorKind::InvalidData))?
129 }
130 .to_string();
131 let mut parts = header.split(' ');
132 match parts.next() {
133 Some("VALUE") => {
134 if let (Some(key), _flags, Some(size_str)) = (
135 parts.next(),
136 parts.next(),
137 parts.next(),
138 ) {
139 let size: usize = size_str
140 .trim_end()
141 .parse()
142 .map_err(|_| Error::from(ErrorKind::InvalidData))?;
143 let mut buffer: Vec<u8> = vec![0; size];
144 self.stream
145 .read_exact(&mut buffer)
146 .await?;
147 let mut crlf = vec![0; 2];
148 self.stream
149 .read_exact(&mut crlf)
150 .await?;
151
152 map.insert(key.to_owned(), buffer);
153 } else {
154 return Err(Error::new(
155 ErrorKind::InvalidData,
156 header,
157 ));
158 }
159 }
160 Some("END\r\n") => return Ok(map),
161 Some("ERROR") => {
162 return Err(Error::other(header));
163 }
164 _ => {
165 return Err(Error::new(
166 ErrorKind::InvalidData,
167 header,
168 ));
169 }
170 }
171 }
172 }
173
174 pub async fn get_prefix<K: Display>(
177 &mut self,
178 key_prefix: K,
179 limit: Option<usize>,
180 ) -> Result<HashMap<String, Vec<u8>>, Error> {
181 let header = if let Some(limit) = limit {
183 format!("get_prefix {key_prefix} {limit}\r\n")
184 } else {
185 format!("get_prefix {key_prefix}\r\n")
186 };
187 self.stream
188 .write_all(header.as_bytes())
189 .await?;
190 self.stream.flush().await?;
191
192 self.read_many_values().await
194 }
195
196 pub async fn add<K: Display>(
198 &mut self,
199 key: K,
200 val: &[u8],
201 expiration: u32,
202 ) -> Result<(), Error> {
203 let header = format!(
205 "add {} 0 {} {}\r\n",
206 key,
207 expiration,
208 val.len()
209 );
210 self.stream
211 .write_all(header.as_bytes())
212 .await?;
213 self.stream
214 .write_all(val)
215 .await?;
216 self.stream
217 .write_all(b"\r\n")
218 .await?;
219 self.stream.flush().await?;
220
221 let header = {
223 let buf = self.read_line().await?;
224 std::str::from_utf8(buf).map_err(|_| Error::from(ErrorKind::InvalidData))?
225 };
226
227 if header.contains("ERROR") {
229 return Err(Error::other(header));
230 } else if header.starts_with("NOT_STORED") {
231 return Err(ErrorKind::AlreadyExists.into());
232 }
233
234 Ok(())
235 }
236
237 pub async fn set<K: Display>(
239 &mut self,
240 key: K,
241 val: &[u8],
242 expiration: u32,
243 ) -> Result<(), Error> {
244 let header = format!(
245 "set {} 0 {} {} noreply\r\n",
246 key,
247 expiration,
248 val.len()
249 );
250 self.stream
251 .write_all(header.as_bytes())
252 .await?;
253 self.stream
254 .write_all(val)
255 .await?;
256 self.stream
257 .write_all(b"\r\n")
258 .await?;
259 self.stream.flush().await?;
260 Ok(())
261 }
262
263 pub async fn append<K: Display>(
265 &mut self,
266 key: K,
267 val: &[u8],
268 ) -> Result<(), Error> {
269 let header = format!(
270 "append {} 0 0 {} noreply\r\n",
271 key,
272 val.len()
273 );
274 self.stream
275 .write_all(header.as_bytes())
276 .await?;
277 self.stream
278 .write_all(val)
279 .await?;
280 self.stream
281 .write_all(b"\r\n")
282 .await?;
283 self.stream.flush().await?;
284 Ok(())
285 }
286
287 pub async fn delete<K: Display>(
289 &mut self,
290 key: K,
291 ) -> Result<(), Error> {
292 let header = format!("delete {key} noreply\r\n");
293 self.stream
294 .write_all(header.as_bytes())
295 .await?;
296 self.stream.flush().await?;
297 Ok(())
298 }
299
300 pub async fn version(&mut self) -> Result<String, Error> {
302 self.stream
303 .write_all(b"version\r\n")
304 .await?;
305 self.stream.flush().await?;
306
307 let header = {
309 let buf = self.read_line().await?;
310 std::str::from_utf8(buf).map_err(|_| Error::from(ErrorKind::InvalidData))?
311 };
312
313 if !header.starts_with("VERSION") {
314 return Err(Error::other(header));
315 }
316 let version = header
317 .trim_start_matches("VERSION ")
318 .trim_end();
319 Ok(version.to_string())
320 }
321
322 pub async fn flush(&mut self) -> Result<(), Error> {
324 self.stream
325 .write_all(b"flush_all\r\n")
326 .await?;
327 self.stream.flush().await?;
328
329 let header = {
331 let buf = self.read_line().await?;
332 std::str::from_utf8(buf).map_err(|_| Error::from(ErrorKind::InvalidData))?
333 };
334
335 if header == "OK\r\n" {
336 Ok(())
337 } else {
338 Err(Error::other(header))
339 }
340 }
341
342 pub async fn increment<K: AsRef<[u8]>>(
345 &mut self,
346 key: K,
347 amount: u64,
348 ) -> Result<u64, Error> {
349 let writer = self.stream.get_mut();
351 let buf = &[
352 b"incr ",
353 key.as_ref(),
354 b" ",
355 amount.to_string().as_bytes(),
356 b"\r\n",
357 ]
358 .concat();
359 writer.write_all(buf).await?;
360 writer.flush().await?;
361
362 let header = {
364 let buf = self.read_line().await?;
365 std::str::from_utf8(buf).map_err(|_| Error::from(ErrorKind::InvalidData))?
366 };
367
368 if header == "NOT_FOUND\r\n" {
369 Err(ErrorKind::NotFound.into())
370 } else {
371 let value = header
372 .trim_end()
373 .parse::<u64>()
374 .map_err(|_| Error::from(ErrorKind::InvalidData))?;
375 Ok(value)
376 }
377 }
378
379 pub async fn decrement<K: AsRef<[u8]>>(
382 &mut self,
383 key: K,
384 amount: u64,
385 ) -> Result<u64, Error> {
386 let writer = self.stream.get_mut();
388 let buf = &[
389 b"decr ",
390 key.as_ref(),
391 b" ",
392 amount.to_string().as_bytes(),
393 b"\r\n",
394 ]
395 .concat();
396 writer.write_all(buf).await?;
397 writer.flush().await?;
398
399 let header = {
401 let buf = self.read_line().await?;
402 std::str::from_utf8(buf).map_err(|_| Error::from(ErrorKind::InvalidData))?
403 };
404
405 if header == "NOT_FOUND\r\n" {
406 Err(ErrorKind::NotFound.into())
407 } else {
408 let value = header
409 .trim_end()
410 .parse::<u64>()
411 .map_err(|_| Error::from(ErrorKind::InvalidData))?;
412 Ok(value)
413 }
414 }
415
416 async fn read_line(&mut self) -> Result<&[u8], Error> {
417 let Self { stream: io, buf } = self;
418 buf.clear();
419 io.read_until(b'\n', buf).await?;
420 if buf.last().copied() != Some(b'\n') {
421 return Err(ErrorKind::UnexpectedEof.into());
422 }
423 Ok(&buf[..])
424 }
425
426 pub async fn gets_cas<K: AsRef<[u8]>>(
428 &mut self,
429 key: K,
430 ) -> Result<(Vec<u8>, u64), Error> {
431 let writer = self.stream.get_mut();
433 writer
434 .write_all(&[b"gets ", key.as_ref(), b"\r\n"].concat())
435 .await?;
436 writer.flush().await?;
437
438 let (val, maybe_cas) = self.read_get_response().await?;
439 let cas = maybe_cas.ok_or(ErrorKind::InvalidData)?;
440
441 Ok((val, cas))
442 }
443
444 pub async fn cas<K: Display>(
449 &mut self,
450 key: K,
451 val: &[u8],
452 cas_id: u64,
453 expiration: u32,
454 ) -> Result<bool, Error> {
455 let header = format!(
456 "cas {} 0 {} {} {}\r\n",
457 key,
458 expiration,
459 val.len(),
460 cas_id
461 );
462 self.stream
463 .write_all(header.as_bytes())
464 .await?;
465 self.stream
466 .write_all(val)
467 .await?;
468 self.stream
469 .write_all(b"\r\n")
470 .await?;
471 self.stream.flush().await?;
472
473 let header = {
475 let buf = self.read_line().await?;
476 std::str::from_utf8(buf).map_err(|_| Error::from(ErrorKind::InvalidData))?
477 };
478
479 if header.starts_with("STORED") {
497 Ok(true)
498 } else if header.starts_with("EXISTS") || header.starts_with("NOT_STORED") {
499 Ok(false)
500 } else if header.starts_with("NOT FOUND") {
501 Err(ErrorKind::NotFound.into())
502 } else {
503 Err(Error::other(header))
504 }
505 }
506
507 pub async fn append_or_vivify<K: Display>(
509 &mut self,
510 key: K,
511 val: &[u8],
512 ttl: u32,
513 ) -> Result<(), Error> {
514 let header = format!(
527 "ms {} {} MA N{}\r\n",
528 key,
529 val.len(),
530 ttl
531 );
532 self.stream
533 .write_all(header.as_bytes())
534 .await?;
535 self.stream
536 .write_all(val)
537 .await?;
538 self.stream
539 .write_all(b"\r\n")
540 .await?;
541 self.stream.flush().await?;
542
543 let header = {
545 let buf = self.read_line().await?;
546 std::str::from_utf8(buf).map_err(|_| Error::from(ErrorKind::InvalidData))?
547 };
548
549 if header.starts_with("HD") {
570 Ok(())
571 } else {
572 Err(Error::other(header))
573 }
574 }
575}