1use crate::consistency::Consistency;
2use crate::frame::{Direction, Envelope, Flags, FromCursor, Opcode, Serialize, Version};
3use crate::query::QueryFlags;
4use crate::query::QueryValues;
5use crate::types::value::Value;
6use crate::types::{
7 from_cursor_str, from_cursor_str_long, serialize_str, serialize_str_long, CBytesShort, CInt,
8 CIntShort, CLong,
9};
10use crate::{error, Error};
11use derive_more::{Constructor, Display};
12use std::convert::{TryFrom, TryInto};
13use std::io::{Cursor, Read};
14
15#[derive(Debug, Clone, Constructor, PartialEq, Eq)]
16pub struct BodyReqBatch {
17 pub batch_type: BatchType,
18 pub queries: Vec<BatchQuery>,
19 pub consistency: Consistency,
20 pub serial_consistency: Option<Consistency>,
21 pub timestamp: Option<CLong>,
22 pub keyspace: Option<String>,
23 pub now_in_seconds: Option<CInt>,
24}
25
26impl Serialize for BodyReqBatch {
27 fn serialize(&self, cursor: &mut Cursor<&mut Vec<u8>>, version: Version) {
28 let batch_type = u8::from(self.batch_type);
29 batch_type.serialize(cursor, version);
30
31 let len = self.queries.len() as CIntShort;
32 len.serialize(cursor, version);
33
34 for query in &self.queries {
35 query.serialize(cursor, version);
36 }
37
38 let consistency: CIntShort = self.consistency.into();
39 consistency.serialize(cursor, version);
40
41 let mut flags = QueryFlags::empty();
42 if self.serial_consistency.is_some() {
43 flags.insert(QueryFlags::WITH_SERIAL_CONSISTENCY)
44 }
45
46 if self.timestamp.is_some() {
47 flags.insert(QueryFlags::WITH_DEFAULT_TIMESTAMP)
48 }
49
50 if self.keyspace.is_some() {
51 flags.insert(QueryFlags::WITH_KEYSPACE)
52 }
53
54 if self.now_in_seconds.is_some() {
55 flags.insert(QueryFlags::WITH_NOW_IN_SECONDS)
56 }
57
58 flags.serialize(cursor, version);
59
60 if let Some(serial_consistency) = self.serial_consistency {
61 let serial_consistency: CIntShort = serial_consistency.into();
62 serial_consistency.serialize(cursor, version);
63 }
64
65 if let Some(timestamp) = self.timestamp {
66 timestamp.serialize(cursor, version);
67 }
68
69 if let Some(keyspace) = &self.keyspace {
70 serialize_str(cursor, keyspace.as_str(), version);
71 }
72
73 if let Some(now_in_seconds) = self.now_in_seconds {
74 now_in_seconds.serialize(cursor, version);
75 }
76 }
77}
78
79impl FromCursor for BodyReqBatch {
80 fn from_cursor(cursor: &mut Cursor<&[u8]>, version: Version) -> error::Result<Self> {
81 let mut batch_type = [0];
82 cursor.read_exact(&mut batch_type)?;
83
84 let batch_type = BatchType::try_from(batch_type[0])?;
85 let len = CIntShort::from_cursor(cursor, version)?;
86
87 let mut queries = Vec::with_capacity(len as usize);
88 for _ in 0..len {
89 queries.push(BatchQuery::from_cursor(cursor, version)?);
90 }
91
92 let consistency = CIntShort::from_cursor(cursor, version).and_then(TryInto::try_into)?;
93 let query_flags = QueryFlags::from_cursor(cursor, version)?;
94
95 let serial_consistency = if query_flags.contains(QueryFlags::WITH_SERIAL_CONSISTENCY) {
96 Some(CIntShort::from_cursor(cursor, version).and_then(TryInto::try_into)?)
97 } else {
98 None
99 };
100
101 let timestamp = if query_flags.contains(QueryFlags::WITH_DEFAULT_TIMESTAMP) {
102 Some(CLong::from_cursor(cursor, version)?)
103 } else {
104 None
105 };
106
107 let keyspace = if query_flags.contains(QueryFlags::WITH_KEYSPACE) {
108 Some(from_cursor_str(cursor).map(|keyspace| keyspace.to_string())?)
109 } else {
110 None
111 };
112
113 let now_in_seconds = if query_flags.contains(QueryFlags::WITH_NOW_IN_SECONDS) {
114 Some(CInt::from_cursor(cursor, version)?)
115 } else {
116 None
117 };
118
119 Ok(BodyReqBatch::new(
120 batch_type,
121 queries,
122 consistency,
123 serial_consistency,
124 timestamp,
125 keyspace,
126 now_in_seconds,
127 ))
128 }
129}
130
131#[derive(Debug, Clone, Copy, PartialEq, Ord, PartialOrd, Eq, Hash, Display)]
133#[non_exhaustive]
134pub enum BatchType {
135 Logged,
138 Unlogged,
140 Counter,
143}
144
145impl TryFrom<u8> for BatchType {
146 type Error = Error;
147
148 fn try_from(value: u8) -> Result<Self, Self::Error> {
149 match value {
150 0 => Ok(BatchType::Logged),
151 1 => Ok(BatchType::Unlogged),
152 2 => Ok(BatchType::Counter),
153 _ => Err(Error::General(format!("Unknown batch type: {value}"))),
154 }
155 }
156}
157
158impl From<BatchType> for u8 {
159 fn from(value: BatchType) -> Self {
160 match value {
161 BatchType::Logged => 0,
162 BatchType::Unlogged => 1,
163 BatchType::Counter => 2,
164 }
165 }
166}
167
168#[derive(Debug, Clone, Ord, PartialOrd, Eq, PartialEq, Hash)]
170pub enum BatchQuerySubj {
171 PreparedId(CBytesShort),
172 QueryString(String),
173}
174
175#[derive(Debug, Clone, Constructor, PartialEq, Eq)]
177pub struct BatchQuery {
178 pub subject: BatchQuerySubj,
180 pub values: QueryValues,
186}
187
188impl Serialize for BatchQuery {
189 fn serialize(&self, cursor: &mut Cursor<&mut Vec<u8>>, version: Version) {
190 match &self.subject {
191 BatchQuerySubj::PreparedId(id) => {
192 1u8.serialize(cursor, version);
193 id.serialize(cursor, version);
194 }
195 BatchQuerySubj::QueryString(s) => {
196 0u8.serialize(cursor, version);
197 serialize_str_long(cursor, s, version);
198 }
199 }
200
201 let len = self.values.len() as CIntShort;
202 len.serialize(cursor, version);
203
204 self.values.serialize(cursor, version);
205 }
206}
207
208impl FromCursor for BatchQuery {
209 fn from_cursor(cursor: &mut Cursor<&[u8]>, version: Version) -> error::Result<Self> {
210 let mut is_prepared = [0];
211 cursor.read_exact(&mut is_prepared)?;
212
213 let is_prepared = is_prepared[0] != 0;
214
215 let subject = if is_prepared {
216 BatchQuerySubj::PreparedId(CBytesShort::from_cursor(cursor, version)?)
217 } else {
218 BatchQuerySubj::QueryString(from_cursor_str_long(cursor).map(Into::into)?)
219 };
220
221 let len = CIntShort::from_cursor(cursor, version)?;
222
223 let mut values = Vec::with_capacity(len as usize);
226 for _ in 0..len {
227 values.push(Value::from_cursor(cursor, version)?);
228 }
229
230 Ok(BatchQuery::new(subject, QueryValues::SimpleValues(values)))
231 }
232}
233
234impl Envelope {
235 pub fn new_req_batch(query: BodyReqBatch, flags: Flags, version: Version) -> Envelope {
236 let direction = Direction::Request;
237 let opcode = Opcode::Batch;
238
239 Envelope::new(
240 version,
241 direction,
242 flags,
243 opcode,
244 0,
245 query.serialize_to_vec(version),
246 None,
247 vec![],
248 )
249 }
250}
251
252#[cfg(test)]
253mod tests {
254 use crate::consistency::Consistency;
255 use crate::frame::message_batch::{BatchQuery, BatchQuerySubj, BatchType, BodyReqBatch};
256 use crate::frame::traits::Serialize;
257 use crate::frame::{FromCursor, Version};
258 use crate::query::QueryValues;
259 use crate::types::prelude::Value;
260 use std::io::Cursor;
261
262 #[test]
263 fn should_deserialize_query() {
264 let data = [0, 0, 0, 0, 1, 65, 0, 1, 0xff, 0xff, 0xff, 0xfe];
265 let mut cursor = Cursor::new(data.as_slice());
266
267 let query = BatchQuery::from_cursor(&mut cursor, Version::V4).unwrap();
268 assert_eq!(query.subject, BatchQuerySubj::QueryString("A".into()));
269 assert_eq!(query.values, QueryValues::SimpleValues(vec![Value::NotSet]));
270 }
271
272 #[test]
273 fn should_deserialize_body() {
274 let data = [0, 0, 0, 0, 0, 0x10 | 0x20, 0, 1, 1, 2, 3, 4, 5, 6, 7, 8];
275 let mut cursor = Cursor::new(data.as_slice());
276
277 let body = BodyReqBatch::from_cursor(&mut cursor, Version::V4).unwrap();
278 assert_eq!(body.batch_type, BatchType::Logged);
279 assert!(body.queries.is_empty());
280 assert_eq!(body.consistency, Consistency::Any);
281 assert_eq!(body.serial_consistency, Some(Consistency::One));
282 assert_eq!(body.timestamp, Some(0x0102030405060708));
283 }
284
285 #[test]
286 fn should_support_keyspace() {
287 let keyspace = "abc";
288 let body = BodyReqBatch::new(
289 BatchType::Logged,
290 vec![],
291 Consistency::Any,
292 None,
293 None,
294 Some(keyspace.into()),
295 None,
296 );
297
298 let data = body.serialize_to_vec(Version::V5);
299 let body =
300 BodyReqBatch::from_cursor(&mut Cursor::new(data.as_slice()), Version::V5).unwrap();
301 assert_eq!(body.keyspace, Some(keyspace.to_string()));
302 }
303
304 #[test]
305 fn should_support_now_in_seconds() {
306 let now_in_seconds = 4;
307 let body = BodyReqBatch::new(
308 BatchType::Logged,
309 vec![],
310 Consistency::Any,
311 None,
312 None,
313 None,
314 Some(now_in_seconds),
315 );
316
317 let data = body.serialize_to_vec(Version::V5);
318 let body =
319 BodyReqBatch::from_cursor(&mut Cursor::new(data.as_slice()), Version::V5).unwrap();
320 assert_eq!(body.now_in_seconds, Some(now_in_seconds));
321 }
322}