1pub const FRAME_HEADER_SIZE: usize = 16;
20pub const MAX_FRAME_SIZE: u32 = 16 * 1024 * 1024;
21
22#[derive(Debug, Clone, PartialEq, Eq)]
23pub struct Frame {
24 pub kind: MessageKind,
25 pub flags: Flags,
26 pub stream_id: u16,
27 pub correlation_id: u64,
28 pub payload: Vec<u8>,
29}
30
31impl Frame {
32 pub fn new(kind: MessageKind, correlation_id: u64, payload: Vec<u8>) -> Self {
33 Self {
34 kind,
35 flags: Flags::empty(),
36 stream_id: 0,
37 correlation_id,
38 payload,
39 }
40 }
41
42 pub fn with_stream(mut self, stream_id: u16) -> Self {
43 self.stream_id = stream_id;
44 self
45 }
46
47 pub fn with_flags(mut self, flags: Flags) -> Self {
48 self.flags = flags;
49 self
50 }
51
52 pub fn encoded_len(&self) -> u32 {
53 (FRAME_HEADER_SIZE + self.payload.len()) as u32
54 }
55}
56
57#[derive(Debug, Clone, Copy, PartialEq, Eq)]
60#[repr(u8)]
61pub enum MessageKind {
62 Query = 0x01,
64 Result = 0x02,
65 Error = 0x03,
66 BulkInsert = 0x04,
67 BulkOk = 0x05,
68 BulkInsertBinary = 0x06,
69 QueryBinary = 0x07,
70 BulkInsertPrevalidated = 0x08,
71 BulkStreamStart = 0x09,
72 BulkStreamRows = 0x0A,
73 BulkStreamCommit = 0x0B,
74 BulkStreamAck = 0x0C,
75 Prepare = 0x0D,
76 PreparedOk = 0x0E,
77 ExecutePrepared = 0x0F,
78
79 Hello = 0x10,
81 HelloAck = 0x11,
82 AuthRequest = 0x12,
83 AuthResponse = 0x13,
84 AuthOk = 0x14,
85 AuthFail = 0x15,
86 Bye = 0x16,
87 Ping = 0x17,
88 Pong = 0x18,
89 Get = 0x19,
90 Delete = 0x1A,
91 DeleteOk = 0x1B,
92
93 Cancel = 0x20,
95 Compress = 0x21,
96 SetSession = 0x22,
97 Notice = 0x23,
98
99 RowDescription = 0x24,
101 StreamEnd = 0x25,
102
103 VectorSearch = 0x26,
105 GraphTraverse = 0x27,
106 QueryWithParams = 0x28,
107}
108
109#[derive(Debug, Clone, Copy, PartialEq, Eq)]
118pub enum MessageClass {
119 DataPlane,
120 Handshake,
121 ControlPlane,
122 Streamed,
123}
124
125#[derive(Debug, Clone, Copy, PartialEq, Eq)]
133pub enum MessageDirection {
134 ClientToServer,
135 ServerToClient,
136 Both,
137}
138
139impl MessageKind {
140 pub fn class(&self) -> MessageClass {
142 match self {
143 Self::Query
148 | Self::Result
149 | Self::Error
150 | Self::BulkInsert
151 | Self::BulkOk
152 | Self::BulkInsertBinary
153 | Self::QueryBinary
154 | Self::BulkInsertPrevalidated
155 | Self::Prepare
156 | Self::PreparedOk
157 | Self::ExecutePrepared
158 | Self::Get
159 | Self::Delete
160 | Self::DeleteOk
161 | Self::VectorSearch
162 | Self::GraphTraverse
163 | Self::QueryWithParams => MessageClass::DataPlane,
164
165 Self::BulkStreamStart
168 | Self::BulkStreamRows
169 | Self::BulkStreamCommit
170 | Self::BulkStreamAck
171 | Self::RowDescription
172 | Self::StreamEnd => MessageClass::Streamed,
173
174 Self::Hello
176 | Self::HelloAck
177 | Self::AuthRequest
178 | Self::AuthResponse
179 | Self::AuthOk
180 | Self::AuthFail
181 | Self::Bye
182 | Self::Ping
183 | Self::Pong => MessageClass::Handshake,
184
185 Self::Cancel | Self::Compress | Self::SetSession | Self::Notice => {
187 MessageClass::ControlPlane
188 }
189 }
190 }
191
192 pub fn allowed_flags(&self) -> Flags {
202 match self {
203 Self::Hello
206 | Self::HelloAck
207 | Self::AuthRequest
208 | Self::AuthResponse
209 | Self::AuthOk
210 | Self::AuthFail
211 | Self::Bye
212 | Self::Ping
213 | Self::Pong => Flags::MORE_FRAMES,
214
215 _ => Flags::COMPRESSED.insert(Flags::MORE_FRAMES),
217 }
218 }
219
220 pub fn is_handshake(&self) -> bool {
225 matches!(self.class(), MessageClass::Handshake)
226 }
227
228 pub fn permits_flags(&self, flags: Flags) -> bool {
234 let allowed = self.allowed_flags().bits();
235 (flags.bits() & !allowed) == 0
236 }
237
238 pub fn direction(&self) -> MessageDirection {
240 match self {
241 Self::Hello
243 | Self::AuthResponse
244 | Self::Query
245 | Self::QueryBinary
246 | Self::BulkInsert
247 | Self::BulkInsertBinary
248 | Self::BulkInsertPrevalidated
249 | Self::BulkStreamStart
250 | Self::BulkStreamRows
251 | Self::BulkStreamCommit
252 | Self::Prepare
253 | Self::ExecutePrepared
254 | Self::Get
255 | Self::Delete
256 | Self::Cancel
257 | Self::Compress
258 | Self::SetSession
259 | Self::VectorSearch
260 | Self::GraphTraverse
261 | Self::QueryWithParams => MessageDirection::ClientToServer,
262
263 Self::HelloAck
265 | Self::AuthRequest
266 | Self::AuthOk
267 | Self::AuthFail
268 | Self::Result
269 | Self::Error
270 | Self::BulkOk
271 | Self::BulkStreamAck
272 | Self::PreparedOk
273 | Self::DeleteOk
274 | Self::Notice
275 | Self::RowDescription
276 | Self::StreamEnd => MessageDirection::ServerToClient,
277
278 Self::Bye | Self::Ping | Self::Pong => MessageDirection::Both,
280 }
281 }
282
283 pub fn from_u8(byte: u8) -> Option<Self> {
284 match byte {
285 0x01 => Some(Self::Query),
286 0x02 => Some(Self::Result),
287 0x03 => Some(Self::Error),
288 0x04 => Some(Self::BulkInsert),
289 0x05 => Some(Self::BulkOk),
290 0x06 => Some(Self::BulkInsertBinary),
291 0x07 => Some(Self::QueryBinary),
292 0x08 => Some(Self::BulkInsertPrevalidated),
293 0x09 => Some(Self::BulkStreamStart),
294 0x0A => Some(Self::BulkStreamRows),
295 0x0B => Some(Self::BulkStreamCommit),
296 0x0C => Some(Self::BulkStreamAck),
297 0x0D => Some(Self::Prepare),
298 0x0E => Some(Self::PreparedOk),
299 0x0F => Some(Self::ExecutePrepared),
300 0x10 => Some(Self::Hello),
301 0x11 => Some(Self::HelloAck),
302 0x12 => Some(Self::AuthRequest),
303 0x13 => Some(Self::AuthResponse),
304 0x14 => Some(Self::AuthOk),
305 0x15 => Some(Self::AuthFail),
306 0x16 => Some(Self::Bye),
307 0x17 => Some(Self::Ping),
308 0x18 => Some(Self::Pong),
309 0x19 => Some(Self::Get),
310 0x1A => Some(Self::Delete),
311 0x1B => Some(Self::DeleteOk),
312 0x20 => Some(Self::Cancel),
313 0x21 => Some(Self::Compress),
314 0x22 => Some(Self::SetSession),
315 0x23 => Some(Self::Notice),
316 0x24 => Some(Self::RowDescription),
317 0x25 => Some(Self::StreamEnd),
318 0x26 => Some(Self::VectorSearch),
319 0x27 => Some(Self::GraphTraverse),
320 0x28 => Some(Self::QueryWithParams),
321 _ => None,
322 }
323 }
324}
325
326#[derive(Debug, Clone, Copy, PartialEq, Eq)]
327pub struct Flags(u8);
328
329impl Flags {
330 pub const COMPRESSED: Self = Self(0b0000_0001);
331 pub const MORE_FRAMES: Self = Self(0b0000_0010);
332
333 pub const fn empty() -> Self {
334 Self(0)
335 }
336
337 pub const fn bits(self) -> u8 {
338 self.0
339 }
340
341 pub const fn from_bits(bits: u8) -> Self {
342 Self(bits)
343 }
344
345 pub const fn contains(self, other: Self) -> bool {
346 (self.0 & other.0) == other.0
347 }
348
349 pub const fn insert(self, other: Self) -> Self {
350 Self(self.0 | other.0)
351 }
352}
353
354impl std::ops::BitOr for Flags {
355 type Output = Self;
356 fn bitor(self, rhs: Self) -> Self {
357 self.insert(rhs)
358 }
359}
360
361#[cfg(test)]
362mod catalog_tests {
363 use super::*;
364
365 const ALL_KINDS: &[MessageKind] = &[
369 MessageKind::Query,
370 MessageKind::Result,
371 MessageKind::Error,
372 MessageKind::BulkInsert,
373 MessageKind::BulkOk,
374 MessageKind::BulkInsertBinary,
375 MessageKind::QueryBinary,
376 MessageKind::BulkInsertPrevalidated,
377 MessageKind::BulkStreamStart,
378 MessageKind::BulkStreamRows,
379 MessageKind::BulkStreamCommit,
380 MessageKind::BulkStreamAck,
381 MessageKind::Prepare,
382 MessageKind::PreparedOk,
383 MessageKind::ExecutePrepared,
384 MessageKind::Hello,
385 MessageKind::HelloAck,
386 MessageKind::AuthRequest,
387 MessageKind::AuthResponse,
388 MessageKind::AuthOk,
389 MessageKind::AuthFail,
390 MessageKind::Bye,
391 MessageKind::Ping,
392 MessageKind::Pong,
393 MessageKind::Get,
394 MessageKind::Delete,
395 MessageKind::DeleteOk,
396 MessageKind::Cancel,
397 MessageKind::Compress,
398 MessageKind::SetSession,
399 MessageKind::Notice,
400 MessageKind::RowDescription,
401 MessageKind::StreamEnd,
402 MessageKind::VectorSearch,
403 MessageKind::GraphTraverse,
404 MessageKind::QueryWithParams,
405 ];
406
407 #[test]
408 fn class_matrix_is_pinned() {
409 assert_eq!(MessageKind::Hello.class(), MessageClass::Handshake);
412 assert_eq!(MessageKind::HelloAck.class(), MessageClass::Handshake);
413 assert_eq!(MessageKind::AuthRequest.class(), MessageClass::Handshake);
414 assert_eq!(MessageKind::AuthResponse.class(), MessageClass::Handshake);
415 assert_eq!(MessageKind::AuthOk.class(), MessageClass::Handshake);
416 assert_eq!(MessageKind::AuthFail.class(), MessageClass::Handshake);
417 assert_eq!(MessageKind::Bye.class(), MessageClass::Handshake);
418 assert_eq!(MessageKind::Ping.class(), MessageClass::Handshake);
419 assert_eq!(MessageKind::Pong.class(), MessageClass::Handshake);
420
421 assert_eq!(MessageKind::Query.class(), MessageClass::DataPlane);
423 assert_eq!(MessageKind::Result.class(), MessageClass::DataPlane);
424 assert_eq!(MessageKind::BulkInsert.class(), MessageClass::DataPlane);
425 assert_eq!(MessageKind::Get.class(), MessageClass::DataPlane);
426 assert_eq!(MessageKind::Delete.class(), MessageClass::DataPlane);
427 assert_eq!(MessageKind::DeleteOk.class(), MessageClass::DataPlane);
428 assert_eq!(MessageKind::VectorSearch.class(), MessageClass::DataPlane);
429 assert_eq!(MessageKind::GraphTraverse.class(), MessageClass::DataPlane);
430 assert_eq!(
431 MessageKind::QueryWithParams.class(),
432 MessageClass::DataPlane
433 );
434
435 assert_eq!(MessageKind::BulkStreamStart.class(), MessageClass::Streamed);
437 assert_eq!(MessageKind::BulkStreamRows.class(), MessageClass::Streamed);
438 assert_eq!(
439 MessageKind::BulkStreamCommit.class(),
440 MessageClass::Streamed
441 );
442 assert_eq!(MessageKind::BulkStreamAck.class(), MessageClass::Streamed);
443 assert_eq!(MessageKind::RowDescription.class(), MessageClass::Streamed);
444 assert_eq!(MessageKind::StreamEnd.class(), MessageClass::Streamed);
445
446 assert_eq!(MessageKind::Cancel.class(), MessageClass::ControlPlane);
448 assert_eq!(MessageKind::Compress.class(), MessageClass::ControlPlane);
449 assert_eq!(MessageKind::SetSession.class(), MessageClass::ControlPlane);
450 assert_eq!(MessageKind::Notice.class(), MessageClass::ControlPlane);
451
452 for k in ALL_KINDS {
454 let _ = k.class();
455 }
456 }
457
458 #[test]
459 fn allowed_flags_matrix_is_pinned() {
460 let handshake = [
464 MessageKind::Hello,
465 MessageKind::HelloAck,
466 MessageKind::AuthRequest,
467 MessageKind::AuthResponse,
468 MessageKind::AuthOk,
469 MessageKind::AuthFail,
470 MessageKind::Bye,
471 MessageKind::Ping,
472 MessageKind::Pong,
473 ];
474 for k in handshake {
475 let f = k.allowed_flags();
476 assert!(
477 f.contains(Flags::MORE_FRAMES),
478 "{k:?} must allow MORE_FRAMES"
479 );
480 assert!(
481 !f.contains(Flags::COMPRESSED),
482 "{k:?} must NOT allow COMPRESSED today"
483 );
484 }
485
486 for k in ALL_KINDS {
488 if handshake.contains(k) {
489 continue;
490 }
491 let f = k.allowed_flags();
492 assert!(
493 f.contains(Flags::MORE_FRAMES),
494 "{k:?} must allow MORE_FRAMES"
495 );
496 assert!(f.contains(Flags::COMPRESSED), "{k:?} must allow COMPRESSED");
497 }
498 }
499
500 #[test]
501 fn every_kind_has_unique_byte_value() {
502 let mut seen = std::collections::HashSet::new();
505 for k in ALL_KINDS {
506 let byte = *k as u8;
507 assert!(
508 seen.insert(byte),
509 "byte 0x{byte:02x} reused by {k:?}; catalog has a duplicate"
510 );
511 }
512 }
513
514 #[test]
515 fn from_u8_round_trips_for_every_kind() {
516 for k in ALL_KINDS {
517 let byte = *k as u8;
518 let decoded = MessageKind::from_u8(byte).unwrap_or_else(|| {
519 panic!("from_u8 returned None for catalog entry {k:?} (0x{byte:02x})")
520 });
521 assert_eq!(
522 decoded, *k,
523 "from_u8(0x{byte:02x}) must round-trip back to {k:?}"
524 );
525 }
526 }
527
528 #[test]
529 fn permits_flags_matches_allowed_flags() {
530 assert!(MessageKind::Ping.permits_flags(Flags::MORE_FRAMES));
532 assert!(MessageKind::Ping.permits_flags(Flags::empty()));
533 assert!(!MessageKind::Ping.permits_flags(Flags::COMPRESSED));
534 assert!(!MessageKind::Ping.permits_flags(Flags::COMPRESSED | Flags::MORE_FRAMES));
535
536 assert!(MessageKind::BulkStreamRows.permits_flags(Flags::MORE_FRAMES));
540 assert!(MessageKind::BulkStreamRows.permits_flags(Flags::COMPRESSED));
541 assert!(MessageKind::RowDescription.permits_flags(Flags::MORE_FRAMES));
542 assert!(MessageKind::StreamEnd.permits_flags(Flags::MORE_FRAMES));
543 }
544
545 #[test]
546 fn direction_matrix_is_pinned() {
547 for k in [
549 MessageKind::Hello,
550 MessageKind::AuthResponse,
551 MessageKind::Query,
552 MessageKind::QueryBinary,
553 MessageKind::BulkInsert,
554 MessageKind::BulkInsertBinary,
555 MessageKind::BulkInsertPrevalidated,
556 MessageKind::BulkStreamStart,
557 MessageKind::BulkStreamRows,
558 MessageKind::BulkStreamCommit,
559 MessageKind::Prepare,
560 MessageKind::ExecutePrepared,
561 MessageKind::Get,
562 MessageKind::Delete,
563 MessageKind::Cancel,
564 MessageKind::Compress,
565 MessageKind::SetSession,
566 MessageKind::VectorSearch,
567 MessageKind::GraphTraverse,
568 MessageKind::QueryWithParams,
569 ] {
570 assert_eq!(
571 k.direction(),
572 MessageDirection::ClientToServer,
573 "{k:?} should be client-originated"
574 );
575 }
576
577 for k in [
579 MessageKind::HelloAck,
580 MessageKind::AuthRequest,
581 MessageKind::AuthOk,
582 MessageKind::AuthFail,
583 MessageKind::Result,
584 MessageKind::Error,
585 MessageKind::BulkOk,
586 MessageKind::BulkStreamAck,
587 MessageKind::PreparedOk,
588 MessageKind::DeleteOk,
589 MessageKind::Notice,
590 MessageKind::RowDescription,
591 MessageKind::StreamEnd,
592 ] {
593 assert_eq!(
594 k.direction(),
595 MessageDirection::ServerToClient,
596 "{k:?} should be server-originated"
597 );
598 }
599
600 for k in [MessageKind::Bye, MessageKind::Ping, MessageKind::Pong] {
602 assert_eq!(
603 k.direction(),
604 MessageDirection::Both,
605 "{k:?} should be symmetric"
606 );
607 }
608 }
609}