quic_reverse/
session.rs

1// Copyright 2024-2026 Farlight Networks, LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15// Mutex::lock().unwrap() is the standard pattern in Rust. The lock only fails
16// if the mutex is poisoned (a thread panicked while holding it), which indicates
17// a bug elsewhere that should propagate. We also suppress the "missing # Panics"
18// warning since these are not user-actionable panics.
19#![allow(clippy::unwrap_used, clippy::missing_panics_doc)]
20
21//! Session management for quic-reverse.
22//!
23//! The `Session` type wraps a QUIC connection and provides the high-level API
24//! for reverse-initiated stream operations.
25
26use crate::control::{ControlReader, ControlStream, ControlWriter};
27use crate::error::TimeoutKind;
28use crate::negotiation::{negotiate_client, negotiate_server, NegotiatedParams};
29use crate::registry::{OpenResult, StreamRegistry};
30use crate::state::State;
31use crate::{Config, Error, Role};
32use quic_reverse_control::{
33    CloseCode, Metadata, OpenRequest, OpenResponse, OpenStatus, ProtocolMessage, RejectCode,
34    ServiceId, StreamClose,
35};
36use quic_reverse_transport::Connection;
37use std::collections::HashMap;
38use std::sync::atomic::{AtomicU64, AtomicU8, Ordering};
39use std::sync::{Arc, Mutex};
40use std::time::Instant;
41use tokio::sync::oneshot;
42use tokio::time::timeout;
43use tracing::{debug, error, info, instrument, trace, warn};
44
45/// Inner session state shared between [`Session`] and [`SessionHandle`].
46pub(crate) struct SessionInner<C: Connection> {
47    /// The underlying QUIC connection.
48    pub(crate) connection: C,
49    /// Session configuration.
50    pub(crate) config: Config,
51    /// Our role in the session.
52    pub(crate) role: Role,
53    /// Current session state.
54    pub(crate) state: AtomicU8,
55    /// Parameters from successful negotiation.
56    pub(crate) negotiated: Mutex<Option<NegotiatedParams>>,
57    /// Registry for tracking streams.
58    pub(crate) registry: Mutex<StreamRegistry>,
59    /// Next ping sequence number.
60    pub(crate) next_ping_seq: AtomicU64,
61    /// Pending pings awaiting pong responses.
62    pub(crate) pending_pings: Mutex<HashMap<u64, PendingPing>>,
63}
64
65/// A pending ping awaiting a pong response.
66pub(crate) struct PendingPing {
67    /// When the ping was sent.
68    pub(crate) sent_at: Instant,
69    /// Channel to notify when pong is received.
70    pub(crate) response_tx: oneshot::Sender<()>,
71}
72
73/// A quic-reverse session over a QUIC connection.
74///
75/// The session provides the main API for:
76/// - Initiating reverse streams to the peer
77/// - Accepting incoming stream requests from the peer
78/// - Managing the session lifecycle
79///
80/// # Example
81///
82/// ```ignore
83/// use quic_reverse::{Session, Config, Role};
84///
85/// // Create a session as the client
86/// let session = Session::new(connection, Role::Client, Config::default());
87///
88/// // Start the session (performs negotiation)
89/// let mut handle = session.start().await?;
90///
91/// // Open a reverse stream
92/// let (send, recv) = handle.open("ssh", Metadata::Empty).await?;
93/// ```
94pub struct Session<C: Connection> {
95    inner: Arc<SessionInner<C>>,
96}
97
98impl<C: Connection> Clone for Session<C> {
99    fn clone(&self) -> Self {
100        Self {
101            inner: Arc::clone(&self.inner),
102        }
103    }
104}
105
106impl<C: Connection> Session<C> {
107    /// Creates a new session wrapping the given connection.
108    ///
109    /// The session starts in the `Init` state. Call [`start`](Self::start)
110    /// to begin negotiation.
111    #[must_use]
112    pub fn new(connection: C, role: Role, config: Config) -> Self {
113        let registry =
114            StreamRegistry::new(config.max_inflight_opens, config.max_concurrent_streams);
115
116        debug!(
117            %role,
118            max_inflight = config.max_inflight_opens,
119            max_concurrent = config.max_concurrent_streams,
120            "session created"
121        );
122
123        Self {
124            inner: Arc::new(SessionInner {
125                connection,
126                config,
127                role,
128                state: AtomicU8::new(State::Init as u8),
129                negotiated: Mutex::new(None),
130                registry: Mutex::new(registry),
131                next_ping_seq: AtomicU64::new(1),
132                pending_pings: Mutex::new(HashMap::new()),
133            }),
134        }
135    }
136
137    /// Returns the current session state.
138    #[must_use]
139    pub fn state(&self) -> State {
140        State::from_u8(self.inner.state.load(Ordering::SeqCst))
141    }
142
143    /// Returns the session role.
144    #[must_use]
145    pub fn role(&self) -> Role {
146        self.inner.role
147    }
148
149    /// Returns the negotiated parameters, if negotiation has completed.
150    #[must_use]
151    pub fn negotiated_params(&self) -> Option<NegotiatedParams> {
152        self.inner.negotiated.lock().unwrap().clone()
153    }
154
155    /// Returns true if the session is ready for stream operations.
156    #[must_use]
157    pub fn is_ready(&self) -> bool {
158        self.state() == State::Ready
159    }
160
161    /// Returns true if the connection was lost.
162    #[must_use]
163    pub fn is_disconnected(&self) -> bool {
164        self.state() == State::Disconnected
165    }
166
167    /// Returns a reference to the underlying connection.
168    #[must_use]
169    pub fn connection(&self) -> &C {
170        &self.inner.connection
171    }
172
173    /// Starts the session by performing negotiation.
174    ///
175    /// This opens the control stream and performs the `Hello`/`HelloAck`
176    /// handshake with the peer. On success, the session transitions
177    /// to the `Ready` state.
178    ///
179    /// # Errors
180    ///
181    /// Returns an error if:
182    /// - The control stream cannot be opened/accepted
183    /// - Negotiation fails (version mismatch, timeout, etc.)
184    /// - The session is not in the `Init` state
185    #[instrument(skip(self), fields(role = %self.inner.role))]
186    pub async fn start(&self) -> Result<SessionHandle<C>, Error> {
187        // Validate we're in Init state
188        if self.state() != State::Init {
189            warn!(state = %self.state(), "cannot start session in non-init state");
190            return Err(Error::protocol_violation(format!(
191                "cannot start session in {} state",
192                self.state()
193            )));
194        }
195
196        // Validate config
197        self.inner.config.validate()?;
198
199        // Transition to Negotiating
200        self.set_state(State::Negotiating);
201        debug!("transitioning to negotiating state");
202
203        // Open or accept the control stream based on role
204        let (control_send, control_recv) = match self.inner.role {
205            Role::Client => {
206                // Client opens the control stream
207                debug!("opening control stream");
208                self.inner.connection.open_bi().await.map_err(|e| {
209                    error!(error = %e, "failed to open control stream");
210                    Error::Transport(Box::new(e))
211                })?
212            }
213            Role::Server => {
214                // Server accepts the control stream
215                debug!("waiting for control stream");
216                self.inner
217                    .connection
218                    .accept_bi()
219                    .await
220                    .map_err(|e| {
221                        error!(error = %e, "failed to accept control stream");
222                        Error::Transport(Box::new(e))
223                    })?
224                    .ok_or_else(|| {
225                        error!("connection closed before control stream");
226                        Error::protocol_violation("connection closed before control stream")
227                    })?
228            }
229        };
230
231        let mut control = ControlStream::new(control_send, control_recv);
232        debug!("control stream established");
233
234        // Perform negotiation with timeout
235        let negotiation_timeout = self.inner.config.negotiation_timeout;
236        debug!(?negotiation_timeout, "starting negotiation");
237
238        let negotiate_result = match self.inner.role {
239            Role::Client => {
240                timeout(
241                    negotiation_timeout,
242                    negotiate_client(&mut control, &self.inner.config),
243                )
244                .await
245            }
246            Role::Server => {
247                timeout(
248                    negotiation_timeout,
249                    negotiate_server(&mut control, &self.inner.config),
250                )
251                .await
252            }
253        };
254
255        let params = if let Ok(result) = negotiate_result {
256            result?
257        } else {
258            warn!("negotiation timed out");
259            self.set_state(State::Closed);
260            return Err(Error::Timeout(TimeoutKind::Negotiation));
261        };
262
263        // Store negotiated parameters
264        info!(
265            version = params.version,
266            features = ?params.features,
267            remote_agent = ?params.remote_agent,
268            "negotiation complete"
269        );
270        *self.inner.negotiated.lock().unwrap() = Some(params);
271
272        // Transition to Ready
273        self.set_state(State::Ready);
274        info!("session ready");
275
276        // Split the control stream for the session handle
277        let (writer, reader) = control.split();
278
279        Ok(SessionHandle {
280            inner: Arc::clone(&self.inner),
281            writer,
282            reader,
283        })
284    }
285
286    /// Sets the session state.
287    fn set_state(&self, state: State) {
288        self.inner.state.store(state as u8, Ordering::SeqCst);
289    }
290}
291
292/// Active session handle with control stream access.
293///
294/// This handle is returned from [`Session::start`] and provides
295/// methods for stream operations and message processing.
296///
297/// For a more convenient API that supports concurrent operations,
298/// wrap this handle in a [`SessionClient`](crate::SessionClient) using
299/// [`SessionClient::new`](crate::SessionClient::new).
300pub struct SessionHandle<C: Connection> {
301    pub(crate) inner: Arc<SessionInner<C>>,
302    pub(crate) writer: ControlWriter<C::SendStream>,
303    pub(crate) reader: ControlReader<C::RecvStream>,
304}
305
306impl<C: Connection> SessionHandle<C> {
307    /// Returns the session state.
308    #[must_use]
309    pub fn state(&self) -> State {
310        State::from_u8(self.inner.state.load(Ordering::SeqCst))
311    }
312
313    /// Returns the negotiated parameters.
314    #[must_use]
315    pub fn negotiated_params(&self) -> Option<NegotiatedParams> {
316        self.inner.negotiated.lock().unwrap().clone()
317    }
318
319    /// Returns true if the session is ready for stream operations.
320    #[must_use]
321    pub fn is_ready(&self) -> bool {
322        self.state() == State::Ready
323    }
324
325    /// Returns true if the connection was lost.
326    #[must_use]
327    pub fn is_disconnected(&self) -> bool {
328        self.state() == State::Disconnected
329    }
330
331    /// Opens a reverse stream to the peer.
332    ///
333    /// Sends an `OpenRequest` for the specified service and waits for
334    /// the peer to accept and bind the stream.
335    ///
336    /// # Arguments
337    ///
338    /// * `service` - The service identifier
339    /// * `metadata` - Optional metadata to send with the request
340    ///
341    /// # Errors
342    ///
343    /// Returns an error if:
344    /// - The session is not ready
345    /// - The request limit has been reached
346    /// - The peer rejects the request
347    /// - The request times out
348    #[instrument(skip(self, metadata), fields(service = %service.as_ref()))]
349    pub async fn open(
350        &mut self,
351        service: impl Into<ServiceId> + AsRef<str>,
352        metadata: Metadata,
353    ) -> Result<(C::SendStream, C::RecvStream), Error> {
354        if !self.is_ready() {
355            warn!("cannot open stream: session not ready");
356            return Err(Error::SessionClosed);
357        }
358
359        let service = service.into();
360
361        // Generate request ID and create pending entry
362        let (response_tx, response_rx) = oneshot::channel();
363        let request_id = {
364            let mut registry = self.inner.registry.lock().unwrap();
365            let request_id = registry.next_request_id();
366            let request =
367                OpenRequest::new(request_id, service.clone()).with_metadata(metadata.clone());
368
369            if registry.register_pending(&request, response_tx).is_none() {
370                warn!(
371                    request_id,
372                    "capacity exceeded: too many pending open requests"
373                );
374                return Err(Error::CapacityExceeded("too many pending open requests"));
375            }
376
377            request_id
378        };
379
380        debug!(request_id, service = %service.as_str(), "sending open request");
381
382        // Send the open request
383        let request = OpenRequest::new(request_id, service).with_metadata(metadata);
384        self.writer
385            .write_message(&ProtocolMessage::OpenRequest(request))
386            .await?;
387        self.writer.flush().await?;
388
389        // Wait for the response with timeout
390        let open_timeout = self.inner.config.open_timeout;
391        let result = match timeout(open_timeout, response_rx).await {
392            Ok(Ok(result)) => result,
393            Ok(Err(_)) => {
394                // Channel closed - session closed
395                // Clean up the pending entry
396                warn!(request_id, "session closed while waiting for response");
397                let mut registry = self.inner.registry.lock().unwrap();
398                registry.take_pending(request_id);
399                return Err(Error::SessionClosed);
400            }
401            Err(_) => {
402                // Timeout - clean up the pending entry
403                warn!(request_id, ?open_timeout, "open request timed out");
404                let mut registry = self.inner.registry.lock().unwrap();
405                registry.take_pending(request_id);
406                return Err(Error::Timeout(TimeoutKind::OpenRequest));
407            }
408        };
409
410        match result {
411            OpenResult::Accepted { logical_stream_id } => {
412                debug!(request_id, logical_stream_id, "open request accepted");
413
414                // Accept the data stream with timeout
415                let bind_timeout = self.inner.config.stream_bind_timeout;
416                let stream_result = timeout(bind_timeout, self.inner.connection.accept_bi()).await;
417
418                let (send, recv) = match stream_result {
419                    Ok(Ok(Some(streams))) => streams,
420                    Ok(Ok(None)) => {
421                        error!(request_id, "connection closed while waiting for stream");
422                        return Err(Error::protocol_violation(
423                            "connection closed while waiting for stream",
424                        ));
425                    }
426                    Ok(Err(e)) => {
427                        error!(request_id, error = %e, "transport error while binding stream");
428                        return Err(Error::Transport(Box::new(e)));
429                    }
430                    Err(_) => {
431                        warn!(request_id, ?bind_timeout, "stream bind timed out");
432                        return Err(Error::Timeout(TimeoutKind::StreamBind));
433                    }
434                };
435
436                // Register the active stream
437                {
438                    let mut registry = self.inner.registry.lock().unwrap();
439                    registry.register_active(
440                        logical_stream_id,
441                        ServiceId::from(""),
442                        Metadata::Empty,
443                        request_id,
444                    );
445                }
446
447                info!(request_id, logical_stream_id, "stream opened successfully");
448                Ok((send, recv))
449            }
450            OpenResult::Rejected { code, reason } => {
451                warn!(request_id, ?code, ?reason, "open request rejected");
452                Err(Error::StreamRejected { code, reason })
453            }
454        }
455    }
456
457    /// Processes the next incoming control message.
458    ///
459    /// This should be called in a loop to handle incoming messages
460    /// from the peer. Returns `None` when the control stream closes.
461    ///
462    /// # Errors
463    ///
464    /// Returns an error if reading from the control stream fails.
465    pub async fn process_message(&mut self) -> Result<Option<ControlEvent>, Error> {
466        let Some(message) = self.reader.read_message().await? else {
467            debug!("control stream closed");
468            return Ok(None);
469        };
470
471        match message {
472            ProtocolMessage::OpenRequest(req) => {
473                // Peer wants to open a stream to us
474                debug!(
475                    request_id = req.request_id,
476                    service = %req.service.as_str(),
477                    "received open request"
478                );
479                Ok(Some(ControlEvent::OpenRequest {
480                    request_id: req.request_id,
481                    service: req.service,
482                    metadata: req.metadata,
483                }))
484            }
485
486            ProtocolMessage::OpenResponse(resp) => {
487                // Response to one of our open requests
488                let accepted = matches!(resp.status, OpenStatus::Accepted);
489                debug!(
490                    request_id = resp.request_id,
491                    accepted,
492                    logical_stream_id = ?resp.logical_stream_id,
493                    "received open response"
494                );
495                let mut registry = self.inner.registry.lock().unwrap();
496                if let Some(pending) = registry.take_pending(resp.request_id) {
497                    let result = match resp.status {
498                        OpenStatus::Accepted => OpenResult::Accepted {
499                            logical_stream_id: resp.logical_stream_id.unwrap_or(0),
500                        },
501                        OpenStatus::Rejected(code) => OpenResult::Rejected {
502                            code,
503                            reason: resp.reason,
504                        },
505                    };
506                    let _ = pending.response_tx.send(result);
507                }
508                Ok(Some(ControlEvent::OpenResponseReceived {
509                    request_id: resp.request_id,
510                    accepted,
511                }))
512            }
513
514            ProtocolMessage::Ping(ping_msg) => {
515                // Auto-respond with Pong
516                trace!(sequence = ping_msg.sequence, "received ping, sending pong");
517                let pong_msg = quic_reverse_control::Pong {
518                    sequence: ping_msg.sequence,
519                };
520                self.writer
521                    .write_message(&ProtocolMessage::Pong(pong_msg))
522                    .await?;
523                self.writer.flush().await?;
524                Ok(Some(ControlEvent::Ping {
525                    sequence: ping_msg.sequence,
526                }))
527            }
528
529            ProtocolMessage::Pong(pong) => {
530                // Resolve pending ping if any
531                trace!(sequence = pong.sequence, "received pong");
532                let mut pending = self.inner.pending_pings.lock().unwrap();
533                if let Some(pending_ping) = pending.remove(&pong.sequence) {
534                    let rtt = pending_ping.sent_at.elapsed();
535                    trace!(sequence = pong.sequence, ?rtt, "ping resolved");
536                    let _ = pending_ping.response_tx.send(());
537                }
538                Ok(Some(ControlEvent::Pong {
539                    sequence: pong.sequence,
540                }))
541            }
542
543            ProtocolMessage::Hello(_) | ProtocolMessage::HelloAck(_) => {
544                // These should only appear during negotiation
545                warn!("received unexpected Hello/HelloAck after negotiation");
546                Err(Error::protocol_violation(
547                    "unexpected Hello/HelloAck after negotiation",
548                ))
549            }
550
551            ProtocolMessage::StreamClose(sc) => {
552                // logical_stream_id 0 indicates session-level close
553                if sc.logical_stream_id == 0 {
554                    info!(code = ?sc.code, reason = ?sc.reason, "received session close");
555                    self.set_state(State::Closing);
556                    Ok(Some(ControlEvent::CloseReceived {
557                        code: sc.code,
558                        reason: sc.reason,
559                    }))
560                } else {
561                    debug!(
562                        logical_stream_id = sc.logical_stream_id,
563                        code = ?sc.code,
564                        "received stream close"
565                    );
566                    Ok(Some(ControlEvent::StreamClose {
567                        logical_stream_id: sc.logical_stream_id,
568                        code: sc.code,
569                    }))
570                }
571            }
572        }
573    }
574
575    /// Sends an `OpenResponse` accepting a stream request.
576    ///
577    /// # Errors
578    ///
579    /// Returns an error if sending the response fails.
580    #[instrument(skip(self))]
581    pub async fn accept_open(
582        &mut self,
583        request_id: u64,
584        logical_stream_id: u64,
585    ) -> Result<(), Error> {
586        debug!(request_id, logical_stream_id, "accepting open request");
587        let response = OpenResponse::accepted(request_id, logical_stream_id);
588        self.writer
589            .write_message(&ProtocolMessage::OpenResponse(response))
590            .await?;
591        self.writer.flush().await
592    }
593
594    /// Sends an `OpenResponse` rejecting a stream request.
595    ///
596    /// # Errors
597    ///
598    /// Returns an error if sending the response fails.
599    #[instrument(skip(self))]
600    pub async fn reject_open(
601        &mut self,
602        request_id: u64,
603        code: RejectCode,
604        reason: Option<String>,
605    ) -> Result<(), Error> {
606        debug!(request_id, ?code, ?reason, "rejecting open request");
607        let response = OpenResponse::rejected(request_id, code, reason);
608        self.writer
609            .write_message(&ProtocolMessage::OpenResponse(response))
610            .await?;
611        self.writer.flush().await
612    }
613
614    /// Notifies the peer that a stream has been closed.
615    ///
616    /// This should be called when the application finishes with a stream
617    /// to inform the peer. The peer will receive a `StreamClose` event.
618    ///
619    /// # Errors
620    ///
621    /// Returns an error if the session is closed or sending the message fails.
622    #[instrument(skip(self))]
623    pub async fn close_stream(
624        &mut self,
625        logical_stream_id: u64,
626        code: CloseCode,
627        reason: Option<String>,
628    ) -> Result<(), Error> {
629        if !self.is_ready() {
630            warn!(logical_stream_id, "cannot close stream: session not ready");
631            return Err(Error::SessionClosed);
632        }
633
634        debug!(logical_stream_id, ?code, ?reason, "closing stream");
635
636        // Remove from registry
637        {
638            let mut registry = self.inner.registry.lock().unwrap();
639            registry.remove_active(logical_stream_id);
640        }
641
642        let close_msg = StreamClose {
643            logical_stream_id,
644            code,
645            reason,
646        };
647        self.writer
648            .write_message(&ProtocolMessage::StreamClose(close_msg))
649            .await?;
650        self.writer.flush().await
651    }
652
653    /// Sends a ping and waits for the pong response.
654    ///
655    /// This can be used to check if the peer is still responsive and to
656    /// measure round-trip latency. Returns the round-trip time on success.
657    ///
658    /// # Errors
659    ///
660    /// Returns `Error::Timeout(TimeoutKind::Ping)` if no pong is received
661    /// within the configured `ping_timeout`.
662    #[instrument(skip(self))]
663    pub async fn ping(&mut self) -> Result<std::time::Duration, Error> {
664        if !self.is_ready() {
665            warn!("cannot ping: session not ready");
666            return Err(Error::SessionClosed);
667        }
668
669        // Generate sequence number
670        let sequence = self.inner.next_ping_seq.fetch_add(1, Ordering::SeqCst);
671        trace!(sequence, "sending ping");
672
673        // Create response channel
674        let (response_tx, response_rx) = oneshot::channel();
675        let sent_at = Instant::now();
676
677        // Register pending ping
678        {
679            let mut pending = self.inner.pending_pings.lock().unwrap();
680            pending.insert(
681                sequence,
682                PendingPing {
683                    sent_at,
684                    response_tx,
685                },
686            );
687        }
688
689        // Send the ping
690        let ping_msg = quic_reverse_control::Ping { sequence };
691        self.writer
692            .write_message(&ProtocolMessage::Ping(ping_msg))
693            .await?;
694        self.writer.flush().await?;
695
696        // Wait for pong with timeout
697        let ping_timeout = self.inner.config.ping_timeout;
698        match timeout(ping_timeout, response_rx).await {
699            Ok(Ok(())) => {
700                let rtt = sent_at.elapsed();
701                debug!(sequence, ?rtt, "ping completed");
702                Ok(rtt)
703            }
704            Ok(Err(_)) => {
705                // Channel closed - session closed
706                warn!(sequence, "session closed while waiting for pong");
707                Err(Error::SessionClosed)
708            }
709            Err(_) => {
710                // Timeout - clean up pending ping
711                warn!(sequence, ?ping_timeout, "ping timed out");
712                let mut pending = self.inner.pending_pings.lock().unwrap();
713                pending.remove(&sequence);
714                Err(Error::Timeout(TimeoutKind::Ping))
715            }
716        }
717    }
718
719    /// Closes the session.
720    ///
721    /// Sends a `StreamClose` message with `logical_stream_id` 0 to indicate
722    /// session close, and transitions to the `Closing` state.
723    ///
724    /// # Errors
725    ///
726    /// Returns an error if the session is already closed or sending the message fails.
727    #[instrument(skip(self))]
728    pub async fn close(&mut self, code: CloseCode, reason: Option<String>) -> Result<(), Error> {
729        if !self.is_ready() && self.state() != State::Closing {
730            warn!("cannot close: session already closed");
731            return Err(Error::SessionClosed);
732        }
733
734        info!(?code, ?reason, "closing session");
735        self.set_state(State::Closing);
736
737        // Use logical_stream_id 0 to indicate session-level close
738        let close_msg = StreamClose {
739            logical_stream_id: 0,
740            code,
741            reason,
742        };
743        self.writer
744            .write_message(&ProtocolMessage::StreamClose(close_msg))
745            .await?;
746        self.writer.flush().await
747    }
748
749    /// Sets the session state.
750    fn set_state(&self, state: State) {
751        self.inner.state.store(state as u8, Ordering::SeqCst);
752    }
753}
754
755/// Events that can occur on the control stream.
756#[derive(Debug, Clone)]
757pub enum ControlEvent {
758    /// Peer requested to open a stream.
759    OpenRequest {
760        /// The request ID.
761        request_id: u64,
762        /// The requested service.
763        service: ServiceId,
764        /// Metadata from the request.
765        metadata: Metadata,
766    },
767    /// Response to our open request was received.
768    OpenResponseReceived {
769        /// The request ID this is responding to.
770        request_id: u64,
771        /// Whether the request was accepted.
772        accepted: bool,
773    },
774    /// Peer initiated close.
775    CloseReceived {
776        /// The close code.
777        code: CloseCode,
778        /// Optional reason string.
779        reason: Option<String>,
780    },
781    /// Ping received (pong auto-sent).
782    Ping {
783        /// The ping sequence number.
784        sequence: u64,
785    },
786    /// Pong received in response to our ping.
787    Pong {
788        /// The pong sequence number.
789        sequence: u64,
790    },
791    /// Stream close notification.
792    StreamClose {
793        /// The logical stream ID being closed.
794        logical_stream_id: u64,
795        /// The close code.
796        code: CloseCode,
797    },
798}
799
800#[cfg(test)]
801mod tests {
802    use super::*;
803    use quic_reverse_control::Features;
804    use quic_reverse_transport::mock_connection_pair;
805
806    #[tokio::test]
807    async fn session_creation() {
808        let (conn_client, _conn_server) = mock_connection_pair();
809
810        let config = Config::new()
811            .with_features(Features::PING_PONG)
812            .with_agent("test/1.0");
813
814        let session = Session::new(conn_client, Role::Client, config);
815
816        assert_eq!(session.state(), State::Init);
817        assert_eq!(session.role(), Role::Client);
818        assert!(session.negotiated_params().is_none());
819    }
820
821    #[tokio::test]
822    async fn session_start_and_negotiate() {
823        let (conn_client, conn_server) = mock_connection_pair();
824
825        let client_config = Config::new()
826            .with_features(Features::PING_PONG)
827            .with_agent("client/1.0");
828
829        let server_config = Config::new()
830            .with_features(Features::PING_PONG)
831            .with_agent("server/1.0");
832
833        let client_session = Session::new(conn_client, Role::Client, client_config);
834        let server_session = Session::new(conn_server, Role::Server, server_config);
835
836        // Keep references for later assertions
837        let client_session_ref = client_session.clone();
838        let server_session_ref = server_session.clone();
839
840        // Start both sessions concurrently
841        let client_handle = tokio::spawn(async move { client_session.start().await });
842        let server_handle = tokio::spawn(async move { server_session.start().await });
843
844        // Wait for both to complete
845        let client_result = client_handle.await.expect("client task");
846        let server_result = server_handle.await.expect("server task");
847
848        // Both should succeed
849        assert!(client_result.is_ok(), "client failed");
850        assert!(server_result.is_ok(), "server failed");
851
852        // Both should be in Ready state
853        assert_eq!(client_session_ref.state(), State::Ready);
854        assert_eq!(server_session_ref.state(), State::Ready);
855
856        // Both should have negotiated params
857        let client_params = client_session_ref
858            .negotiated_params()
859            .expect("client params");
860        let server_params = server_session_ref
861            .negotiated_params()
862            .expect("server params");
863
864        assert_eq!(client_params.version, server_params.version);
865        assert_eq!(client_params.features, Features::PING_PONG);
866
867        // Should see each other's agent strings
868        assert_eq!(client_params.remote_agent.as_deref(), Some("server/1.0"));
869        assert_eq!(server_params.remote_agent.as_deref(), Some("client/1.0"));
870    }
871
872    #[tokio::test]
873    async fn cannot_start_twice() {
874        let (conn_client, conn_server) = mock_connection_pair();
875
876        let client_session = Session::new(conn_client, Role::Client, Config::new());
877        let server_session = Session::new(conn_server, Role::Server, Config::new());
878
879        // Keep reference for later
880        let client_session_ref = client_session.clone();
881
882        // Start both
883        let client_handle = tokio::spawn(async move { client_session.start().await });
884        let server_handle = tokio::spawn(async move { server_session.start().await });
885
886        let _ = client_handle.await;
887        let _ = server_handle.await;
888
889        // Try to start again - should fail
890        let result = client_session_ref.start().await;
891        assert!(result.is_err());
892    }
893
894    #[tokio::test]
895    async fn ping_pong_exchange() {
896        let (conn_client, conn_server) = mock_connection_pair();
897
898        let config = Config::new().with_features(Features::PING_PONG);
899
900        let client_session = Session::new(conn_client, Role::Client, config.clone());
901        let server_session = Session::new(conn_server, Role::Server, config);
902
903        // Start both and get handles
904        let client_start = tokio::spawn(async move { client_session.start().await });
905        let server_start = tokio::spawn(async move { server_session.start().await });
906
907        let mut client_handle = client_start.await.unwrap().unwrap();
908        let mut server_handle = server_start.await.unwrap().unwrap();
909
910        // Send a ping from client
911        let ping = quic_reverse_control::Ping { sequence: 42 };
912        client_handle
913            .writer
914            .write_message(&ProtocolMessage::Ping(ping))
915            .await
916            .unwrap();
917        client_handle.writer.flush().await.unwrap();
918
919        // Server should receive it and auto-respond
920        let event = server_handle.process_message().await.unwrap().unwrap();
921        assert!(matches!(event, ControlEvent::Ping { sequence: 42 }));
922
923        // Client should receive the pong
924        let event = client_handle.process_message().await.unwrap().unwrap();
925        assert!(matches!(event, ControlEvent::Pong { sequence: 42 }));
926    }
927
928    #[tokio::test]
929    async fn close_session() {
930        let (conn_client, conn_server) = mock_connection_pair();
931
932        let client_session = Session::new(conn_client, Role::Client, Config::new());
933        let server_session = Session::new(conn_server, Role::Server, Config::new());
934
935        // Keep references
936        let client_session_ref = client_session.clone();
937        let server_session_ref = server_session.clone();
938
939        // Start both
940        let client_start = tokio::spawn(async move { client_session.start().await });
941        let server_start = tokio::spawn(async move { server_session.start().await });
942
943        let mut client_handle = client_start.await.unwrap().unwrap();
944        let mut server_handle = server_start.await.unwrap().unwrap();
945
946        // Client initiates close
947        client_handle
948            .close(CloseCode::Normal, Some("goodbye".into()))
949            .await
950            .unwrap();
951
952        // Server receives close event
953        let event = server_handle.process_message().await.unwrap().unwrap();
954        match event {
955            ControlEvent::CloseReceived { code, reason } => {
956                assert_eq!(code, CloseCode::Normal);
957                assert_eq!(reason.as_deref(), Some("goodbye"));
958            }
959            _ => panic!("expected CloseReceived"),
960        }
961
962        assert_eq!(client_session_ref.state(), State::Closing);
963        assert_eq!(server_session_ref.state(), State::Closing);
964    }
965
966    #[tokio::test]
967    async fn stream_open_and_accept() {
968        use tokio::io::{AsyncReadExt, AsyncWriteExt};
969        use tokio::sync::mpsc;
970
971        let (conn_client, conn_server) = mock_connection_pair();
972
973        let client_session = Session::new(conn_client, Role::Client, Config::new());
974        let server_session = Session::new(conn_server, Role::Server, Config::new());
975
976        // Start both sessions
977        let client_start = tokio::spawn(async move { client_session.start().await });
978        let server_start = tokio::spawn(async move { server_session.start().await });
979
980        let client_handle = client_start.await.unwrap().unwrap();
981        let mut server_handle = server_start.await.unwrap().unwrap();
982
983        // Split client handle for concurrent message processing and open
984        // We need channels to coordinate between the two client tasks
985        let (open_done_tx, mut open_done_rx) = mpsc::channel::<(
986            quic_reverse_transport::MockSendStream,
987            quic_reverse_transport::MockRecvStream,
988        )>(1);
989
990        // Client: Spawn message processor that will receive the OpenResponse
991        let client_inner = Arc::clone(&client_handle.inner);
992        let mut client_reader = client_handle.reader;
993        let client_msg_processor = tokio::spawn(async move {
994            // Wait for OpenResponse
995            let msg = client_reader.read_message().await.unwrap().unwrap();
996            if let ProtocolMessage::OpenResponse(resp) = msg {
997                let accepted = matches!(resp.status, OpenStatus::Accepted);
998                let mut registry = client_inner.registry.lock().unwrap();
999                if let Some(pending) = registry.take_pending(resp.request_id) {
1000                    let result = match resp.status {
1001                        OpenStatus::Accepted => OpenResult::Accepted {
1002                            logical_stream_id: resp.logical_stream_id.unwrap_or(0),
1003                        },
1004                        OpenStatus::Rejected(code) => OpenResult::Rejected {
1005                            code,
1006                            reason: resp.reason,
1007                        },
1008                    };
1009                    let _ = pending.response_tx.send(result);
1010                }
1011                accepted
1012            } else {
1013                panic!("expected OpenResponse");
1014            }
1015        });
1016
1017        // Client: Spawn the open request
1018        let client_inner2 = Arc::clone(&client_handle.inner);
1019        let mut client_writer = client_handle.writer;
1020        let client_open = tokio::spawn(async move {
1021            // Generate request ID and create pending entry
1022            let (response_tx, response_rx) = oneshot::channel();
1023            let request_id = {
1024                let mut registry = client_inner2.registry.lock().unwrap();
1025                let request_id = registry.next_request_id();
1026                let request = OpenRequest::new(request_id, "ssh").with_metadata(Metadata::Empty);
1027                registry.register_pending(&request, response_tx).unwrap();
1028                request_id
1029            };
1030
1031            // Send the open request
1032            let request = OpenRequest::new(request_id, "ssh").with_metadata(Metadata::Empty);
1033            client_writer
1034                .write_message(&ProtocolMessage::OpenRequest(request))
1035                .await
1036                .unwrap();
1037            client_writer.flush().await.unwrap();
1038
1039            // Wait for the response (will be delivered by the message processor)
1040            let result = response_rx.await.unwrap();
1041
1042            match result {
1043                OpenResult::Accepted { .. } => {
1044                    // Accept the data stream
1045                    let (send, recv) = client_inner2.connection.accept_bi().await.unwrap().unwrap();
1046                    open_done_tx.send((send, recv)).await.unwrap();
1047                }
1048                OpenResult::Rejected { code, reason } => {
1049                    panic!("rejected: {code:?} {reason:?}");
1050                }
1051            }
1052        });
1053
1054        // Server: Process the open request
1055        let event = server_handle.process_message().await.unwrap().unwrap();
1056        let (request_id, service) = match event {
1057            ControlEvent::OpenRequest {
1058                request_id,
1059                service,
1060                ..
1061            } => (request_id, service),
1062            _ => panic!("expected OpenRequest, got {event:?}"),
1063        };
1064        assert_eq!(service.as_str(), "ssh");
1065
1066        // Server: Accept the request
1067        let logical_stream_id = 1;
1068        server_handle
1069            .accept_open(request_id, logical_stream_id)
1070            .await
1071            .unwrap();
1072
1073        // Server: Open the data stream back to client
1074        let (mut server_send, mut server_recv) =
1075            server_handle.inner.connection.open_bi().await.unwrap();
1076
1077        // Wait for client tasks to complete
1078        client_msg_processor.await.unwrap();
1079        client_open.await.unwrap();
1080
1081        // Get the client streams
1082        let (mut client_send, mut client_recv) = open_done_rx.recv().await.unwrap();
1083
1084        // Exchange data bidirectionally
1085        server_send.write_all(b"hello from server").await.unwrap();
1086        server_send.flush().await.unwrap();
1087
1088        let mut buf = [0u8; 32];
1089        let n = client_recv.read(&mut buf).await.unwrap();
1090        assert_eq!(&buf[..n], b"hello from server");
1091
1092        client_send.write_all(b"hello from client").await.unwrap();
1093        client_send.flush().await.unwrap();
1094
1095        let n = server_recv.read(&mut buf).await.unwrap();
1096        assert_eq!(&buf[..n], b"hello from client");
1097    }
1098
1099    #[tokio::test]
1100    async fn stream_open_rejected() {
1101        use tokio::sync::mpsc;
1102
1103        let (conn_client, conn_server) = mock_connection_pair();
1104
1105        let client_session = Session::new(conn_client, Role::Client, Config::new());
1106        let server_session = Session::new(conn_server, Role::Server, Config::new());
1107
1108        // Start both sessions
1109        let client_start = tokio::spawn(async move { client_session.start().await });
1110        let server_start = tokio::spawn(async move { server_session.start().await });
1111
1112        let client_handle = client_start.await.unwrap().unwrap();
1113        let mut server_handle = server_start.await.unwrap().unwrap();
1114
1115        // Channel to receive the rejection result
1116        let (result_tx, mut result_rx) = mpsc::channel::<Result<(), Error>>(1);
1117
1118        // Client: Spawn message processor
1119        let client_inner = Arc::clone(&client_handle.inner);
1120        let mut client_reader = client_handle.reader;
1121        let client_msg_processor = tokio::spawn(async move {
1122            let msg = client_reader.read_message().await.unwrap().unwrap();
1123            if let ProtocolMessage::OpenResponse(resp) = msg {
1124                let mut registry = client_inner.registry.lock().unwrap();
1125                if let Some(pending) = registry.take_pending(resp.request_id) {
1126                    let result = match resp.status {
1127                        OpenStatus::Accepted => OpenResult::Accepted {
1128                            logical_stream_id: resp.logical_stream_id.unwrap_or(0),
1129                        },
1130                        OpenStatus::Rejected(code) => OpenResult::Rejected {
1131                            code,
1132                            reason: resp.reason,
1133                        },
1134                    };
1135                    let _ = pending.response_tx.send(result);
1136                }
1137            }
1138        });
1139
1140        // Client: Spawn the open request
1141        let client_inner2 = Arc::clone(&client_handle.inner);
1142        let mut client_writer = client_handle.writer;
1143        let client_open = tokio::spawn(async move {
1144            let (response_tx, response_rx) = oneshot::channel();
1145            let request_id = {
1146                let mut registry = client_inner2.registry.lock().unwrap();
1147                let request_id = registry.next_request_id();
1148                let request =
1149                    OpenRequest::new(request_id, "unknown").with_metadata(Metadata::Empty);
1150                registry.register_pending(&request, response_tx).unwrap();
1151                request_id
1152            };
1153
1154            let request = OpenRequest::new(request_id, "unknown").with_metadata(Metadata::Empty);
1155            client_writer
1156                .write_message(&ProtocolMessage::OpenRequest(request))
1157                .await
1158                .unwrap();
1159            client_writer.flush().await.unwrap();
1160
1161            let result = response_rx.await.unwrap();
1162            match result {
1163                OpenResult::Accepted { .. } => {
1164                    result_tx.send(Ok(())).await.unwrap();
1165                }
1166                OpenResult::Rejected { code, reason } => {
1167                    result_tx
1168                        .send(Err(Error::StreamRejected { code, reason }))
1169                        .await
1170                        .unwrap();
1171                }
1172            }
1173        });
1174
1175        // Server: Process and reject the request
1176        let event = server_handle.process_message().await.unwrap().unwrap();
1177        let request_id = match event {
1178            ControlEvent::OpenRequest { request_id, .. } => request_id,
1179            _ => panic!("expected OpenRequest"),
1180        };
1181
1182        server_handle
1183            .reject_open(
1184                request_id,
1185                RejectCode::UnsupportedService,
1186                Some("not available".into()),
1187            )
1188            .await
1189            .unwrap();
1190
1191        // Wait for client tasks
1192        client_msg_processor.await.unwrap();
1193        client_open.await.unwrap();
1194
1195        // Get the rejection result
1196        let result = result_rx.recv().await.unwrap();
1197        match result {
1198            Err(Error::StreamRejected { code, reason }) => {
1199                assert_eq!(code, RejectCode::UnsupportedService);
1200                assert_eq!(reason.as_deref(), Some("not available"));
1201            }
1202            other => panic!("expected StreamRejected, got {other:?}"),
1203        }
1204    }
1205
1206    #[tokio::test]
1207    async fn stream_close_notification() {
1208        use tokio::sync::mpsc;
1209
1210        let (conn_client, conn_server) = mock_connection_pair();
1211
1212        let client_session = Session::new(conn_client, Role::Client, Config::new());
1213        let server_session = Session::new(conn_server, Role::Server, Config::new());
1214
1215        // Start both sessions
1216        let client_start = tokio::spawn(async move { client_session.start().await });
1217        let server_start = tokio::spawn(async move { server_session.start().await });
1218
1219        let client_handle = client_start.await.unwrap().unwrap();
1220        let mut server_handle = server_start.await.unwrap().unwrap();
1221
1222        // Set up the stream as before
1223        let (open_done_tx, mut open_done_rx) = mpsc::channel::<u64>(1);
1224
1225        // Client: Message processor
1226        let client_inner = Arc::clone(&client_handle.inner);
1227        let mut client_reader = client_handle.reader;
1228        let client_msg_processor = tokio::spawn(async move {
1229            // First: OpenResponse
1230            let msg = client_reader.read_message().await.unwrap().unwrap();
1231            if let ProtocolMessage::OpenResponse(resp) = msg {
1232                let mut registry = client_inner.registry.lock().unwrap();
1233                if let Some(pending) = registry.take_pending(resp.request_id) {
1234                    let result = match resp.status {
1235                        OpenStatus::Accepted => OpenResult::Accepted {
1236                            logical_stream_id: resp.logical_stream_id.unwrap_or(0),
1237                        },
1238                        OpenStatus::Rejected(code) => OpenResult::Rejected {
1239                            code,
1240                            reason: resp.reason,
1241                        },
1242                    };
1243                    let _ = pending.response_tx.send(result);
1244                }
1245            }
1246
1247            // Second: StreamClose
1248            let msg = client_reader.read_message().await.unwrap().unwrap();
1249            if let ProtocolMessage::StreamClose(sc) = msg {
1250                (sc.logical_stream_id, sc.code, sc.reason)
1251            } else {
1252                panic!("expected StreamClose");
1253            }
1254        });
1255
1256        // Client: Open request
1257        let client_inner2 = Arc::clone(&client_handle.inner);
1258        let mut client_writer = client_handle.writer;
1259        let client_open = tokio::spawn(async move {
1260            let (response_tx, response_rx) = oneshot::channel();
1261            let request_id = {
1262                let mut registry = client_inner2.registry.lock().unwrap();
1263                let request_id = registry.next_request_id();
1264                let request = OpenRequest::new(request_id, "ssh").with_metadata(Metadata::Empty);
1265                registry.register_pending(&request, response_tx).unwrap();
1266                request_id
1267            };
1268
1269            let request = OpenRequest::new(request_id, "ssh").with_metadata(Metadata::Empty);
1270            client_writer
1271                .write_message(&ProtocolMessage::OpenRequest(request))
1272                .await
1273                .unwrap();
1274            client_writer.flush().await.unwrap();
1275
1276            let result = response_rx.await.unwrap();
1277            if let OpenResult::Accepted { logical_stream_id } = result {
1278                open_done_tx.send(logical_stream_id).await.unwrap();
1279            }
1280        });
1281
1282        // Server: Process open request
1283        let event = server_handle.process_message().await.unwrap().unwrap();
1284        let request_id = match event {
1285            ControlEvent::OpenRequest { request_id, .. } => request_id,
1286            _ => panic!("expected OpenRequest"),
1287        };
1288
1289        // Server: Accept
1290        let logical_stream_id = 42;
1291        server_handle
1292            .accept_open(request_id, logical_stream_id)
1293            .await
1294            .unwrap();
1295
1296        // Wait for client to receive the stream ID
1297        client_open.await.unwrap();
1298        let received_id = open_done_rx.recv().await.unwrap();
1299        assert_eq!(received_id, logical_stream_id);
1300
1301        // Server: Close the stream
1302        server_handle
1303            .close_stream(logical_stream_id, CloseCode::Normal, Some("done".into()))
1304            .await
1305            .unwrap();
1306
1307        // Client should receive the close notification
1308        let (close_id, close_code, close_reason) = client_msg_processor.await.unwrap();
1309        assert_eq!(close_id, logical_stream_id);
1310        assert_eq!(close_code, CloseCode::Normal);
1311        assert_eq!(close_reason.as_deref(), Some("done"));
1312    }
1313
1314    #[tokio::test]
1315    async fn open_respects_inflight_limit() {
1316        let (conn_client, conn_server) = mock_connection_pair();
1317
1318        // Configure with a very low limit
1319        let client_config = Config::new().with_max_inflight_opens(2);
1320        let server_config = Config::new();
1321
1322        let client_session = Session::new(conn_client, Role::Client, client_config);
1323        let server_session = Session::new(conn_server, Role::Server, server_config);
1324
1325        // Start both sessions
1326        let client_start = tokio::spawn(async move { client_session.start().await });
1327        let server_start = tokio::spawn(async move { server_session.start().await });
1328
1329        let client_handle = client_start.await.unwrap().unwrap();
1330        let _server_handle = server_start.await.unwrap().unwrap();
1331
1332        // Access the inner directly to simulate multiple pending opens
1333        let inner = Arc::clone(&client_handle.inner);
1334        let mut writer = client_handle.writer;
1335
1336        // Register two pending opens (filling the limit)
1337        {
1338            let mut registry = inner.registry.lock().unwrap();
1339            let (tx1, _rx1) = oneshot::channel();
1340            let (tx2, _rx2) = oneshot::channel();
1341            let req1 = OpenRequest::new(1, "service1");
1342            let req2 = OpenRequest::new(2, "service2");
1343            assert!(registry.register_pending(&req1, tx1).is_some());
1344            assert!(registry.register_pending(&req2, tx2).is_some());
1345        }
1346
1347        // Try to open a third - should fail with capacity exceeded
1348        let (response_tx, _response_rx) = oneshot::channel();
1349        let result = {
1350            let mut registry = inner.registry.lock().unwrap();
1351            let request_id = registry.next_request_id();
1352            let request = OpenRequest::new(request_id, "service3");
1353            registry.register_pending(&request, response_tx)
1354        };
1355
1356        assert!(result.is_none(), "should fail due to limit");
1357
1358        // Clean up - send a dummy message to prevent hanging
1359        let _ = writer
1360            .write_message(&ProtocolMessage::Ping(quic_reverse_control::Ping {
1361                sequence: 0,
1362            }))
1363            .await;
1364    }
1365
1366    #[tokio::test]
1367    async fn open_request_timeout() {
1368        use std::time::Duration;
1369
1370        let (conn_client, conn_server) = mock_connection_pair();
1371
1372        // Configure with a very short timeout
1373        let client_config = Config::new().with_open_timeout(Duration::from_millis(50));
1374        let server_config = Config::new();
1375
1376        let client_session = Session::new(conn_client, Role::Client, client_config);
1377        let server_session = Session::new(conn_server, Role::Server, server_config);
1378
1379        // Start both sessions
1380        let client_start = tokio::spawn(async move { client_session.start().await });
1381        let server_start = tokio::spawn(async move { server_session.start().await });
1382
1383        let mut client_handle = client_start.await.unwrap().unwrap();
1384        let _server_handle = server_start.await.unwrap().unwrap();
1385
1386        // Client tries to open, but server never responds
1387        let result = client_handle.open("ssh", Metadata::Empty).await;
1388
1389        // Should timeout
1390        match result {
1391            Err(Error::Timeout(TimeoutKind::OpenRequest)) => {}
1392            other => panic!("expected OpenRequest timeout, got: {other:?}"),
1393        }
1394    }
1395
1396    #[tokio::test]
1397    async fn stream_bind_timeout() {
1398        use std::time::Duration;
1399
1400        let (conn_client, conn_server) = mock_connection_pair();
1401
1402        // Configure with a very short bind timeout
1403        let client_config = Config::new()
1404            .with_open_timeout(Duration::from_secs(5))
1405            .with_stream_bind_timeout(Duration::from_millis(50));
1406        let server_config = Config::new();
1407
1408        let client_session = Session::new(conn_client, Role::Client, client_config);
1409        let server_session = Session::new(conn_server, Role::Server, server_config);
1410
1411        // Start both sessions
1412        let client_start = tokio::spawn(async move { client_session.start().await });
1413        let server_start = tokio::spawn(async move { server_session.start().await });
1414
1415        let client_handle = client_start.await.unwrap().unwrap();
1416        let mut server_handle = server_start.await.unwrap().unwrap();
1417
1418        // Split client handle for concurrent operation
1419        let client_inner_open = Arc::clone(&client_handle.inner);
1420        let client_inner_msg = Arc::clone(&client_handle.inner);
1421        let mut client_writer = client_handle.writer;
1422        let mut client_reader = client_handle.reader;
1423
1424        // Client: Open request task
1425        let client_open = tokio::spawn(async move {
1426            // Send open request manually
1427            let (response_tx, response_rx) = oneshot::channel();
1428            let request_id = {
1429                let mut registry = client_inner_open.registry.lock().unwrap();
1430                let request_id = registry.next_request_id();
1431                let request = OpenRequest::new(request_id, "ssh").with_metadata(Metadata::Empty);
1432                registry.register_pending(&request, response_tx).unwrap();
1433                request_id
1434            };
1435
1436            let request = OpenRequest::new(request_id, "ssh").with_metadata(Metadata::Empty);
1437            client_writer
1438                .write_message(&ProtocolMessage::OpenRequest(request))
1439                .await
1440                .unwrap();
1441            client_writer.flush().await.unwrap();
1442
1443            // Wait for the response
1444            let result = response_rx.await.unwrap();
1445
1446            // Try to accept the stream with timeout
1447            match result {
1448                OpenResult::Accepted { .. } => {
1449                    let bind_timeout = Duration::from_millis(50);
1450                    match timeout(bind_timeout, client_inner_open.connection.accept_bi()).await {
1451                        Ok(Ok(Some(streams))) => Ok(streams),
1452                        Ok(Ok(None)) => Err(Error::SessionClosed),
1453                        Ok(Err(e)) => Err(Error::Transport(Box::new(e))),
1454                        Err(_) => Err(Error::Timeout(TimeoutKind::StreamBind)),
1455                    }
1456                }
1457                OpenResult::Rejected { code, reason } => {
1458                    Err(Error::StreamRejected { code, reason })
1459                }
1460            }
1461        });
1462
1463        // Client: Message processor
1464        let client_msg_processor = tokio::spawn(async move {
1465            let msg = client_reader.read_message().await.unwrap().unwrap();
1466            if let ProtocolMessage::OpenResponse(resp) = msg {
1467                let mut registry = client_inner_msg.registry.lock().unwrap();
1468                if let Some(pending) = registry.take_pending(resp.request_id) {
1469                    let result = match resp.status {
1470                        OpenStatus::Accepted => OpenResult::Accepted {
1471                            logical_stream_id: resp.logical_stream_id.unwrap_or(0),
1472                        },
1473                        OpenStatus::Rejected(code) => OpenResult::Rejected {
1474                            code,
1475                            reason: resp.reason,
1476                        },
1477                    };
1478                    let _ = pending.response_tx.send(result);
1479                }
1480            }
1481        });
1482
1483        // Server: Accept the request but DON'T open the data stream
1484        let event = server_handle.process_message().await.unwrap().unwrap();
1485        let request_id = match event {
1486            ControlEvent::OpenRequest { request_id, .. } => request_id,
1487            _ => panic!("expected OpenRequest"),
1488        };
1489
1490        // Accept but don't open data stream - client should timeout on bind
1491        server_handle.accept_open(request_id, 1).await.unwrap();
1492
1493        // Wait for client message processor
1494        let _ = client_msg_processor.await;
1495
1496        // The open task should fail with stream bind timeout
1497        let result = client_open.await.unwrap();
1498        match result {
1499            Err(Error::Timeout(TimeoutKind::StreamBind)) => {}
1500            other => panic!("expected StreamBind timeout, got: {:?}", other),
1501        }
1502    }
1503
1504    #[tokio::test]
1505    async fn negotiation_timeout() {
1506        use std::time::Duration;
1507
1508        let (conn_client, _conn_server) = mock_connection_pair();
1509
1510        // Configure with a very short negotiation timeout
1511        let client_config = Config::new().with_negotiation_timeout(Duration::from_millis(50));
1512
1513        let client_session = Session::new(conn_client, Role::Client, client_config);
1514
1515        // Client tries to start, but server never responds (no server started)
1516        let result = client_session.start().await;
1517
1518        // Should timeout during negotiation
1519        assert!(
1520            matches!(result, Err(Error::Timeout(TimeoutKind::Negotiation))),
1521            "expected Negotiation timeout, got: {:?}",
1522            result.as_ref().map(|_| "Ok(SessionHandle)")
1523        );
1524
1525        // Session should be in Closed state
1526        assert_eq!(client_session.state(), State::Closed);
1527    }
1528
1529    #[tokio::test]
1530    async fn ping_returns_rtt() {
1531        use std::time::Duration;
1532
1533        let (conn_client, conn_server) = mock_connection_pair();
1534
1535        let client_session = Session::new(conn_client, Role::Client, Config::new());
1536        let server_session = Session::new(conn_server, Role::Server, Config::new());
1537
1538        // Start both sessions
1539        let client_start = tokio::spawn(async move { client_session.start().await });
1540        let server_start = tokio::spawn(async move { server_session.start().await });
1541
1542        let client_handle = client_start.await.unwrap().unwrap();
1543        let mut server_handle = server_start.await.unwrap().unwrap();
1544
1545        // Split client handle for concurrent ping and message processing
1546        let client_inner = Arc::clone(&client_handle.inner);
1547        let mut client_writer = client_handle.writer;
1548        let mut client_reader = client_handle.reader;
1549
1550        // Client: Send ping task
1551        let ping_task = tokio::spawn(async move {
1552            // Generate sequence number
1553            let sequence = client_inner.next_ping_seq.fetch_add(1, Ordering::SeqCst);
1554
1555            // Create response channel
1556            let (response_tx, response_rx) = oneshot::channel();
1557            let sent_at = Instant::now();
1558
1559            // Register pending ping
1560            {
1561                let mut pending = client_inner.pending_pings.lock().unwrap();
1562                pending.insert(
1563                    sequence,
1564                    PendingPing {
1565                        sent_at,
1566                        response_tx,
1567                    },
1568                );
1569            }
1570
1571            // Send the ping
1572            let ping_msg = quic_reverse_control::Ping { sequence };
1573            client_writer
1574                .write_message(&ProtocolMessage::Ping(ping_msg))
1575                .await
1576                .unwrap();
1577            client_writer.flush().await.unwrap();
1578
1579            // Wait for pong
1580            response_rx.await.unwrap();
1581            sent_at.elapsed()
1582        });
1583
1584        // Client: Message processor that receives Pong
1585        let client_inner2 = Arc::clone(&client_handle.inner);
1586        let client_msg_processor = tokio::spawn(async move {
1587            let msg = client_reader.read_message().await.unwrap().unwrap();
1588            if let ProtocolMessage::Pong(pong) = msg {
1589                let mut pending = client_inner2.pending_pings.lock().unwrap();
1590                if let Some(pending_ping) = pending.remove(&pong.sequence) {
1591                    let _ = pending_ping.response_tx.send(());
1592                }
1593            }
1594        });
1595
1596        // Server: Process the ping (auto-responds with pong)
1597        let event = server_handle.process_message().await.unwrap().unwrap();
1598        assert!(matches!(event, ControlEvent::Ping { sequence: 1 }));
1599
1600        // Wait for client tasks
1601        let _ = client_msg_processor.await;
1602        let rtt = ping_task.await.unwrap();
1603
1604        // RTT should be positive but small (local mock connection)
1605        assert!(rtt < Duration::from_secs(1));
1606    }
1607
1608    #[tokio::test]
1609    async fn ping_timeout() {
1610        use std::time::Duration;
1611
1612        let (conn_client, conn_server) = mock_connection_pair();
1613
1614        // Configure with a very short ping timeout
1615        let client_config = Config::new().with_ping_timeout(Duration::from_millis(50));
1616        let server_config = Config::new();
1617
1618        let client_session = Session::new(conn_client, Role::Client, client_config);
1619        let server_session = Session::new(conn_server, Role::Server, server_config);
1620
1621        // Start both sessions
1622        let client_start = tokio::spawn(async move { client_session.start().await });
1623        let server_start = tokio::spawn(async move { server_session.start().await });
1624
1625        let mut client_handle = client_start.await.unwrap().unwrap();
1626        let _server_handle = server_start.await.unwrap().unwrap();
1627
1628        // Client sends ping, but server never processes it (no pong)
1629        let result = client_handle.ping().await;
1630
1631        // Should timeout
1632        match result {
1633            Err(Error::Timeout(TimeoutKind::Ping)) => {}
1634            other => panic!("expected Ping timeout, got: {other:?}"),
1635        }
1636    }
1637
1638    /// Stress tests for concurrent stream handling.
1639    mod stress_tests {
1640        use super::*;
1641
1642        /// Tests many sequential open requests to verify registry handling.
1643        #[tokio::test]
1644        async fn many_sequential_opens() {
1645            const NUM_OPENS: usize = 20;
1646
1647            let (conn_client, conn_server) = mock_connection_pair();
1648
1649            let client_config = Config::new().with_max_inflight_opens(NUM_OPENS);
1650            let server_config = Config::new();
1651
1652            let client_session = Session::new(conn_client, Role::Client, client_config);
1653            let server_session = Session::new(conn_server, Role::Server, server_config);
1654
1655            let client_start = tokio::spawn(async move { client_session.start().await });
1656            let server_start = tokio::spawn(async move { server_session.start().await });
1657
1658            let client_handle = client_start.await.unwrap().unwrap();
1659            let mut server_handle = server_start.await.unwrap().unwrap();
1660
1661            let client_inner = Arc::clone(&client_handle.inner);
1662            let mut client_writer = client_handle.writer;
1663            let mut client_reader = client_handle.reader;
1664
1665            // Client: send open requests
1666            let client_sender = tokio::spawn(async move {
1667                for i in 0..NUM_OPENS {
1668                    let (response_tx, _response_rx) = oneshot::channel();
1669                    let request_id = {
1670                        let mut registry = client_inner.registry.lock().unwrap();
1671                        let request_id = registry.next_request_id();
1672                        let request = OpenRequest::new(request_id, format!("service-{i}"));
1673                        let _ = registry.register_pending(&request, response_tx);
1674                        request_id
1675                    };
1676
1677                    let request = OpenRequest::new(request_id, format!("service-{i}"));
1678                    client_writer
1679                        .write_message(&ProtocolMessage::OpenRequest(request))
1680                        .await
1681                        .unwrap();
1682                }
1683                client_writer.flush().await.unwrap();
1684                NUM_OPENS
1685            });
1686
1687            // Server: process and accept all opens
1688            let server_processor = tokio::spawn(async move {
1689                let mut accepted = 0;
1690                for _ in 0..NUM_OPENS {
1691                    match server_handle.process_message().await {
1692                        Ok(Some(ControlEvent::OpenRequest { request_id, .. })) => {
1693                            server_handle
1694                                .accept_open(request_id, accepted as u64)
1695                                .await
1696                                .ok();
1697                            accepted += 1;
1698                        }
1699                        _ => break,
1700                    }
1701                }
1702                accepted
1703            });
1704
1705            // Client: receive responses
1706            let client_receiver = tokio::spawn(async move {
1707                let mut received = 0;
1708                for _ in 0..NUM_OPENS {
1709                    match client_reader.read_message().await {
1710                        Ok(Some(ProtocolMessage::OpenResponse(_))) => received += 1,
1711                        _ => break,
1712                    }
1713                }
1714                received
1715            });
1716
1717            let sent = client_sender.await.unwrap();
1718            let accepted = server_processor.await.unwrap();
1719            let received = client_receiver.await.unwrap();
1720
1721            assert_eq!(sent, NUM_OPENS);
1722            assert_eq!(accepted, NUM_OPENS);
1723            assert_eq!(received, NUM_OPENS);
1724        }
1725
1726        /// Tests rapid sequential ping/pong exchanges.
1727        #[tokio::test]
1728        async fn sequential_ping_pong() {
1729            const NUM_PINGS: usize = 10;
1730
1731            let (conn_client, conn_server) = mock_connection_pair();
1732
1733            let client_session = Session::new(conn_client, Role::Client, Config::new());
1734            let server_session = Session::new(conn_server, Role::Server, Config::new());
1735
1736            let client_start = tokio::spawn(async move { client_session.start().await });
1737            let server_start = tokio::spawn(async move { server_session.start().await });
1738
1739            let client_handle = client_start.await.unwrap().unwrap();
1740            let mut server_handle = server_start.await.unwrap().unwrap();
1741
1742            let client_inner = Arc::clone(&client_handle.inner);
1743            let mut client_writer = client_handle.writer;
1744            let mut client_reader = client_handle.reader;
1745
1746            // Client: send pings and track them
1747            let pings_sent = Arc::new(std::sync::atomic::AtomicUsize::new(0));
1748            let pings_sent_clone = Arc::clone(&pings_sent);
1749
1750            let client_sender = tokio::spawn(async move {
1751                for _ in 0..NUM_PINGS {
1752                    let sequence = client_inner.next_ping_seq.fetch_add(1, Ordering::SeqCst);
1753                    let (response_tx, _) = oneshot::channel();
1754                    {
1755                        let mut pending = client_inner.pending_pings.lock().unwrap();
1756                        pending.insert(
1757                            sequence,
1758                            PendingPing {
1759                                sent_at: Instant::now(),
1760                                response_tx,
1761                            },
1762                        );
1763                    }
1764                    let ping_msg = quic_reverse_control::Ping { sequence };
1765                    client_writer
1766                        .write_message(&ProtocolMessage::Ping(ping_msg))
1767                        .await
1768                        .unwrap();
1769                    pings_sent_clone.fetch_add(1, Ordering::SeqCst);
1770                }
1771                client_writer.flush().await.unwrap();
1772            });
1773
1774            // Server: process pings
1775            let server_processor = tokio::spawn(async move {
1776                let mut processed = 0;
1777                for _ in 0..NUM_PINGS {
1778                    match server_handle.process_message().await {
1779                        Ok(Some(ControlEvent::Ping { .. })) => processed += 1,
1780                        _ => break,
1781                    }
1782                }
1783                processed
1784            });
1785
1786            // Client: receive pongs
1787            let client_receiver = tokio::spawn(async move {
1788                let mut received = 0;
1789                for _ in 0..NUM_PINGS {
1790                    match client_reader.read_message().await {
1791                        Ok(Some(ProtocolMessage::Pong(_))) => received += 1,
1792                        _ => break,
1793                    }
1794                }
1795                received
1796            });
1797
1798            client_sender.await.unwrap();
1799            let processed = server_processor.await.unwrap();
1800            let received = client_receiver.await.unwrap();
1801
1802            assert_eq!(pings_sent.load(Ordering::SeqCst), NUM_PINGS);
1803            assert_eq!(processed, NUM_PINGS);
1804            assert_eq!(received, NUM_PINGS);
1805        }
1806
1807        /// Tests high-volume registry operations.
1808        #[tokio::test]
1809        async fn registry_stress() {
1810            use crate::registry::StreamRegistry;
1811            use quic_reverse_control::ServiceId;
1812
1813            let mut registry = StreamRegistry::new(100, 100);
1814            const NUM_OPS: usize = 100;
1815
1816            // Register many pending requests
1817            for i in 0..NUM_OPS {
1818                let (tx, _rx) = oneshot::channel();
1819                let request = OpenRequest::new(i as u64, format!("svc-{i}"));
1820                assert!(
1821                    registry.register_pending(&request, tx).is_some(),
1822                    "failed to register pending {i}"
1823                );
1824            }
1825
1826            // Take all pending and register as active
1827            for i in 0..NUM_OPS {
1828                let pending = registry.take_pending(i as u64);
1829                assert!(pending.is_some(), "failed to take pending {i}");
1830
1831                let service = ServiceId::new(format!("svc-{i}"));
1832                assert!(
1833                    registry
1834                        .register_active(i as u64, service, Metadata::Empty, i as u64)
1835                        .is_some(),
1836                    "failed to register active {i}"
1837                );
1838            }
1839
1840            // Remove all active
1841            for i in 0..NUM_OPS {
1842                registry.remove_active(i as u64);
1843            }
1844
1845            // Registry should be empty now
1846            assert!(registry.can_open());
1847        }
1848    }
1849}