1use bitflags::bitflags;
21use serde::{Deserialize, Serialize};
22use std::collections::HashMap;
23
24pub const PROTOCOL_VERSION: u16 = 1;
26
27#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
33pub struct ServiceId(pub String);
34
35impl ServiceId {
36 #[must_use]
38 pub fn new(name: impl Into<String>) -> Self {
39 Self(name.into())
40 }
41
42 #[must_use]
44 pub fn as_str(&self) -> &str {
45 &self.0
46 }
47}
48
49impl std::fmt::Display for ServiceId {
50 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
51 write!(f, "{}", self.0)
52 }
53}
54
55impl From<&str> for ServiceId {
56 fn from(s: &str) -> Self {
57 Self(s.to_owned())
58 }
59}
60
61impl From<String> for ServiceId {
62 fn from(s: String) -> Self {
63 Self(s)
64 }
65}
66
67#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
72pub enum Metadata {
73 #[default]
75 Empty,
76 Bytes(Vec<u8>),
78 Structured(HashMap<String, MetadataValue>),
80}
81
82#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
84pub enum MetadataValue {
85 String(String),
87 Integer(i64),
89 Boolean(bool),
91 Bytes(Vec<u8>),
93}
94
95impl Eq for MetadataValue {}
96
97bitflags! {
98 #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
100 pub struct Features: u32 {
101 const STRUCTURED_METADATA = 0b0000_0001;
103 const PING_PONG = 0b0000_0010;
105 const STREAM_PRIORITY = 0b0000_0100;
107 }
108}
109
110impl Default for Features {
111 fn default() -> Self {
112 Self::empty()
113 }
114}
115
116bitflags! {
117 #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
119 pub struct OpenFlags: u8 {
120 const UNIDIRECTIONAL = 0b0000_0001;
122 const HIGH_PRIORITY = 0b0000_0010;
124 }
125}
126
127impl Default for OpenFlags {
128 fn default() -> Self {
129 Self::empty()
130 }
131}
132
133#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
135pub enum ProtocolMessage {
136 Hello(Hello),
138 HelloAck(HelloAck),
140 OpenRequest(OpenRequest),
142 OpenResponse(OpenResponse),
144 StreamClose(StreamClose),
146 Ping(Ping),
148 Pong(Pong),
150}
151
152#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
157pub struct Hello {
158 pub protocol_version: u16,
160 pub features: Features,
162 pub agent: Option<String>,
164}
165
166impl Hello {
167 #[must_use]
169 pub const fn new(features: Features) -> Self {
170 Self {
171 protocol_version: PROTOCOL_VERSION,
172 features,
173 agent: None,
174 }
175 }
176
177 #[must_use]
179 pub fn with_agent(mut self, agent: impl Into<String>) -> Self {
180 self.agent = Some(agent.into());
181 self
182 }
183}
184
185#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
187pub struct HelloAck {
188 pub selected_version: u16,
190 pub selected_features: Features,
192}
193
194#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
200pub struct OpenRequest {
201 pub request_id: u64,
203 pub service: ServiceId,
205 pub metadata: Metadata,
207 pub flags: OpenFlags,
209}
210
211impl OpenRequest {
212 #[must_use]
214 pub fn new(request_id: u64, service: impl Into<ServiceId>) -> Self {
215 Self {
216 request_id,
217 service: service.into(),
218 metadata: Metadata::Empty,
219 flags: OpenFlags::empty(),
220 }
221 }
222
223 #[must_use]
225 pub fn with_metadata(mut self, metadata: Metadata) -> Self {
226 self.metadata = metadata;
227 self
228 }
229
230 #[must_use]
232 pub const fn with_flags(mut self, flags: OpenFlags) -> Self {
233 self.flags = flags;
234 self
235 }
236}
237
238#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
240pub struct OpenResponse {
241 pub request_id: u64,
243 pub status: OpenStatus,
245 pub reason: Option<String>,
247 pub logical_stream_id: Option<u64>,
249}
250
251impl OpenResponse {
252 #[must_use]
254 pub const fn accepted(request_id: u64, logical_stream_id: u64) -> Self {
255 Self {
256 request_id,
257 status: OpenStatus::Accepted,
258 reason: None,
259 logical_stream_id: Some(logical_stream_id),
260 }
261 }
262
263 #[must_use]
265 pub const fn rejected(request_id: u64, code: RejectCode, reason: Option<String>) -> Self {
266 Self {
267 request_id,
268 status: OpenStatus::Rejected(code),
269 reason,
270 logical_stream_id: None,
271 }
272 }
273}
274
275#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
277pub enum OpenStatus {
278 Accepted,
280 Rejected(RejectCode),
282}
283
284#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
286pub enum RejectCode {
287 ServiceUnavailable,
289 UnsupportedService,
291 LimitExceeded,
293 Unauthorized,
295 InternalError,
297}
298
299impl std::fmt::Display for RejectCode {
300 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
301 match self {
302 Self::ServiceUnavailable => write!(f, "service unavailable"),
303 Self::UnsupportedService => write!(f, "unsupported service"),
304 Self::LimitExceeded => write!(f, "limit exceeded"),
305 Self::Unauthorized => write!(f, "unauthorized"),
306 Self::InternalError => write!(f, "internal error"),
307 }
308 }
309}
310
311#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
313pub struct StreamClose {
314 pub logical_stream_id: u64,
316 pub code: CloseCode,
318 pub reason: Option<String>,
320}
321
322impl StreamClose {
323 #[must_use]
325 pub const fn normal(logical_stream_id: u64) -> Self {
326 Self {
327 logical_stream_id,
328 code: CloseCode::Normal,
329 reason: None,
330 }
331 }
332
333 #[must_use]
335 pub fn error(logical_stream_id: u64, reason: impl Into<String>) -> Self {
336 Self {
337 logical_stream_id,
338 code: CloseCode::Error,
339 reason: Some(reason.into()),
340 }
341 }
342}
343
344#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
346pub enum CloseCode {
347 Normal,
349 Error,
351 Timeout,
353 Reset,
355}
356
357impl CloseCode {
358 #[must_use]
360 pub const fn as_u8(self) -> u8 {
361 match self {
362 Self::Normal => 0,
363 Self::Error => 1,
364 Self::Timeout => 2,
365 Self::Reset => 3,
366 }
367 }
368}
369
370#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
372pub struct Ping {
373 pub sequence: u64,
375}
376
377#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
379pub struct Pong {
380 pub sequence: u64,
382}
383
384#[derive(Debug, Clone, Copy, PartialEq, Eq)]
400pub struct StreamBind {
401 pub logical_stream_id: u64,
403}
404
405impl StreamBind {
406 pub const MAGIC: [u8; 4] = [0x51, 0x52, 0x42, 0x56]; pub const VERSION: u8 = 1;
411
412 pub const ENCODED_SIZE: usize = 13; #[must_use]
417 pub const fn new(logical_stream_id: u64) -> Self {
418 Self { logical_stream_id }
419 }
420
421 #[must_use]
423 pub fn encode(&self) -> [u8; Self::ENCODED_SIZE] {
424 let mut buf = [0u8; Self::ENCODED_SIZE];
425 buf[0..4].copy_from_slice(&Self::MAGIC);
426 buf[4] = Self::VERSION;
427 buf[5..13].copy_from_slice(&self.logical_stream_id.to_be_bytes());
428 buf
429 }
430
431 #[must_use]
435 pub fn decode(buf: &[u8; Self::ENCODED_SIZE]) -> Option<Self> {
436 if buf[0..4] != Self::MAGIC {
437 return None;
438 }
439 if buf[4] != Self::VERSION {
440 return None;
441 }
442 let logical_stream_id = u64::from_be_bytes([
443 buf[5], buf[6], buf[7], buf[8], buf[9], buf[10], buf[11], buf[12],
444 ]);
445 Some(Self { logical_stream_id })
446 }
447}
448
449#[cfg(test)]
450mod tests {
451 use super::*;
452
453 #[test]
454 fn service_id_from_str() {
455 let id: ServiceId = "ssh".into();
456 assert_eq!(id.as_str(), "ssh");
457 }
458
459 #[test]
460 fn service_id_display() {
461 let id = ServiceId::new("http");
462 assert_eq!(format!("{id}"), "http");
463 }
464
465 #[test]
466 fn hello_with_agent() {
467 let hello = Hello::new(Features::PING_PONG).with_agent("test/1.0");
468 assert_eq!(hello.protocol_version, PROTOCOL_VERSION);
469 assert_eq!(hello.features, Features::PING_PONG);
470 assert_eq!(hello.agent.as_deref(), Some("test/1.0"));
471 }
472
473 #[test]
474 fn open_request_builder() {
475 let req = OpenRequest::new(42, "tcp")
476 .with_metadata(Metadata::Bytes(vec![1, 2, 3]))
477 .with_flags(OpenFlags::HIGH_PRIORITY);
478
479 assert_eq!(req.request_id, 42);
480 assert_eq!(req.service.as_str(), "tcp");
481 assert_eq!(req.metadata, Metadata::Bytes(vec![1, 2, 3]));
482 assert!(req.flags.contains(OpenFlags::HIGH_PRIORITY));
483 }
484
485 #[test]
486 fn open_response_accepted() {
487 let resp = OpenResponse::accepted(42, 100);
488 assert_eq!(resp.request_id, 42);
489 assert_eq!(resp.status, OpenStatus::Accepted);
490 assert_eq!(resp.logical_stream_id, Some(100));
491 }
492
493 #[test]
494 fn open_response_rejected() {
495 let resp = OpenResponse::rejected(42, RejectCode::Unauthorized, Some("denied".into()));
496 assert_eq!(resp.request_id, 42);
497 assert_eq!(resp.status, OpenStatus::Rejected(RejectCode::Unauthorized));
498 assert_eq!(resp.reason.as_deref(), Some("denied"));
499 assert_eq!(resp.logical_stream_id, None);
500 }
501
502 #[test]
503 fn stream_close_normal() {
504 let close = StreamClose::normal(99);
505 assert_eq!(close.logical_stream_id, 99);
506 assert_eq!(close.code, CloseCode::Normal);
507 assert!(close.reason.is_none());
508 }
509
510 #[test]
511 fn features_intersection() {
512 let a = Features::PING_PONG | Features::STRUCTURED_METADATA;
513 let b = Features::PING_PONG | Features::STREAM_PRIORITY;
514 let intersection = a & b;
515 assert_eq!(intersection, Features::PING_PONG);
516 }
517
518 #[test]
519 fn stream_bind_encode_decode() {
520 let bind = StreamBind::new(0x0102_0304_0506_0708);
521 let encoded = bind.encode();
522
523 assert_eq!(&encoded[0..4], &StreamBind::MAGIC);
525 assert_eq!(encoded[4], StreamBind::VERSION);
527 assert_eq!(
529 &encoded[5..13],
530 &[0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08]
531 );
532
533 let decoded = StreamBind::decode(&encoded).expect("decode should succeed");
535 assert_eq!(decoded.logical_stream_id, 0x0102_0304_0506_0708);
536 }
537
538 #[test]
539 fn stream_bind_invalid_magic() {
540 let mut buf = [0u8; StreamBind::ENCODED_SIZE];
541 buf[0..4].copy_from_slice(&[0x00, 0x00, 0x00, 0x00]); buf[4] = StreamBind::VERSION;
543 assert!(StreamBind::decode(&buf).is_none());
544 }
545
546 #[test]
547 fn stream_bind_invalid_version() {
548 let mut buf = [0u8; StreamBind::ENCODED_SIZE];
549 buf[0..4].copy_from_slice(&StreamBind::MAGIC);
550 buf[4] = 0xFF; assert!(StreamBind::decode(&buf).is_none());
552 }
553}