memcached/protocol/
binary_packet.rs1use super::{
2 code::{Magic, Opcode},
3 parse,
4};
5use crate::{
6 error::{CommandError, MemcachedError, ServerError},
7 stream::Stream,
8 Result,
9};
10use byteorder::{BigEndian, ReadBytesExt};
11use serde::{de::DeserializeOwned, Deserialize, Serialize};
12use std::{borrow::Cow, collections::HashMap, io::Cursor};
13
14const OK_STATUS: u16 = 0x0;
15
16#[derive(Debug, Default, Serialize, Deserialize, PartialEq)]
17pub(super) struct PacketHeader {
18 pub(super) magic: u8,
19 pub(super) opcode: u8,
20 pub(super) key_length: u16,
21 pub(super) extras_length: u8,
22 pub(super) data_type: u8,
23 pub(super) vbucket_id_or_status: u16,
24 pub(super) total_body_length: u32,
25 pub(super) opaque: u32,
26 pub(super) cas: u64,
27}
28
29#[derive(Debug)]
30pub(super) struct StoreExtras {
31 pub(super) flags: u32,
32 pub(super) expiration: u32,
33}
34
35#[derive(Debug)]
36pub(super) struct CounterExtras {
37 pub(super) amount: u64,
38 pub(super) initial_value: u64,
39 pub(super) expiration: u32,
40}
41
42impl PacketHeader {
43 pub(super) async fn write(self, writer: &mut Stream) -> Result<()> {
44 writer.write_u8(self.magic).await?;
45 writer.write_u8(self.opcode).await?;
46 writer.write_u16(self.key_length).await?;
47 writer.write_u8(self.extras_length).await?;
48 writer.write_u8(self.data_type).await?;
49 writer.write_u16(self.vbucket_id_or_status).await?;
50 writer.write_u32(self.total_body_length).await?;
51 writer.write_u32(self.opaque).await?;
52 writer.write_u64(self.cas).await?;
53 Ok(())
54 }
55
56 pub(super) async fn read(stream: &mut Stream) -> Result<PacketHeader> {
57 let magic = stream.read_u8().await?;
58 if magic != Magic::Response as u8 {
59 return Err(ServerError::BadMagic(magic).into());
60 }
61 Ok(PacketHeader {
62 magic,
63 opcode: stream.read_u8().await?,
64 key_length: stream.read_u16().await?,
65 extras_length: stream.read_u8().await?,
66 data_type: stream.read_u8().await?,
67 vbucket_id_or_status: stream.read_u16().await?,
68 total_body_length: stream.read_u32().await?,
69 opaque: stream.read_u32().await?,
70 cas: stream.read_u64().await?,
71 })
72 }
73}
74
75#[derive(Debug, Deserialize)]
76pub(super) struct Response {
77 header: PacketHeader,
78 key: Vec<u8>,
79 extras: Vec<u8>,
80 value: Vec<u8>,
81}
82
83impl Response {
84 pub(super) fn err(self) -> Result<Self> {
85 let status = self.header.vbucket_id_or_status;
86 if status == OK_STATUS {
87 Ok(self)
88 } else {
89 Err(CommandError::from(status).into())
90 }
91 }
92}
93
94pub(super) async fn parse_response(stream: &mut Stream) -> Result<Response> {
95 let head = PacketHeader::read(stream).await?;
96 let mut extras = vec![0x0; head.extras_length as usize];
97 stream.read_exact(extras.as_mut_slice()).await?;
98
99 let mut key = vec![0x0; head.key_length as usize];
100 stream.read_exact(key.as_mut_slice()).await?;
101
102 let value_len = (head.total_body_length
103 - u32::from(head.key_length)
104 - u32::from(head.extras_length)) as usize;
105 let mut value = vec![0x0; value_len];
107 stream.read_exact(&mut value).await?;
108
109 Ok(Response {
110 header: head,
111 key,
112 extras,
113 value,
114 })
115}
116
117pub(super) async fn parse_cas_response(stream: &mut Stream) -> Result<bool> {
118 match parse_response(stream).await?.err() {
119 Err(MemcachedError::CommandError(e))
120 if e == CommandError::KeyNotFound || e == CommandError::KeyExists =>
121 {
122 Ok(false)
123 }
124 Ok(_) => Ok(true),
125 Err(e) => Err(e),
126 }
127}
128
129pub(super) async fn parse_version_response(stream: &mut Stream) -> Result<String> {
130 let Response { value, .. } = parse_response(stream).await?.err()?;
131 Ok(parse::deserialize_bytes(&value)?)
132}
133
134pub(super) async fn parse_get_response<T: DeserializeOwned + 'static>(
135 stream: &mut Stream,
136) -> Result<Option<T>> {
137 match parse_response(stream).await?.err() {
138 Ok(Response { value, .. }) => Ok(Some(parse::deserialize_bytes(&value)?)),
139 Err(MemcachedError::CommandError(CommandError::KeyNotFound)) => Ok(None),
140 Err(e) => Err(e),
141 }
142}
143
144pub(super) async fn parse_gets_response<V: DeserializeOwned + 'static>(
145 stream: &mut Stream,
146 max_responses: usize,
147) -> Result<HashMap<String, (V, u32, Option<u64>)>> {
148 let mut result = HashMap::new();
149 for _ in 0..=max_responses {
150 let Response {
151 header,
152 key,
153 extras,
154 value,
155 } = parse_response(stream).await?.err()?;
156 if header.opcode == Opcode::Noop as u8 {
157 return Ok(result);
158 }
159 let flags = Cursor::new(extras).read_u32::<BigEndian>()?;
160 let key = parse::deserialize_bytes(&key)?;
161 let _ = result.insert(
162 key,
163 (parse::deserialize_bytes(&value)?, flags, Some(header.cas)),
164 );
165 }
166 Err(ServerError::BadResponse(Cow::Borrowed("Expected end of gets response")).into())
167}
168
169pub(super) async fn parse_delete_response(stream: &mut Stream) -> Result<bool> {
170 match parse_response(stream).await?.err() {
171 Ok(_) => Ok(true),
172 Err(MemcachedError::CommandError(CommandError::KeyNotFound)) => Ok(false),
173 Err(e) => Err(e),
174 }
175}
176
177pub(super) async fn parse_counter_response(stream: &mut Stream) -> Result<u64> {
178 let Response { value, .. } = parse_response(stream).await?.err()?;
179 Ok(Cursor::new(&value).read_u64::<BigEndian>()?)
180}
181
182pub(super) async fn parse_touch_response(stream: &mut Stream) -> Result<bool> {
183 match parse_response(stream).await?.err() {
184 Ok(_) => Ok(true),
185 Err(MemcachedError::CommandError(CommandError::KeyNotFound)) => Ok(false),
186 Err(e) => Err(e),
187 }
188}
189
190pub(super) async fn parse_stats_response(stream: &mut Stream) -> Result<HashMap<String, String>> {
191 let mut result = HashMap::new();
192 loop {
193 let Response { key, value, .. } = parse_response(stream).await?.err()?;
194 let key: String = parse::deserialize_bytes(&key)?;
195 let value: String = parse::deserialize_bytes(&value)?;
196 if key.is_empty() && value.is_empty() {
197 break;
198 }
199 let _ = result.insert(key, value);
200 }
201 Ok(result)
202}
203
204pub(super) async fn parse_start_auth_response(stream: &mut Stream) -> Result<bool> {
205 parse_response(stream).await?.err().map(|_| true)
206}