cdrs_tokio/
envelope_parser.rs

1use std::convert::TryFrom;
2use std::io::Cursor;
3use std::net::SocketAddr;
4use tokio::io::AsyncReadExt;
5
6use cassandra_protocol::compression::Compression;
7use cassandra_protocol::error;
8use cassandra_protocol::frame::message_response::ResponseBody;
9use cassandra_protocol::frame::{
10    Direction, Envelope, Flags, Opcode, Version, LENGTH_LEN, STREAM_LEN,
11};
12use cassandra_protocol::types::data_serialization_types::decode_timeuuid;
13use cassandra_protocol::types::{
14    from_cursor_string_list, try_i16_from_bytes, try_i32_from_bytes, UUID_LEN,
15};
16
17async fn parse_raw_envelope<T: AsyncReadExt + Unpin>(
18    cursor: &mut T,
19    compressor: Compression,
20) -> error::Result<Envelope> {
21    let mut version_bytes = [0; Version::BYTE_LENGTH];
22    let mut flag_bytes = [0; Flags::BYTE_LENGTH];
23    let mut opcode_bytes = [0; Opcode::BYTE_LENGTH];
24    let mut stream_bytes = [0; STREAM_LEN];
25    let mut length_bytes = [0; LENGTH_LEN];
26
27    // NOTE: order of reads matters
28    cursor.read_exact(&mut version_bytes).await?;
29    cursor.read_exact(&mut flag_bytes).await?;
30    cursor.read_exact(&mut stream_bytes).await?;
31    cursor.read_exact(&mut opcode_bytes).await?;
32    cursor.read_exact(&mut length_bytes).await?;
33
34    let version = Version::try_from(version_bytes[0])?;
35    let direction = Direction::from(version_bytes[0]);
36    let flags = Flags::from_bits_truncate(flag_bytes[0]);
37    let stream_id = try_i16_from_bytes(&stream_bytes)?;
38    let opcode = Opcode::try_from(opcode_bytes[0])?;
39    let length = try_i32_from_bytes(&length_bytes)? as usize;
40
41    let mut body_bytes = vec![0; length];
42
43    cursor.read_exact(&mut body_bytes).await?;
44
45    let full_body = if flags.contains(Flags::COMPRESSION) {
46        compressor.decode(body_bytes)?
47    } else {
48        Compression::None.decode(body_bytes)?
49    };
50
51    let body_len = full_body.len();
52
53    // Use cursor to get tracing id, warnings and actual body
54    let mut body_cursor = Cursor::new(full_body.as_slice());
55
56    let tracing_id = if flags.contains(Flags::TRACING) {
57        let mut tracing_bytes = [0; UUID_LEN];
58        std::io::Read::read_exact(&mut body_cursor, &mut tracing_bytes)?;
59
60        decode_timeuuid(&tracing_bytes).ok()
61    } else {
62        None
63    };
64
65    let warnings = if flags.contains(Flags::WARNING) {
66        from_cursor_string_list(&mut body_cursor)?
67    } else {
68        vec![]
69    };
70
71    let mut body = Vec::with_capacity(body_len - body_cursor.position() as usize);
72
73    std::io::Read::read_to_end(&mut body_cursor, &mut body)?;
74
75    let envelope = Envelope {
76        version,
77        direction,
78        flags,
79        opcode,
80        stream_id,
81        body,
82        tracing_id,
83        warnings,
84    };
85
86    Ok(envelope)
87}
88
89pub async fn parse_envelope<T: AsyncReadExt + Unpin>(
90    cursor: &mut T,
91    compressor: Compression,
92    addr: SocketAddr,
93) -> error::Result<Envelope> {
94    let envelope = parse_raw_envelope(cursor, compressor).await?;
95    convert_envelope_into_result(envelope, addr)
96}
97
98pub(crate) fn convert_envelope_into_result(
99    envelope: Envelope,
100    addr: SocketAddr,
101) -> error::Result<Envelope> {
102    match envelope.opcode {
103        Opcode::Error => envelope.response_body().and_then(|err| match err {
104            ResponseBody::Error(err) => Err(error::Error::Server { body: err, addr }),
105            _ => unreachable!(),
106        }),
107        _ => Ok(envelope),
108    }
109}