1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
use std::convert::TryFrom;
use std::io::Cursor;
use std::net::SocketAddr;
use tokio::io::AsyncReadExt;

use cassandra_protocol::compression::Compression;
use cassandra_protocol::error;
use cassandra_protocol::frame::message_response::ResponseBody;
use cassandra_protocol::frame::{
    Direction, Envelope, Flags, Opcode, Version, LENGTH_LEN, STREAM_LEN,
};
use cassandra_protocol::types::data_serialization_types::decode_timeuuid;
use cassandra_protocol::types::{
    from_cursor_string_list, try_i16_from_bytes, try_i32_from_bytes, UUID_LEN,
};

async fn parse_raw_envelope<T: AsyncReadExt + Unpin>(
    cursor: &mut T,
    compressor: Compression,
) -> error::Result<Envelope> {
    let mut version_bytes = [0; Version::BYTE_LENGTH];
    let mut flag_bytes = [0; Flags::BYTE_LENGTH];
    let mut opcode_bytes = [0; Opcode::BYTE_LENGTH];
    let mut stream_bytes = [0; STREAM_LEN];
    let mut length_bytes = [0; LENGTH_LEN];

    // NOTE: order of reads matters
    cursor.read_exact(&mut version_bytes).await?;
    cursor.read_exact(&mut flag_bytes).await?;
    cursor.read_exact(&mut stream_bytes).await?;
    cursor.read_exact(&mut opcode_bytes).await?;
    cursor.read_exact(&mut length_bytes).await?;

    let version = Version::try_from(version_bytes[0])?;
    let direction = Direction::from(version_bytes[0]);
    let flags = Flags::from_bits_truncate(flag_bytes[0]);
    let stream_id = try_i16_from_bytes(&stream_bytes)?;
    let opcode = Opcode::try_from(opcode_bytes[0])?;
    let length = try_i32_from_bytes(&length_bytes)? as usize;

    let mut body_bytes = vec![0; length];

    cursor.read_exact(&mut body_bytes).await?;

    let full_body = if flags.contains(Flags::COMPRESSION) {
        compressor.decode(body_bytes)?
    } else {
        Compression::None.decode(body_bytes)?
    };

    let body_len = full_body.len();

    // Use cursor to get tracing id, warnings and actual body
    let mut body_cursor = Cursor::new(full_body.as_slice());

    let tracing_id = if flags.contains(Flags::TRACING) {
        let mut tracing_bytes = [0; UUID_LEN];
        std::io::Read::read_exact(&mut body_cursor, &mut tracing_bytes)?;

        decode_timeuuid(&tracing_bytes).ok()
    } else {
        None
    };

    let warnings = if flags.contains(Flags::WARNING) {
        from_cursor_string_list(&mut body_cursor)?
    } else {
        vec![]
    };

    let mut body = Vec::with_capacity(body_len - body_cursor.position() as usize);

    std::io::Read::read_to_end(&mut body_cursor, &mut body)?;

    let envelope = Envelope {
        version,
        direction,
        flags,
        opcode,
        stream_id,
        body,
        tracing_id,
        warnings,
    };

    Ok(envelope)
}

pub async fn parse_envelope<T: AsyncReadExt + Unpin>(
    cursor: &mut T,
    compressor: Compression,
    addr: SocketAddr,
) -> error::Result<Envelope> {
    let envelope = parse_raw_envelope(cursor, compressor).await?;
    convert_envelope_into_result(envelope, addr)
}

pub(crate) fn convert_envelope_into_result(
    envelope: Envelope,
    addr: SocketAddr,
) -> error::Result<Envelope> {
    match envelope.opcode {
        Opcode::Error => envelope.response_body().and_then(|err| match err {
            ResponseBody::Error(err) => Err(error::Error::Server { body: err, addr }),
            _ => unreachable!(),
        }),
        _ => Ok(envelope),
    }
}