quic_reverse_control/
messages.rs

1// Copyright 2024-2026 Farlight Networks, LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Protocol message definitions.
16//!
17//! This module defines all control plane messages exchanged between peers
18//! during a quic-reverse session.
19
20use bitflags::bitflags;
21use serde::{Deserialize, Serialize};
22use std::collections::HashMap;
23
24/// Current protocol version.
25pub const PROTOCOL_VERSION: u16 = 1;
26
27/// Identifies a logical service for multiplexing.
28///
29/// Services are identified by string names such as "ssh", "http", or "tcp".
30/// The service ID is used to route incoming stream requests to the appropriate
31/// handler.
32#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
33pub struct ServiceId(pub String);
34
35impl ServiceId {
36    /// Creates a new service identifier.
37    #[must_use]
38    pub fn new(name: impl Into<String>) -> Self {
39        Self(name.into())
40    }
41
42    /// Returns the service name as a string slice.
43    #[must_use]
44    pub fn as_str(&self) -> &str {
45        &self.0
46    }
47}
48
49impl std::fmt::Display for ServiceId {
50    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
51        write!(f, "{}", self.0)
52    }
53}
54
55impl From<&str> for ServiceId {
56    fn from(s: &str) -> Self {
57        Self(s.to_owned())
58    }
59}
60
61impl From<String> for ServiceId {
62    fn from(s: String) -> Self {
63        Self(s)
64    }
65}
66
67/// Metadata attached to stream open requests.
68///
69/// Metadata can be empty, raw bytes, or a structured key-value map.
70/// The format is negotiated during the `Hello`/`HelloAck` exchange.
71#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
72pub enum Metadata {
73    /// No metadata.
74    #[default]
75    Empty,
76    /// Raw byte payload.
77    Bytes(Vec<u8>),
78    /// Structured key-value pairs.
79    Structured(HashMap<String, MetadataValue>),
80}
81
82/// A value within structured metadata.
83#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
84pub enum MetadataValue {
85    /// String value.
86    String(String),
87    /// Integer value.
88    Integer(i64),
89    /// Boolean value.
90    Boolean(bool),
91    /// Binary data.
92    Bytes(Vec<u8>),
93}
94
95impl Eq for MetadataValue {}
96
97bitflags! {
98    /// Feature flags negotiated during `Hello`/`HelloAck` exchange.
99    #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
100    pub struct Features: u32 {
101        /// Support for structured metadata in `OpenRequest`.
102        const STRUCTURED_METADATA = 0b0000_0001;
103        /// Support for Ping/Pong keep-alive messages.
104        const PING_PONG = 0b0000_0010;
105        /// Support for stream priority hints.
106        const STREAM_PRIORITY = 0b0000_0100;
107    }
108}
109
110impl Default for Features {
111    fn default() -> Self {
112        Self::empty()
113    }
114}
115
116bitflags! {
117    /// Flags for `OpenRequest` messages.
118    #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
119    pub struct OpenFlags: u8 {
120        /// Request a unidirectional stream (send only).
121        const UNIDIRECTIONAL = 0b0000_0001;
122        /// High priority stream hint.
123        const HIGH_PRIORITY = 0b0000_0010;
124    }
125}
126
127impl Default for OpenFlags {
128    fn default() -> Self {
129        Self::empty()
130    }
131}
132
133/// All protocol messages that can be exchanged on the control stream.
134#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
135pub enum ProtocolMessage {
136    /// Initial handshake message.
137    Hello(Hello),
138    /// Handshake acknowledgment.
139    HelloAck(HelloAck),
140    /// Request to open a reverse stream.
141    OpenRequest(OpenRequest),
142    /// Response to an open request.
143    OpenResponse(OpenResponse),
144    /// Notification that a stream has closed.
145    StreamClose(StreamClose),
146    /// Keep-alive ping.
147    Ping(Ping),
148    /// Keep-alive pong.
149    Pong(Pong),
150}
151
152/// Initial handshake message sent by both peers.
153///
154/// Each peer sends a Hello message after the QUIC connection is established.
155/// The messages are used to negotiate protocol version and features.
156#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
157pub struct Hello {
158    /// Protocol version supported by this peer.
159    pub protocol_version: u16,
160    /// Feature flags supported by this peer.
161    pub features: Features,
162    /// Optional agent identifier (e.g., "quic-reverse/0.1.0").
163    pub agent: Option<String>,
164}
165
166impl Hello {
167    /// Creates a new Hello message with the current protocol version.
168    #[must_use]
169    pub const fn new(features: Features) -> Self {
170        Self {
171            protocol_version: PROTOCOL_VERSION,
172            features,
173            agent: None,
174        }
175    }
176
177    /// Sets the agent identifier.
178    #[must_use]
179    pub fn with_agent(mut self, agent: impl Into<String>) -> Self {
180        self.agent = Some(agent.into());
181        self
182    }
183}
184
185/// Handshake acknowledgment confirming negotiated parameters.
186#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
187pub struct HelloAck {
188    /// Selected protocol version (highest mutually supported).
189    pub selected_version: u16,
190    /// Selected feature set (intersection of both peers' features).
191    pub selected_features: Features,
192}
193
194/// Request to open a reverse stream.
195///
196/// Sent by the peer that wants to initiate a reverse stream. The receiving
197/// peer will respond with an `OpenResponse` and, if accepted, open a new
198/// QUIC stream back to the initiator.
199#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
200pub struct OpenRequest {
201    /// Unique identifier for this request, used to correlate responses.
202    pub request_id: u64,
203    /// Target service identifier.
204    pub service: ServiceId,
205    /// Optional metadata for the stream.
206    pub metadata: Metadata,
207    /// Request flags.
208    pub flags: OpenFlags,
209}
210
211impl OpenRequest {
212    /// Creates a new open request for the specified service.
213    #[must_use]
214    pub fn new(request_id: u64, service: impl Into<ServiceId>) -> Self {
215        Self {
216            request_id,
217            service: service.into(),
218            metadata: Metadata::Empty,
219            flags: OpenFlags::empty(),
220        }
221    }
222
223    /// Sets the metadata for this request.
224    #[must_use]
225    pub fn with_metadata(mut self, metadata: Metadata) -> Self {
226        self.metadata = metadata;
227        self
228    }
229
230    /// Sets the flags for this request.
231    #[must_use]
232    pub const fn with_flags(mut self, flags: OpenFlags) -> Self {
233        self.flags = flags;
234        self
235    }
236}
237
238/// Response to an `OpenRequest`.
239#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
240pub struct OpenResponse {
241    /// Request ID from the corresponding `OpenRequest`.
242    pub request_id: u64,
243    /// Result of the open request.
244    pub status: OpenStatus,
245    /// Optional reason message (typically for rejections).
246    pub reason: Option<String>,
247    /// Logical stream ID assigned to this stream (if accepted).
248    pub logical_stream_id: Option<u64>,
249}
250
251impl OpenResponse {
252    /// Creates an accepted response with the given logical stream ID.
253    #[must_use]
254    pub const fn accepted(request_id: u64, logical_stream_id: u64) -> Self {
255        Self {
256            request_id,
257            status: OpenStatus::Accepted,
258            reason: None,
259            logical_stream_id: Some(logical_stream_id),
260        }
261    }
262
263    /// Creates a rejected response with the given code and optional reason.
264    #[must_use]
265    pub const fn rejected(request_id: u64, code: RejectCode, reason: Option<String>) -> Self {
266        Self {
267            request_id,
268            status: OpenStatus::Rejected(code),
269            reason,
270            logical_stream_id: None,
271        }
272    }
273}
274
275/// Status of an `OpenRequest`.
276#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
277pub enum OpenStatus {
278    /// Request accepted; stream will be opened.
279    Accepted,
280    /// Request rejected with the given code.
281    Rejected(RejectCode),
282}
283
284/// Reason codes for rejecting an `OpenRequest`.
285#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
286pub enum RejectCode {
287    /// The requested service is not available.
288    ServiceUnavailable,
289    /// The requested service is not supported.
290    UnsupportedService,
291    /// Resource limits have been exceeded.
292    LimitExceeded,
293    /// The request is not authorized.
294    Unauthorized,
295    /// An internal error occurred.
296    InternalError,
297}
298
299impl std::fmt::Display for RejectCode {
300    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
301        match self {
302            Self::ServiceUnavailable => write!(f, "service unavailable"),
303            Self::UnsupportedService => write!(f, "unsupported service"),
304            Self::LimitExceeded => write!(f, "limit exceeded"),
305            Self::Unauthorized => write!(f, "unauthorized"),
306            Self::InternalError => write!(f, "internal error"),
307        }
308    }
309}
310
311/// Notification that a stream has closed.
312#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
313pub struct StreamClose {
314    /// Logical stream ID of the closed stream.
315    pub logical_stream_id: u64,
316    /// Close code indicating the reason.
317    pub code: CloseCode,
318    /// Optional human-readable reason.
319    pub reason: Option<String>,
320}
321
322impl StreamClose {
323    /// Creates a normal close notification.
324    #[must_use]
325    pub const fn normal(logical_stream_id: u64) -> Self {
326        Self {
327            logical_stream_id,
328            code: CloseCode::Normal,
329            reason: None,
330        }
331    }
332
333    /// Creates an error close notification.
334    #[must_use]
335    pub fn error(logical_stream_id: u64, reason: impl Into<String>) -> Self {
336        Self {
337            logical_stream_id,
338            code: CloseCode::Error,
339            reason: Some(reason.into()),
340        }
341    }
342}
343
344/// Close codes for `StreamClose` messages.
345#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
346pub enum CloseCode {
347    /// Normal closure.
348    Normal,
349    /// Error condition.
350    Error,
351    /// Timeout expired.
352    Timeout,
353    /// Stream was reset.
354    Reset,
355}
356
357impl CloseCode {
358    /// Returns the numeric code for wire transmission.
359    #[must_use]
360    pub const fn as_u8(self) -> u8 {
361        match self {
362            Self::Normal => 0,
363            Self::Error => 1,
364            Self::Timeout => 2,
365            Self::Reset => 3,
366        }
367    }
368}
369
370/// Keep-alive ping message.
371#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
372pub struct Ping {
373    /// Sequence number for matching with Pong responses.
374    pub sequence: u64,
375}
376
377/// Keep-alive pong response.
378#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
379pub struct Pong {
380    /// Sequence number from the corresponding Ping.
381    pub sequence: u64,
382}
383
384/// Stream binding frame sent on data streams.
385///
386/// When a data stream is opened, the first frame sent must be a `StreamBind`
387/// to identify which logical stream this QUIC stream belongs to. This allows
388/// the receiving peer to match the data stream with the corresponding
389/// `OpenRequest`/`OpenResponse` exchange.
390///
391/// # Wire Format
392///
393/// The stream bind frame is encoded as:
394/// - 4 bytes: magic number (`0x51524256`, "QRBV" for "Quic Reverse Bind Version")
395/// - 1 byte: version (currently 1)
396/// - 8 bytes: `logical_stream_id` (big-endian u64)
397///
398/// Total: 13 bytes
399#[derive(Debug, Clone, Copy, PartialEq, Eq)]
400pub struct StreamBind {
401    /// Logical stream ID assigned during `OpenResponse`.
402    pub logical_stream_id: u64,
403}
404
405impl StreamBind {
406    /// Magic number identifying the stream bind frame.
407    pub const MAGIC: [u8; 4] = [0x51, 0x52, 0x42, 0x56]; // "QRBV"
408
409    /// Current stream bind version.
410    pub const VERSION: u8 = 1;
411
412    /// Size of the encoded stream bind frame.
413    pub const ENCODED_SIZE: usize = 13; // 4 + 1 + 8
414
415    /// Creates a new stream bind frame.
416    #[must_use]
417    pub const fn new(logical_stream_id: u64) -> Self {
418        Self { logical_stream_id }
419    }
420
421    /// Encodes the stream bind to bytes.
422    #[must_use]
423    pub fn encode(&self) -> [u8; Self::ENCODED_SIZE] {
424        let mut buf = [0u8; Self::ENCODED_SIZE];
425        buf[0..4].copy_from_slice(&Self::MAGIC);
426        buf[4] = Self::VERSION;
427        buf[5..13].copy_from_slice(&self.logical_stream_id.to_be_bytes());
428        buf
429    }
430
431    /// Decodes a stream bind from bytes.
432    ///
433    /// Returns `None` if the magic number is invalid or the version is unsupported.
434    #[must_use]
435    pub fn decode(buf: &[u8; Self::ENCODED_SIZE]) -> Option<Self> {
436        if buf[0..4] != Self::MAGIC {
437            return None;
438        }
439        if buf[4] != Self::VERSION {
440            return None;
441        }
442        let logical_stream_id = u64::from_be_bytes([
443            buf[5], buf[6], buf[7], buf[8], buf[9], buf[10], buf[11], buf[12],
444        ]);
445        Some(Self { logical_stream_id })
446    }
447}
448
449#[cfg(test)]
450mod tests {
451    use super::*;
452
453    #[test]
454    fn service_id_from_str() {
455        let id: ServiceId = "ssh".into();
456        assert_eq!(id.as_str(), "ssh");
457    }
458
459    #[test]
460    fn service_id_display() {
461        let id = ServiceId::new("http");
462        assert_eq!(format!("{id}"), "http");
463    }
464
465    #[test]
466    fn hello_with_agent() {
467        let hello = Hello::new(Features::PING_PONG).with_agent("test/1.0");
468        assert_eq!(hello.protocol_version, PROTOCOL_VERSION);
469        assert_eq!(hello.features, Features::PING_PONG);
470        assert_eq!(hello.agent.as_deref(), Some("test/1.0"));
471    }
472
473    #[test]
474    fn open_request_builder() {
475        let req = OpenRequest::new(42, "tcp")
476            .with_metadata(Metadata::Bytes(vec![1, 2, 3]))
477            .with_flags(OpenFlags::HIGH_PRIORITY);
478
479        assert_eq!(req.request_id, 42);
480        assert_eq!(req.service.as_str(), "tcp");
481        assert_eq!(req.metadata, Metadata::Bytes(vec![1, 2, 3]));
482        assert!(req.flags.contains(OpenFlags::HIGH_PRIORITY));
483    }
484
485    #[test]
486    fn open_response_accepted() {
487        let resp = OpenResponse::accepted(42, 100);
488        assert_eq!(resp.request_id, 42);
489        assert_eq!(resp.status, OpenStatus::Accepted);
490        assert_eq!(resp.logical_stream_id, Some(100));
491    }
492
493    #[test]
494    fn open_response_rejected() {
495        let resp = OpenResponse::rejected(42, RejectCode::Unauthorized, Some("denied".into()));
496        assert_eq!(resp.request_id, 42);
497        assert_eq!(resp.status, OpenStatus::Rejected(RejectCode::Unauthorized));
498        assert_eq!(resp.reason.as_deref(), Some("denied"));
499        assert_eq!(resp.logical_stream_id, None);
500    }
501
502    #[test]
503    fn stream_close_normal() {
504        let close = StreamClose::normal(99);
505        assert_eq!(close.logical_stream_id, 99);
506        assert_eq!(close.code, CloseCode::Normal);
507        assert!(close.reason.is_none());
508    }
509
510    #[test]
511    fn features_intersection() {
512        let a = Features::PING_PONG | Features::STRUCTURED_METADATA;
513        let b = Features::PING_PONG | Features::STREAM_PRIORITY;
514        let intersection = a & b;
515        assert_eq!(intersection, Features::PING_PONG);
516    }
517
518    #[test]
519    fn stream_bind_encode_decode() {
520        let bind = StreamBind::new(0x0102_0304_0506_0708);
521        let encoded = bind.encode();
522
523        // Check magic number
524        assert_eq!(&encoded[0..4], &StreamBind::MAGIC);
525        // Check version
526        assert_eq!(encoded[4], StreamBind::VERSION);
527        // Check logical_stream_id (big-endian)
528        assert_eq!(
529            &encoded[5..13],
530            &[0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08]
531        );
532
533        // Decode and verify
534        let decoded = StreamBind::decode(&encoded).expect("decode should succeed");
535        assert_eq!(decoded.logical_stream_id, 0x0102_0304_0506_0708);
536    }
537
538    #[test]
539    fn stream_bind_invalid_magic() {
540        let mut buf = [0u8; StreamBind::ENCODED_SIZE];
541        buf[0..4].copy_from_slice(&[0x00, 0x00, 0x00, 0x00]); // Wrong magic
542        buf[4] = StreamBind::VERSION;
543        assert!(StreamBind::decode(&buf).is_none());
544    }
545
546    #[test]
547    fn stream_bind_invalid_version() {
548        let mut buf = [0u8; StreamBind::ENCODED_SIZE];
549        buf[0..4].copy_from_slice(&StreamBind::MAGIC);
550        buf[4] = 0xFF; // Wrong version
551        assert!(StreamBind::decode(&buf).is_none());
552    }
553}