1pub const FRAME_HEADER_SIZE: usize = 16;
20pub const MAX_FRAME_SIZE: u32 = 16 * 1024 * 1024;
21
22#[derive(Debug, Clone, PartialEq, Eq)]
23pub struct Frame {
24 pub kind: MessageKind,
25 pub flags: Flags,
26 pub stream_id: u16,
27 pub correlation_id: u64,
28 pub payload: Vec<u8>,
29}
30
31impl Frame {
32 pub fn new(kind: MessageKind, correlation_id: u64, payload: Vec<u8>) -> Self {
33 Self {
34 kind,
35 flags: Flags::empty(),
36 stream_id: 0,
37 correlation_id,
38 payload,
39 }
40 }
41
42 pub fn with_stream(mut self, stream_id: u16) -> Self {
43 self.stream_id = stream_id;
44 self
45 }
46
47 pub fn with_flags(mut self, flags: Flags) -> Self {
48 self.flags = flags;
49 self
50 }
51
52 pub fn encoded_len(&self) -> u32 {
53 (FRAME_HEADER_SIZE + self.payload.len()) as u32
54 }
55}
56
57#[derive(Debug, Clone, Copy, PartialEq, Eq)]
60#[repr(u8)]
61pub enum MessageKind {
62 Query = 0x01,
64 Result = 0x02,
65 Error = 0x03,
66 BulkInsert = 0x04,
67 BulkOk = 0x05,
68 BulkInsertBinary = 0x06,
69 QueryBinary = 0x07,
70 BulkInsertPrevalidated = 0x08,
71 BulkStreamStart = 0x09,
72 BulkStreamRows = 0x0A,
73 BulkStreamCommit = 0x0B,
74 BulkStreamAck = 0x0C,
75 Prepare = 0x0D,
76 PreparedOk = 0x0E,
77 ExecutePrepared = 0x0F,
78
79 Hello = 0x10,
81 HelloAck = 0x11,
82 AuthRequest = 0x12,
83 AuthResponse = 0x13,
84 AuthOk = 0x14,
85 AuthFail = 0x15,
86 Bye = 0x16,
87 Ping = 0x17,
88 Pong = 0x18,
89 Get = 0x19,
90 Delete = 0x1A,
91 DeleteOk = 0x1B,
92
93 Cancel = 0x20,
95 Compress = 0x21,
96 SetSession = 0x22,
97 Notice = 0x23,
98
99 RowDescription = 0x24,
101 StreamEnd = 0x25,
102
103 VectorSearch = 0x26,
105 GraphTraverse = 0x27,
106}
107
108#[derive(Debug, Clone, Copy, PartialEq, Eq)]
117pub enum MessageClass {
118 DataPlane,
119 Handshake,
120 ControlPlane,
121 Streamed,
122}
123
124#[derive(Debug, Clone, Copy, PartialEq, Eq)]
132pub enum MessageDirection {
133 ClientToServer,
134 ServerToClient,
135 Both,
136}
137
138impl MessageKind {
139 pub fn class(&self) -> MessageClass {
141 match self {
142 Self::Query
147 | Self::Result
148 | Self::Error
149 | Self::BulkInsert
150 | Self::BulkOk
151 | Self::BulkInsertBinary
152 | Self::QueryBinary
153 | Self::BulkInsertPrevalidated
154 | Self::Prepare
155 | Self::PreparedOk
156 | Self::ExecutePrepared
157 | Self::Get
158 | Self::Delete
159 | Self::DeleteOk
160 | Self::VectorSearch
161 | Self::GraphTraverse => MessageClass::DataPlane,
162
163 Self::BulkStreamStart
166 | Self::BulkStreamRows
167 | Self::BulkStreamCommit
168 | Self::BulkStreamAck
169 | Self::RowDescription
170 | Self::StreamEnd => MessageClass::Streamed,
171
172 Self::Hello
174 | Self::HelloAck
175 | Self::AuthRequest
176 | Self::AuthResponse
177 | Self::AuthOk
178 | Self::AuthFail
179 | Self::Bye
180 | Self::Ping
181 | Self::Pong => MessageClass::Handshake,
182
183 Self::Cancel | Self::Compress | Self::SetSession | Self::Notice => {
185 MessageClass::ControlPlane
186 }
187 }
188 }
189
190 pub fn allowed_flags(&self) -> Flags {
200 match self {
201 Self::Hello
204 | Self::HelloAck
205 | Self::AuthRequest
206 | Self::AuthResponse
207 | Self::AuthOk
208 | Self::AuthFail
209 | Self::Bye
210 | Self::Ping
211 | Self::Pong => Flags::MORE_FRAMES,
212
213 _ => Flags::COMPRESSED.insert(Flags::MORE_FRAMES),
215 }
216 }
217
218 pub fn is_handshake(&self) -> bool {
223 matches!(self.class(), MessageClass::Handshake)
224 }
225
226 pub fn permits_flags(&self, flags: Flags) -> bool {
232 let allowed = self.allowed_flags().bits();
233 (flags.bits() & !allowed) == 0
234 }
235
236 pub fn direction(&self) -> MessageDirection {
238 match self {
239 Self::Hello
241 | Self::AuthResponse
242 | Self::Query
243 | Self::QueryBinary
244 | Self::BulkInsert
245 | Self::BulkInsertBinary
246 | Self::BulkInsertPrevalidated
247 | Self::BulkStreamStart
248 | Self::BulkStreamRows
249 | Self::BulkStreamCommit
250 | Self::Prepare
251 | Self::ExecutePrepared
252 | Self::Get
253 | Self::Delete
254 | Self::Cancel
255 | Self::Compress
256 | Self::SetSession
257 | Self::VectorSearch
258 | Self::GraphTraverse => MessageDirection::ClientToServer,
259
260 Self::HelloAck
262 | Self::AuthRequest
263 | Self::AuthOk
264 | Self::AuthFail
265 | Self::Result
266 | Self::Error
267 | Self::BulkOk
268 | Self::BulkStreamAck
269 | Self::PreparedOk
270 | Self::DeleteOk
271 | Self::Notice
272 | Self::RowDescription
273 | Self::StreamEnd => MessageDirection::ServerToClient,
274
275 Self::Bye | Self::Ping | Self::Pong => MessageDirection::Both,
277 }
278 }
279
280 pub fn from_u8(byte: u8) -> Option<Self> {
281 match byte {
282 0x01 => Some(Self::Query),
283 0x02 => Some(Self::Result),
284 0x03 => Some(Self::Error),
285 0x04 => Some(Self::BulkInsert),
286 0x05 => Some(Self::BulkOk),
287 0x06 => Some(Self::BulkInsertBinary),
288 0x07 => Some(Self::QueryBinary),
289 0x08 => Some(Self::BulkInsertPrevalidated),
290 0x09 => Some(Self::BulkStreamStart),
291 0x0A => Some(Self::BulkStreamRows),
292 0x0B => Some(Self::BulkStreamCommit),
293 0x0C => Some(Self::BulkStreamAck),
294 0x0D => Some(Self::Prepare),
295 0x0E => Some(Self::PreparedOk),
296 0x0F => Some(Self::ExecutePrepared),
297 0x10 => Some(Self::Hello),
298 0x11 => Some(Self::HelloAck),
299 0x12 => Some(Self::AuthRequest),
300 0x13 => Some(Self::AuthResponse),
301 0x14 => Some(Self::AuthOk),
302 0x15 => Some(Self::AuthFail),
303 0x16 => Some(Self::Bye),
304 0x17 => Some(Self::Ping),
305 0x18 => Some(Self::Pong),
306 0x19 => Some(Self::Get),
307 0x1A => Some(Self::Delete),
308 0x1B => Some(Self::DeleteOk),
309 0x20 => Some(Self::Cancel),
310 0x21 => Some(Self::Compress),
311 0x22 => Some(Self::SetSession),
312 0x23 => Some(Self::Notice),
313 0x24 => Some(Self::RowDescription),
314 0x25 => Some(Self::StreamEnd),
315 0x26 => Some(Self::VectorSearch),
316 0x27 => Some(Self::GraphTraverse),
317 _ => None,
318 }
319 }
320}
321
322#[derive(Debug, Clone, Copy, PartialEq, Eq)]
323pub struct Flags(u8);
324
325impl Flags {
326 pub const COMPRESSED: Self = Self(0b0000_0001);
327 pub const MORE_FRAMES: Self = Self(0b0000_0010);
328
329 pub const fn empty() -> Self {
330 Self(0)
331 }
332
333 pub const fn bits(self) -> u8 {
334 self.0
335 }
336
337 pub const fn from_bits(bits: u8) -> Self {
338 Self(bits)
339 }
340
341 pub const fn contains(self, other: Self) -> bool {
342 (self.0 & other.0) == other.0
343 }
344
345 pub const fn insert(self, other: Self) -> Self {
346 Self(self.0 | other.0)
347 }
348}
349
350impl std::ops::BitOr for Flags {
351 type Output = Self;
352 fn bitor(self, rhs: Self) -> Self {
353 self.insert(rhs)
354 }
355}
356
357#[cfg(test)]
358mod catalog_tests {
359 use super::*;
360
361 const ALL_KINDS: &[MessageKind] = &[
365 MessageKind::Query,
366 MessageKind::Result,
367 MessageKind::Error,
368 MessageKind::BulkInsert,
369 MessageKind::BulkOk,
370 MessageKind::BulkInsertBinary,
371 MessageKind::QueryBinary,
372 MessageKind::BulkInsertPrevalidated,
373 MessageKind::BulkStreamStart,
374 MessageKind::BulkStreamRows,
375 MessageKind::BulkStreamCommit,
376 MessageKind::BulkStreamAck,
377 MessageKind::Prepare,
378 MessageKind::PreparedOk,
379 MessageKind::ExecutePrepared,
380 MessageKind::Hello,
381 MessageKind::HelloAck,
382 MessageKind::AuthRequest,
383 MessageKind::AuthResponse,
384 MessageKind::AuthOk,
385 MessageKind::AuthFail,
386 MessageKind::Bye,
387 MessageKind::Ping,
388 MessageKind::Pong,
389 MessageKind::Get,
390 MessageKind::Delete,
391 MessageKind::DeleteOk,
392 MessageKind::Cancel,
393 MessageKind::Compress,
394 MessageKind::SetSession,
395 MessageKind::Notice,
396 MessageKind::RowDescription,
397 MessageKind::StreamEnd,
398 MessageKind::VectorSearch,
399 MessageKind::GraphTraverse,
400 ];
401
402 #[test]
403 fn class_matrix_is_pinned() {
404 assert_eq!(MessageKind::Hello.class(), MessageClass::Handshake);
407 assert_eq!(MessageKind::HelloAck.class(), MessageClass::Handshake);
408 assert_eq!(MessageKind::AuthRequest.class(), MessageClass::Handshake);
409 assert_eq!(MessageKind::AuthResponse.class(), MessageClass::Handshake);
410 assert_eq!(MessageKind::AuthOk.class(), MessageClass::Handshake);
411 assert_eq!(MessageKind::AuthFail.class(), MessageClass::Handshake);
412 assert_eq!(MessageKind::Bye.class(), MessageClass::Handshake);
413 assert_eq!(MessageKind::Ping.class(), MessageClass::Handshake);
414 assert_eq!(MessageKind::Pong.class(), MessageClass::Handshake);
415
416 assert_eq!(MessageKind::Query.class(), MessageClass::DataPlane);
418 assert_eq!(MessageKind::Result.class(), MessageClass::DataPlane);
419 assert_eq!(MessageKind::BulkInsert.class(), MessageClass::DataPlane);
420 assert_eq!(MessageKind::Get.class(), MessageClass::DataPlane);
421 assert_eq!(MessageKind::Delete.class(), MessageClass::DataPlane);
422 assert_eq!(MessageKind::DeleteOk.class(), MessageClass::DataPlane);
423 assert_eq!(MessageKind::VectorSearch.class(), MessageClass::DataPlane);
424 assert_eq!(MessageKind::GraphTraverse.class(), MessageClass::DataPlane);
425
426 assert_eq!(MessageKind::BulkStreamStart.class(), MessageClass::Streamed);
428 assert_eq!(MessageKind::BulkStreamRows.class(), MessageClass::Streamed);
429 assert_eq!(
430 MessageKind::BulkStreamCommit.class(),
431 MessageClass::Streamed
432 );
433 assert_eq!(MessageKind::BulkStreamAck.class(), MessageClass::Streamed);
434 assert_eq!(MessageKind::RowDescription.class(), MessageClass::Streamed);
435 assert_eq!(MessageKind::StreamEnd.class(), MessageClass::Streamed);
436
437 assert_eq!(MessageKind::Cancel.class(), MessageClass::ControlPlane);
439 assert_eq!(MessageKind::Compress.class(), MessageClass::ControlPlane);
440 assert_eq!(MessageKind::SetSession.class(), MessageClass::ControlPlane);
441 assert_eq!(MessageKind::Notice.class(), MessageClass::ControlPlane);
442
443 for k in ALL_KINDS {
445 let _ = k.class();
446 }
447 }
448
449 #[test]
450 fn allowed_flags_matrix_is_pinned() {
451 let handshake = [
455 MessageKind::Hello,
456 MessageKind::HelloAck,
457 MessageKind::AuthRequest,
458 MessageKind::AuthResponse,
459 MessageKind::AuthOk,
460 MessageKind::AuthFail,
461 MessageKind::Bye,
462 MessageKind::Ping,
463 MessageKind::Pong,
464 ];
465 for k in handshake {
466 let f = k.allowed_flags();
467 assert!(
468 f.contains(Flags::MORE_FRAMES),
469 "{k:?} must allow MORE_FRAMES"
470 );
471 assert!(
472 !f.contains(Flags::COMPRESSED),
473 "{k:?} must NOT allow COMPRESSED today"
474 );
475 }
476
477 for k in ALL_KINDS {
479 if handshake.contains(k) {
480 continue;
481 }
482 let f = k.allowed_flags();
483 assert!(
484 f.contains(Flags::MORE_FRAMES),
485 "{k:?} must allow MORE_FRAMES"
486 );
487 assert!(f.contains(Flags::COMPRESSED), "{k:?} must allow COMPRESSED");
488 }
489 }
490
491 #[test]
492 fn every_kind_has_unique_byte_value() {
493 let mut seen = std::collections::HashSet::new();
496 for k in ALL_KINDS {
497 let byte = *k as u8;
498 assert!(
499 seen.insert(byte),
500 "byte 0x{byte:02x} reused by {k:?}; catalog has a duplicate"
501 );
502 }
503 }
504
505 #[test]
506 fn from_u8_round_trips_for_every_kind() {
507 for k in ALL_KINDS {
508 let byte = *k as u8;
509 let decoded = MessageKind::from_u8(byte).unwrap_or_else(|| {
510 panic!("from_u8 returned None for catalog entry {k:?} (0x{byte:02x})")
511 });
512 assert_eq!(
513 decoded, *k,
514 "from_u8(0x{byte:02x}) must round-trip back to {k:?}"
515 );
516 }
517 }
518
519 #[test]
520 fn permits_flags_matches_allowed_flags() {
521 assert!(MessageKind::Ping.permits_flags(Flags::MORE_FRAMES));
523 assert!(MessageKind::Ping.permits_flags(Flags::empty()));
524 assert!(!MessageKind::Ping.permits_flags(Flags::COMPRESSED));
525 assert!(!MessageKind::Ping.permits_flags(Flags::COMPRESSED | Flags::MORE_FRAMES));
526
527 assert!(MessageKind::BulkStreamRows.permits_flags(Flags::MORE_FRAMES));
531 assert!(MessageKind::BulkStreamRows.permits_flags(Flags::COMPRESSED));
532 assert!(MessageKind::RowDescription.permits_flags(Flags::MORE_FRAMES));
533 assert!(MessageKind::StreamEnd.permits_flags(Flags::MORE_FRAMES));
534 }
535
536 #[test]
537 fn direction_matrix_is_pinned() {
538 for k in [
540 MessageKind::Hello,
541 MessageKind::AuthResponse,
542 MessageKind::Query,
543 MessageKind::QueryBinary,
544 MessageKind::BulkInsert,
545 MessageKind::BulkInsertBinary,
546 MessageKind::BulkInsertPrevalidated,
547 MessageKind::BulkStreamStart,
548 MessageKind::BulkStreamRows,
549 MessageKind::BulkStreamCommit,
550 MessageKind::Prepare,
551 MessageKind::ExecutePrepared,
552 MessageKind::Get,
553 MessageKind::Delete,
554 MessageKind::Cancel,
555 MessageKind::Compress,
556 MessageKind::SetSession,
557 MessageKind::VectorSearch,
558 MessageKind::GraphTraverse,
559 ] {
560 assert_eq!(
561 k.direction(),
562 MessageDirection::ClientToServer,
563 "{k:?} should be client-originated"
564 );
565 }
566
567 for k in [
569 MessageKind::HelloAck,
570 MessageKind::AuthRequest,
571 MessageKind::AuthOk,
572 MessageKind::AuthFail,
573 MessageKind::Result,
574 MessageKind::Error,
575 MessageKind::BulkOk,
576 MessageKind::BulkStreamAck,
577 MessageKind::PreparedOk,
578 MessageKind::DeleteOk,
579 MessageKind::Notice,
580 MessageKind::RowDescription,
581 MessageKind::StreamEnd,
582 ] {
583 assert_eq!(
584 k.direction(),
585 MessageDirection::ServerToClient,
586 "{k:?} should be server-originated"
587 );
588 }
589
590 for k in [MessageKind::Bye, MessageKind::Ping, MessageKind::Pong] {
592 assert_eq!(
593 k.direction(),
594 MessageDirection::Both,
595 "{k:?} should be symmetric"
596 );
597 }
598 }
599}