Skip to main content

memcached/protocol/
binary_packet.rs

1use super::{
2    code::{Magic, Opcode},
3    parse,
4};
5use crate::{
6    error::{CommandError, MemcachedError, ServerError},
7    stream::Stream,
8    Result,
9};
10use byteorder::{BigEndian, ReadBytesExt};
11use serde::{de::DeserializeOwned, Deserialize, Serialize};
12use std::{borrow::Cow, collections::HashMap, io::Cursor};
13
14const OK_STATUS: u16 = 0x0;
15
16#[derive(Debug, Default, Serialize, Deserialize, PartialEq)]
17pub(super) struct PacketHeader {
18    pub(super) magic: u8,
19    pub(super) opcode: u8,
20    pub(super) key_length: u16,
21    pub(super) extras_length: u8,
22    pub(super) data_type: u8,
23    pub(super) vbucket_id_or_status: u16,
24    pub(super) total_body_length: u32,
25    pub(super) opaque: u32,
26    pub(super) cas: u64,
27}
28
29#[derive(Debug)]
30pub(super) struct StoreExtras {
31    pub(super) flags: u32,
32    pub(super) expiration: u32,
33}
34
35#[derive(Debug)]
36pub(super) struct CounterExtras {
37    pub(super) amount: u64,
38    pub(super) initial_value: u64,
39    pub(super) expiration: u32,
40}
41
42impl PacketHeader {
43    pub(super) async fn write(self, writer: &mut Stream) -> Result<()> {
44        writer.write_u8(self.magic).await?;
45        writer.write_u8(self.opcode).await?;
46        writer.write_u16(self.key_length).await?;
47        writer.write_u8(self.extras_length).await?;
48        writer.write_u8(self.data_type).await?;
49        writer.write_u16(self.vbucket_id_or_status).await?;
50        writer.write_u32(self.total_body_length).await?;
51        writer.write_u32(self.opaque).await?;
52        writer.write_u64(self.cas).await?;
53        Ok(())
54    }
55
56    pub(super) async fn read(stream: &mut Stream) -> Result<PacketHeader> {
57        let magic = stream.read_u8().await?;
58        if magic != Magic::Response as u8 {
59            return Err(ServerError::BadMagic(magic).into());
60        }
61        Ok(PacketHeader {
62            magic,
63            opcode: stream.read_u8().await?,
64            key_length: stream.read_u16().await?,
65            extras_length: stream.read_u8().await?,
66            data_type: stream.read_u8().await?,
67            vbucket_id_or_status: stream.read_u16().await?,
68            total_body_length: stream.read_u32().await?,
69            opaque: stream.read_u32().await?,
70            cas: stream.read_u64().await?,
71        })
72    }
73}
74
75#[derive(Debug, Deserialize)]
76pub(super) struct Response {
77    header: PacketHeader,
78    key: Vec<u8>,
79    extras: Vec<u8>,
80    value: Vec<u8>,
81}
82
83impl Response {
84    pub(super) fn err(self) -> Result<Self> {
85        let status = self.header.vbucket_id_or_status;
86        if status == OK_STATUS {
87            Ok(self)
88        } else {
89            Err(CommandError::from(status).into())
90        }
91    }
92}
93
94pub(super) async fn parse_response(stream: &mut Stream) -> Result<Response> {
95    let head = PacketHeader::read(stream).await?;
96    let mut extras = vec![0x0; head.extras_length as usize];
97    stream.read_exact(extras.as_mut_slice()).await?;
98
99    let mut key = vec![0x0; head.key_length as usize];
100    stream.read_exact(key.as_mut_slice()).await?;
101
102    let value_len = (head.total_body_length
103        - u32::from(head.key_length)
104        - u32::from(head.extras_length)) as usize;
105    // TODO: return error if total_body_length < extras_length + key_length
106    let mut value = vec![0x0; value_len];
107    stream.read_exact(&mut value).await?;
108
109    Ok(Response {
110        header: head,
111        key,
112        extras,
113        value,
114    })
115}
116
117pub(super) async fn parse_cas_response(stream: &mut Stream) -> Result<bool> {
118    match parse_response(stream).await?.err() {
119        Err(MemcachedError::CommandError(e))
120            if e == CommandError::KeyNotFound || e == CommandError::KeyExists =>
121        {
122            Ok(false)
123        }
124        Ok(_) => Ok(true),
125        Err(e) => Err(e),
126    }
127}
128
129pub(super) async fn parse_version_response(stream: &mut Stream) -> Result<String> {
130    let Response { value, .. } = parse_response(stream).await?.err()?;
131    Ok(parse::deserialize_bytes(&value)?)
132}
133
134pub(super) async fn parse_get_response<T: DeserializeOwned + 'static>(
135    stream: &mut Stream,
136) -> Result<Option<T>> {
137    match parse_response(stream).await?.err() {
138        Ok(Response { value, .. }) => Ok(Some(parse::deserialize_bytes(&value)?)),
139        Err(MemcachedError::CommandError(CommandError::KeyNotFound)) => Ok(None),
140        Err(e) => Err(e),
141    }
142}
143
144pub(super) async fn parse_gets_response<V: DeserializeOwned + 'static>(
145    stream: &mut Stream,
146    max_responses: usize,
147) -> Result<HashMap<String, (V, u32, Option<u64>)>> {
148    let mut result = HashMap::new();
149    for _ in 0..=max_responses {
150        let Response {
151            header,
152            key,
153            extras,
154            value,
155        } = parse_response(stream).await?.err()?;
156        if header.opcode == Opcode::Noop as u8 {
157            return Ok(result);
158        }
159        let flags = Cursor::new(extras).read_u32::<BigEndian>()?;
160        let key = parse::deserialize_bytes(&key)?;
161        let _ = result.insert(
162            key,
163            (parse::deserialize_bytes(&value)?, flags, Some(header.cas)),
164        );
165    }
166    Err(ServerError::BadResponse(Cow::Borrowed("Expected end of gets response")).into())
167}
168
169pub(super) async fn parse_delete_response(stream: &mut Stream) -> Result<bool> {
170    match parse_response(stream).await?.err() {
171        Ok(_) => Ok(true),
172        Err(MemcachedError::CommandError(CommandError::KeyNotFound)) => Ok(false),
173        Err(e) => Err(e),
174    }
175}
176
177pub(super) async fn parse_counter_response(stream: &mut Stream) -> Result<u64> {
178    let Response { value, .. } = parse_response(stream).await?.err()?;
179    Ok(Cursor::new(&value).read_u64::<BigEndian>()?)
180}
181
182pub(super) async fn parse_touch_response(stream: &mut Stream) -> Result<bool> {
183    match parse_response(stream).await?.err() {
184        Ok(_) => Ok(true),
185        Err(MemcachedError::CommandError(CommandError::KeyNotFound)) => Ok(false),
186        Err(e) => Err(e),
187    }
188}
189
190pub(super) async fn parse_stats_response(stream: &mut Stream) -> Result<HashMap<String, String>> {
191    let mut result = HashMap::new();
192    loop {
193        let Response { key, value, .. } = parse_response(stream).await?.err()?;
194        let key: String = parse::deserialize_bytes(&key)?;
195        let value: String = parse::deserialize_bytes(&value)?;
196        if key.is_empty() && value.is_empty() {
197            break;
198        }
199        let _ = result.insert(key, value);
200    }
201    Ok(result)
202}
203
204pub(super) async fn parse_start_auth_response(stream: &mut Stream) -> Result<bool> {
205    parse_response(stream).await?.err().map(|_| true)
206}