bolt_proto/
message.rs

1use std::{
2    mem,
3    panic::{catch_unwind, UnwindSafe},
4};
5
6use bytes::{Buf, BufMut, Bytes, BytesMut};
7use futures_util::io::{AsyncRead, AsyncReadExt};
8
9pub use begin::Begin;
10pub use discard::Discard;
11pub use failure::Failure;
12pub use hello::Hello;
13pub use init::Init;
14pub use pull::Pull;
15pub use record::Record;
16pub use route::Route;
17pub use route_with_metadata::RouteWithMetadata;
18pub use run::Run;
19pub use run_with_metadata::RunWithMetadata;
20pub use success::Success;
21
22use crate::{error::*, serialization::*, value::MARKER_TINY_STRUCT};
23
24pub(crate) mod begin;
25pub(crate) mod discard;
26pub(crate) mod failure;
27pub(crate) mod hello;
28pub(crate) mod init;
29pub(crate) mod pull;
30pub(crate) mod record;
31pub(crate) mod route;
32pub(crate) mod route_with_metadata;
33pub(crate) mod run;
34pub(crate) mod run_with_metadata;
35pub(crate) mod success;
36
37pub(crate) const SIGNATURE_INIT: u8 = 0x01;
38pub(crate) const SIGNATURE_RUN: u8 = 0x10;
39pub(crate) const SIGNATURE_DISCARD_ALL: u8 = 0x2F;
40pub(crate) const SIGNATURE_PULL_ALL: u8 = 0x3F;
41pub(crate) const SIGNATURE_ACK_FAILURE: u8 = 0x0E;
42pub(crate) const SIGNATURE_RESET: u8 = 0x0F;
43pub(crate) const SIGNATURE_RECORD: u8 = 0x71;
44pub(crate) const SIGNATURE_SUCCESS: u8 = 0x70;
45pub(crate) const SIGNATURE_FAILURE: u8 = 0x7F;
46pub(crate) const SIGNATURE_IGNORED: u8 = 0x7E;
47pub(crate) const SIGNATURE_HELLO: u8 = 0x01;
48pub(crate) const SIGNATURE_GOODBYE: u8 = 0x02;
49pub(crate) const SIGNATURE_RUN_WITH_METADATA: u8 = 0x10;
50pub(crate) const SIGNATURE_BEGIN: u8 = 0x11;
51pub(crate) const SIGNATURE_COMMIT: u8 = 0x12;
52pub(crate) const SIGNATURE_ROLLBACK: u8 = 0x13;
53pub(crate) const SIGNATURE_DISCARD: u8 = 0x2F;
54pub(crate) const SIGNATURE_PULL: u8 = 0x3F;
55pub(crate) const SIGNATURE_ROUTE: u8 = 0x66;
56
57// This is the default maximum chunk size in the official driver, minus header length
58const CHUNK_SIZE: usize = 16383 - mem::size_of::<u16>();
59
60#[derive(Debug, Clone, Eq, PartialEq)]
61pub enum Message {
62    // v1-compatible message types
63    Init(Init),
64    Run(Run),
65    DiscardAll,
66    PullAll,
67    AckFailure,
68    Reset,
69    Record(Record),
70    Success(Success),
71    Failure(Failure),
72    Ignored,
73
74    // v3-compatible message types
75    Hello(Hello),
76    Goodbye,
77    RunWithMetadata(RunWithMetadata),
78    Begin(Begin),
79    Commit,
80    Rollback,
81
82    // v4-compatible message types
83    Discard(Discard),
84    Pull(Pull),
85
86    // v4.3-compatible message types
87    Route(Route),
88
89    // v4.4-compatible message types
90    RouteWithMetadata(RouteWithMetadata),
91}
92
93impl Message {
94    pub async fn from_stream(mut stream: impl AsyncRead + Unpin) -> DeserializeResult<Message> {
95        let mut bytes = BytesMut::new();
96        let mut chunk_len = 0;
97        // Ignore any no-op messages
98        while chunk_len == 0 {
99            let mut u16_bytes = [0, 0];
100            stream.read_exact(&mut u16_bytes).await?;
101            chunk_len = u16::from_be_bytes(u16_bytes);
102        }
103        // Messages end in a 0_u16
104        while chunk_len > 0 {
105            let mut buf = vec![0; chunk_len as usize];
106            stream.read_exact(&mut buf).await?;
107            bytes.put_slice(&buf);
108            let mut u16_bytes = [0, 0];
109            stream.read_exact(&mut u16_bytes).await?;
110            chunk_len = u16::from_be_bytes(u16_bytes);
111        }
112        let (message, remaining) = Message::deserialize(bytes)?;
113        debug_assert_eq!(remaining.len(), 0);
114
115        Ok(message)
116    }
117
118    pub fn into_chunks(self) -> SerializeResult<Vec<Bytes>> {
119        let bytes = self.serialize()?;
120
121        // Big enough to hold all the chunks, plus a partial chunk, plus the message footer
122        let mut result: Vec<Bytes> = Vec::with_capacity(bytes.len() / CHUNK_SIZE + 2);
123        for slice in bytes.chunks(CHUNK_SIZE) {
124            // 16-bit size, then the chunk data
125            let mut chunk = BytesMut::with_capacity(mem::size_of::<u16>() + slice.len());
126            // Length of slice is at most CHUNK_SIZE, which can fit in a u16
127            chunk.put_u16(slice.len() as u16);
128            chunk.put(slice);
129            result.push(chunk.freeze());
130        }
131        // End message
132        result.push(Bytes::from_static(&[0, 0]));
133
134        Ok(result)
135    }
136}
137
138macro_rules! deserialize_struct {
139    ($name:ident, $bytes:ident) => {{
140        let (message, remaining) = $name::deserialize($bytes)?;
141        $bytes = remaining;
142        Ok((Message::$name(message), $bytes))
143    }};
144}
145
146impl BoltValue for Message {
147    fn marker(&self) -> SerializeResult<u8> {
148        match self {
149            Message::Init(init) => init.marker(),
150            Message::Run(run) => run.marker(),
151            Message::Record(record) => record.marker(),
152            Message::Success(success) => success.marker(),
153            Message::Failure(failure) => failure.marker(),
154            Message::Hello(hello) => hello.marker(),
155            Message::RunWithMetadata(run_with_metadata) => run_with_metadata.marker(),
156            Message::Begin(begin) => begin.marker(),
157            Message::Discard(discard) => discard.marker(),
158            Message::Pull(pull) => pull.marker(),
159            Message::Route(route) => route.marker(),
160            Message::RouteWithMetadata(route_with_metadata) => route_with_metadata.marker(),
161            _ => Ok(MARKER_TINY_STRUCT),
162        }
163    }
164
165    fn serialize(self) -> SerializeResult<Bytes> {
166        match self {
167            Message::Init(init) => init.serialize(),
168            Message::Run(run) => run.serialize(),
169            Message::Record(record) => record.serialize(),
170            Message::Success(success) => success.serialize(),
171            Message::Failure(failure) => failure.serialize(),
172            Message::Hello(hello) => hello.serialize(),
173            Message::RunWithMetadata(run_with_metadata) => run_with_metadata.serialize(),
174            Message::Begin(begin) => begin.serialize(),
175            Message::Discard(discard) => discard.serialize(),
176            Message::Pull(pull) => pull.serialize(),
177            Message::Route(route) => route.serialize(),
178            Message::RouteWithMetadata(route_with_metadata) => route_with_metadata.serialize(),
179            other => Ok(Bytes::from(vec![other.marker()?, other.signature()])),
180        }
181    }
182
183    fn deserialize<B: Buf + UnwindSafe>(mut bytes: B) -> DeserializeResult<(Self, B)> {
184        catch_unwind(move || {
185            let marker = bytes.get_u8();
186            let (size, signature) = get_structure_info(marker, &mut bytes)?;
187
188            match signature {
189                SIGNATURE_INIT => {
190                    // Conflicting signatures, so we have to check for metadata.
191                    // HELLO has 1 field, while INIT has 2.
192                    match size {
193                        1 => deserialize_struct!(Hello, bytes),
194                        2 => deserialize_struct!(Init, bytes),
195                        _ => Err(DeserializationError::InvalidSize { size, signature }),
196                    }
197                }
198                SIGNATURE_RUN => {
199                    // Conflicting signatures, so we have to check for metadata.
200                    // RUN has 2 fields, while RUN_WITH_METADATA has 3.
201                    match size {
202                        2 => deserialize_struct!(Run, bytes),
203                        3 => deserialize_struct!(RunWithMetadata, bytes),
204                        _ => Err(DeserializationError::InvalidSize { size, signature }),
205                    }
206                }
207                SIGNATURE_DISCARD_ALL => {
208                    // Conflicting signatures, so we have to check for metadata.
209                    // DISCARD_ALL has 0 fields, while DISCARD has 1.
210                    match size {
211                        0 => Ok((Message::DiscardAll, bytes)),
212                        1 => deserialize_struct!(Discard, bytes),
213                        _ => Err(DeserializationError::InvalidSize { size, signature }),
214                    }
215                }
216                SIGNATURE_PULL_ALL => {
217                    // Conflicting signatures, so we have to check for metadata.
218                    // PULL_ALL has 0 fields, while PULL has 1.
219                    match size {
220                        0 => Ok((Message::PullAll, bytes)),
221                        1 => deserialize_struct!(Pull, bytes),
222                        _ => Err(DeserializationError::InvalidSize { size, signature }),
223                    }
224                }
225                SIGNATURE_ACK_FAILURE => Ok((Message::AckFailure, bytes)),
226                SIGNATURE_RESET => Ok((Message::Reset, bytes)),
227                SIGNATURE_RECORD => deserialize_struct!(Record, bytes),
228                SIGNATURE_SUCCESS => deserialize_struct!(Success, bytes),
229                SIGNATURE_FAILURE => deserialize_struct!(Failure, bytes),
230                SIGNATURE_IGNORED => Ok((Message::Ignored, bytes)),
231                SIGNATURE_GOODBYE => Ok((Message::Goodbye, bytes)),
232                SIGNATURE_BEGIN => deserialize_struct!(Begin, bytes),
233                SIGNATURE_COMMIT => Ok((Message::Commit, bytes)),
234                SIGNATURE_ROLLBACK => Ok((Message::Rollback, bytes)),
235                SIGNATURE_ROUTE => match RouteWithMetadata::deserialize(bytes.chunk()) {
236                    Ok(_) => {
237                        // Actually consume the bytes
238                        let (message, remaining) = RouteWithMetadata::deserialize(bytes)?;
239                        bytes = remaining;
240                        Ok((Message::RouteWithMetadata(message), bytes))
241                    }
242                    Err(_) => {
243                        // Fall back to v4.3-compatible ROUTE message
244                        let (message, remaining) = Route::deserialize(bytes)?;
245                        bytes = remaining;
246                        Ok((Message::Route(message), bytes))
247                    }
248                },
249                _ => Err(DeserializationError::InvalidSignatureByte(signature)),
250            }
251        })
252        .map_err(|_| DeserializationError::Panicked)?
253    }
254}
255
256impl BoltStructure for Message {
257    fn signature(&self) -> u8 {
258        match self {
259            Message::Init(_) => SIGNATURE_INIT,
260            Message::Run(_) => SIGNATURE_RUN,
261            Message::DiscardAll => SIGNATURE_DISCARD_ALL,
262            Message::PullAll => SIGNATURE_PULL_ALL,
263            Message::AckFailure => SIGNATURE_ACK_FAILURE,
264            Message::Reset => SIGNATURE_RESET,
265            Message::Record(_) => SIGNATURE_RECORD,
266            Message::Success(_) => SIGNATURE_SUCCESS,
267            Message::Failure(_) => SIGNATURE_FAILURE,
268            Message::Ignored => SIGNATURE_IGNORED,
269            Message::Hello(_) => SIGNATURE_HELLO,
270            Message::Goodbye => SIGNATURE_GOODBYE,
271            Message::RunWithMetadata(_) => SIGNATURE_RUN_WITH_METADATA,
272            Message::Begin(_) => SIGNATURE_BEGIN,
273            Message::Commit => SIGNATURE_COMMIT,
274            Message::Rollback => SIGNATURE_ROLLBACK,
275            Message::Discard(_) => SIGNATURE_DISCARD,
276            Message::Pull(_) => SIGNATURE_PULL,
277            Message::Route(_) | Message::RouteWithMetadata(_) => SIGNATURE_ROUTE,
278        }
279    }
280}