1use bitflags::bitflags;
21use serde::{Deserialize, Serialize};
22use std::collections::HashMap;
23
24pub const PROTOCOL_VERSION: u16 = 1;
26
27#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
33pub struct ServiceId(pub String);
34
35impl ServiceId {
36 #[must_use]
38 pub fn new(name: impl Into<String>) -> Self {
39 Self(name.into())
40 }
41
42 #[must_use]
44 pub fn as_str(&self) -> &str {
45 &self.0
46 }
47}
48
49impl std::fmt::Display for ServiceId {
50 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
51 write!(f, "{}", self.0)
52 }
53}
54
55impl From<&str> for ServiceId {
56 fn from(s: &str) -> Self {
57 Self(s.to_owned())
58 }
59}
60
61impl From<String> for ServiceId {
62 fn from(s: String) -> Self {
63 Self(s)
64 }
65}
66
67#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
72pub enum Metadata {
73 #[default]
75 Empty,
76 Bytes(Vec<u8>),
78 Structured(HashMap<String, MetadataValue>),
80}
81
82#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
84pub enum MetadataValue {
85 String(String),
87 Integer(i64),
89 Boolean(bool),
91 Bytes(Vec<u8>),
93}
94
95impl Eq for MetadataValue {}
96
97bitflags! {
98 #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
100 pub struct Features: u32 {
101 const STRUCTURED_METADATA = 0b0000_0001;
103 const PING_PONG = 0b0000_0010;
105 const STREAM_PRIORITY = 0b0000_0100;
107 }
108}
109
110impl Default for Features {
111 fn default() -> Self {
112 Self::empty()
113 }
114}
115
116bitflags! {
117 #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
119 pub struct OpenFlags: u8 {
120 const UNIDIRECTIONAL = 0b0000_0001;
122 const HIGH_PRIORITY = 0b0000_0010;
124 }
125}
126
127impl Default for OpenFlags {
128 fn default() -> Self {
129 Self::empty()
130 }
131}
132
133#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
135pub enum ProtocolMessage {
136 Hello(Hello),
138 HelloAck(HelloAck),
140 OpenRequest(OpenRequest),
142 OpenResponse(OpenResponse),
144 StreamClose(StreamClose),
146 Ping(Ping),
148 Pong(Pong),
150}
151
152#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
157pub struct Hello {
158 pub protocol_version: u16,
160 pub features: Features,
162 pub agent: Option<String>,
164}
165
166impl Hello {
167 #[must_use]
169 pub const fn new(features: Features) -> Self {
170 Self {
171 protocol_version: PROTOCOL_VERSION,
172 features,
173 agent: None,
174 }
175 }
176
177 #[must_use]
179 pub fn with_agent(mut self, agent: impl Into<String>) -> Self {
180 self.agent = Some(agent.into());
181 self
182 }
183}
184
185#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
187pub struct HelloAck {
188 pub selected_version: u16,
190 pub selected_features: Features,
192}
193
194#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
200pub struct OpenRequest {
201 pub request_id: u64,
203 pub service: ServiceId,
205 pub metadata: Metadata,
207 pub flags: OpenFlags,
209}
210
211impl OpenRequest {
212 #[must_use]
214 pub fn new(request_id: u64, service: impl Into<ServiceId>) -> Self {
215 Self {
216 request_id,
217 service: service.into(),
218 metadata: Metadata::Empty,
219 flags: OpenFlags::empty(),
220 }
221 }
222
223 #[must_use]
225 pub fn with_metadata(mut self, metadata: Metadata) -> Self {
226 self.metadata = metadata;
227 self
228 }
229
230 #[must_use]
232 pub const fn with_flags(mut self, flags: OpenFlags) -> Self {
233 self.flags = flags;
234 self
235 }
236}
237
238#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
240pub struct OpenResponse {
241 pub request_id: u64,
243 pub status: OpenStatus,
245 pub reason: Option<String>,
247 pub logical_stream_id: Option<u64>,
249}
250
251impl OpenResponse {
252 #[must_use]
254 pub const fn accepted(request_id: u64, logical_stream_id: u64) -> Self {
255 Self {
256 request_id,
257 status: OpenStatus::Accepted,
258 reason: None,
259 logical_stream_id: Some(logical_stream_id),
260 }
261 }
262
263 #[must_use]
265 pub const fn rejected(request_id: u64, code: RejectCode, reason: Option<String>) -> Self {
266 Self {
267 request_id,
268 status: OpenStatus::Rejected(code),
269 reason,
270 logical_stream_id: None,
271 }
272 }
273}
274
275#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
277pub enum OpenStatus {
278 Accepted,
280 Rejected(RejectCode),
282}
283
284#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
286pub enum RejectCode {
287 ServiceUnavailable,
289 UnsupportedService,
291 LimitExceeded,
293 Unauthorized,
295 InternalError,
297}
298
299impl std::fmt::Display for RejectCode {
300 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
301 match self {
302 Self::ServiceUnavailable => write!(f, "service unavailable"),
303 Self::UnsupportedService => write!(f, "unsupported service"),
304 Self::LimitExceeded => write!(f, "limit exceeded"),
305 Self::Unauthorized => write!(f, "unauthorized"),
306 Self::InternalError => write!(f, "internal error"),
307 }
308 }
309}
310
311#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
313pub struct StreamClose {
314 pub logical_stream_id: u64,
316 pub code: CloseCode,
318 pub reason: Option<String>,
320}
321
322impl StreamClose {
323 #[must_use]
325 pub const fn normal(logical_stream_id: u64) -> Self {
326 Self {
327 logical_stream_id,
328 code: CloseCode::Normal,
329 reason: None,
330 }
331 }
332
333 #[must_use]
335 pub fn error(logical_stream_id: u64, reason: impl Into<String>) -> Self {
336 Self {
337 logical_stream_id,
338 code: CloseCode::Error,
339 reason: Some(reason.into()),
340 }
341 }
342}
343
344#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
346pub enum CloseCode {
347 Normal,
349 Error,
351 Timeout,
353 Reset,
355}
356
357impl CloseCode {
358 #[must_use]
360 pub const fn as_u8(self) -> u8 {
361 match self {
362 Self::Normal => 0,
363 Self::Error => 1,
364 Self::Timeout => 2,
365 Self::Reset => 3,
366 }
367 }
368}
369
370#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
372pub struct Ping {
373 pub sequence: u64,
375}
376
377#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
379pub struct Pong {
380 pub sequence: u64,
382}
383
384#[derive(Debug, Clone, Copy, PartialEq, Eq)]
400pub struct StreamBind {
401 pub logical_stream_id: u64,
403}
404
405impl StreamBind {
406 pub const MAGIC: [u8; 4] = [0x51, 0x52, 0x42, 0x56]; pub const VERSION: u8 = 1;
411
412 pub const ENCODED_SIZE: usize = 13; #[must_use]
417 pub const fn new(logical_stream_id: u64) -> Self {
418 Self { logical_stream_id }
419 }
420
421 #[must_use]
423 pub fn encode(&self) -> [u8; Self::ENCODED_SIZE] {
424 let mut buf = [0u8; Self::ENCODED_SIZE];
425 buf[0..4].copy_from_slice(&Self::MAGIC);
426 buf[4] = Self::VERSION;
427 buf[5..13].copy_from_slice(&self.logical_stream_id.to_be_bytes());
428 buf
429 }
430
431 #[must_use]
435 pub fn decode(buf: &[u8; Self::ENCODED_SIZE]) -> Option<Self> {
436 if buf[0..4] != Self::MAGIC {
437 return None;
438 }
439 if buf[4] != Self::VERSION {
440 return None;
441 }
442 let logical_stream_id = u64::from_be_bytes([
443 buf[5], buf[6], buf[7], buf[8], buf[9], buf[10], buf[11], buf[12],
444 ]);
445 Some(Self { logical_stream_id })
446 }
447}
448
449#[cfg(test)]
450mod tests {
451 use super::*;
452
453 mod proptest_tests {
455 use super::*;
456 use crate::{BincodeCodec, Codec};
457 use proptest::prelude::*;
458
459 fn arb_features() -> impl Strategy<Value = Features> {
461 (0u32..8).prop_map(Features::from_bits_truncate)
462 }
463
464 fn arb_open_flags() -> impl Strategy<Value = OpenFlags> {
466 (0u8..4).prop_map(OpenFlags::from_bits_truncate)
467 }
468
469 fn arb_service_id() -> impl Strategy<Value = ServiceId> {
471 "[a-z][a-z0-9_-]{0,31}".prop_map(ServiceId::new)
472 }
473
474 fn arb_metadata_value() -> impl Strategy<Value = MetadataValue> {
476 prop_oneof![
477 ".*".prop_map(MetadataValue::String),
478 any::<i64>().prop_map(MetadataValue::Integer),
479 any::<bool>().prop_map(MetadataValue::Boolean),
480 prop::collection::vec(any::<u8>(), 0..64).prop_map(MetadataValue::Bytes),
481 ]
482 }
483
484 fn arb_metadata() -> impl Strategy<Value = Metadata> {
486 prop_oneof![
487 Just(Metadata::Empty),
488 prop::collection::vec(any::<u8>(), 0..128).prop_map(Metadata::Bytes),
489 prop::collection::hash_map("[a-z]{1,16}", arb_metadata_value(), 0..8)
490 .prop_map(Metadata::Structured),
491 ]
492 }
493
494 fn arb_reject_code() -> impl Strategy<Value = RejectCode> {
496 prop_oneof![
497 Just(RejectCode::ServiceUnavailable),
498 Just(RejectCode::UnsupportedService),
499 Just(RejectCode::LimitExceeded),
500 Just(RejectCode::Unauthorized),
501 Just(RejectCode::InternalError),
502 ]
503 }
504
505 fn arb_close_code() -> impl Strategy<Value = CloseCode> {
507 prop_oneof![
508 Just(CloseCode::Normal),
509 Just(CloseCode::Reset),
510 Just(CloseCode::Timeout),
511 Just(CloseCode::Error),
512 ]
513 }
514
515 fn arb_hello() -> impl Strategy<Value = Hello> {
517 (arb_features(), proptest::option::of(".*")).prop_map(|(features, agent)| {
518 let mut hello = Hello::new(features);
519 hello.agent = agent;
520 hello
521 })
522 }
523
524 fn arb_hello_ack() -> impl Strategy<Value = HelloAck> {
526 (any::<u16>(), arb_features()).prop_map(|(version, features)| HelloAck {
527 selected_version: version,
528 selected_features: features,
529 })
530 }
531
532 fn arb_open_request() -> impl Strategy<Value = OpenRequest> {
534 (
535 any::<u64>(),
536 arb_service_id(),
537 arb_metadata(),
538 arb_open_flags(),
539 )
540 .prop_map(|(request_id, service, metadata, flags)| OpenRequest {
541 request_id,
542 service,
543 metadata,
544 flags,
545 })
546 }
547
548 fn arb_open_response() -> impl Strategy<Value = OpenResponse> {
550 (
551 any::<u64>(),
552 prop_oneof![
553 Just(OpenStatus::Accepted),
554 arb_reject_code().prop_map(OpenStatus::Rejected),
555 ],
556 proptest::option::of(".*"),
557 proptest::option::of(any::<u64>()),
558 )
559 .prop_map(|(request_id, status, reason, logical_stream_id)| {
560 OpenResponse {
561 request_id,
562 status,
563 reason,
564 logical_stream_id,
565 }
566 })
567 }
568
569 fn arb_stream_close() -> impl Strategy<Value = StreamClose> {
571 (any::<u64>(), arb_close_code(), proptest::option::of(".*")).prop_map(
572 |(logical_stream_id, code, reason)| StreamClose {
573 logical_stream_id,
574 code,
575 reason,
576 },
577 )
578 }
579
580 fn arb_ping() -> impl Strategy<Value = Ping> {
582 any::<u64>().prop_map(|sequence| Ping { sequence })
583 }
584
585 fn arb_pong() -> impl Strategy<Value = Pong> {
587 any::<u64>().prop_map(|sequence| Pong { sequence })
588 }
589
590 fn arb_protocol_message() -> impl Strategy<Value = ProtocolMessage> {
592 prop_oneof![
593 arb_hello().prop_map(ProtocolMessage::Hello),
594 arb_hello_ack().prop_map(ProtocolMessage::HelloAck),
595 arb_open_request().prop_map(ProtocolMessage::OpenRequest),
596 arb_open_response().prop_map(ProtocolMessage::OpenResponse),
597 arb_stream_close().prop_map(ProtocolMessage::StreamClose),
598 arb_ping().prop_map(ProtocolMessage::Ping),
599 arb_pong().prop_map(ProtocolMessage::Pong),
600 ]
601 }
602
603 proptest! {
604 #![proptest_config(ProptestConfig::with_cases(1000))]
605
606 #[test]
607 fn protocol_message_round_trip(msg in arb_protocol_message()) {
608 let codec = BincodeCodec::new();
609 let encoded = codec.encode(&msg).expect("encoding should succeed");
610 let decoded: ProtocolMessage = codec.decode(&encoded).expect("decoding should succeed");
611 prop_assert_eq!(msg, decoded);
612 }
613
614 #[test]
615 fn hello_round_trip(msg in arb_hello()) {
616 let codec = BincodeCodec::new();
617 let wrapped = ProtocolMessage::Hello(msg.clone());
618 let encoded = codec.encode(&wrapped).expect("encoding should succeed");
619 let decoded: ProtocolMessage = codec.decode(&encoded).expect("decoding should succeed");
620 prop_assert_eq!(ProtocolMessage::Hello(msg), decoded);
621 }
622
623 #[test]
624 fn open_request_round_trip(msg in arb_open_request()) {
625 let codec = BincodeCodec::new();
626 let wrapped = ProtocolMessage::OpenRequest(msg.clone());
627 let encoded = codec.encode(&wrapped).expect("encoding should succeed");
628 let decoded: ProtocolMessage = codec.decode(&encoded).expect("decoding should succeed");
629 prop_assert_eq!(ProtocolMessage::OpenRequest(msg), decoded);
630 }
631
632 #[test]
633 fn stream_bind_round_trip(id in any::<u64>()) {
634 let bind = StreamBind::new(id);
635 let encoded = bind.encode();
636 let decoded = StreamBind::decode(&encoded).expect("decode should succeed");
637 prop_assert_eq!(bind.logical_stream_id, decoded.logical_stream_id);
638 }
639
640 #[test]
641 fn service_id_preserves_content(s in "[a-z][a-z0-9_-]{0,63}") {
642 let id = ServiceId::new(&s);
643 prop_assert_eq!(id.as_str(), s.as_str());
644 prop_assert_eq!(format!("{id}"), s);
645 }
646 }
647 }
648
649 #[test]
650 fn service_id_from_str() {
651 let id: ServiceId = "ssh".into();
652 assert_eq!(id.as_str(), "ssh");
653 }
654
655 #[test]
656 fn service_id_display() {
657 let id = ServiceId::new("http");
658 assert_eq!(format!("{id}"), "http");
659 }
660
661 #[test]
662 fn hello_with_agent() {
663 let hello = Hello::new(Features::PING_PONG).with_agent("test/1.0");
664 assert_eq!(hello.protocol_version, PROTOCOL_VERSION);
665 assert_eq!(hello.features, Features::PING_PONG);
666 assert_eq!(hello.agent.as_deref(), Some("test/1.0"));
667 }
668
669 #[test]
670 fn open_request_builder() {
671 let req = OpenRequest::new(42, "tcp")
672 .with_metadata(Metadata::Bytes(vec![1, 2, 3]))
673 .with_flags(OpenFlags::HIGH_PRIORITY);
674
675 assert_eq!(req.request_id, 42);
676 assert_eq!(req.service.as_str(), "tcp");
677 assert_eq!(req.metadata, Metadata::Bytes(vec![1, 2, 3]));
678 assert!(req.flags.contains(OpenFlags::HIGH_PRIORITY));
679 }
680
681 #[test]
682 fn open_response_accepted() {
683 let resp = OpenResponse::accepted(42, 100);
684 assert_eq!(resp.request_id, 42);
685 assert_eq!(resp.status, OpenStatus::Accepted);
686 assert_eq!(resp.logical_stream_id, Some(100));
687 }
688
689 #[test]
690 fn open_response_rejected() {
691 let resp = OpenResponse::rejected(42, RejectCode::Unauthorized, Some("denied".into()));
692 assert_eq!(resp.request_id, 42);
693 assert_eq!(resp.status, OpenStatus::Rejected(RejectCode::Unauthorized));
694 assert_eq!(resp.reason.as_deref(), Some("denied"));
695 assert_eq!(resp.logical_stream_id, None);
696 }
697
698 #[test]
699 fn stream_close_normal() {
700 let close = StreamClose::normal(99);
701 assert_eq!(close.logical_stream_id, 99);
702 assert_eq!(close.code, CloseCode::Normal);
703 assert!(close.reason.is_none());
704 }
705
706 #[test]
707 fn features_intersection() {
708 let a = Features::PING_PONG | Features::STRUCTURED_METADATA;
709 let b = Features::PING_PONG | Features::STREAM_PRIORITY;
710 let intersection = a & b;
711 assert_eq!(intersection, Features::PING_PONG);
712 }
713
714 #[test]
715 fn stream_bind_encode_decode() {
716 let bind = StreamBind::new(0x0102_0304_0506_0708);
717 let encoded = bind.encode();
718
719 assert_eq!(&encoded[0..4], &StreamBind::MAGIC);
721 assert_eq!(encoded[4], StreamBind::VERSION);
723 assert_eq!(
725 &encoded[5..13],
726 &[0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08]
727 );
728
729 let decoded = StreamBind::decode(&encoded).expect("decode should succeed");
731 assert_eq!(decoded.logical_stream_id, 0x0102_0304_0506_0708);
732 }
733
734 #[test]
735 fn stream_bind_invalid_magic() {
736 let mut buf = [0u8; StreamBind::ENCODED_SIZE];
737 buf[0..4].copy_from_slice(&[0x00, 0x00, 0x00, 0x00]); buf[4] = StreamBind::VERSION;
739 assert!(StreamBind::decode(&buf).is_none());
740 }
741
742 #[test]
743 fn stream_bind_invalid_version() {
744 let mut buf = [0u8; StreamBind::ENCODED_SIZE];
745 buf[0..4].copy_from_slice(&StreamBind::MAGIC);
746 buf[4] = 0xFF; assert!(StreamBind::decode(&buf).is_none());
748 }
749}