1#![allow(clippy::unwrap_used, clippy::missing_panics_doc)]
20
21use crate::control::{ControlReader, ControlStream, ControlWriter};
27use crate::error::TimeoutKind;
28use crate::negotiation::{negotiate_client, negotiate_server, NegotiatedParams};
29use crate::registry::{OpenResult, StreamRegistry};
30use crate::state::State;
31use crate::{Config, Error, Role};
32use quic_reverse_control::{
33 CloseCode, Metadata, OpenRequest, OpenResponse, OpenStatus, ProtocolMessage, RejectCode,
34 ServiceId, StreamClose,
35};
36use quic_reverse_transport::Connection;
37use std::collections::HashMap;
38use std::sync::atomic::{AtomicU64, AtomicU8, Ordering};
39use std::sync::{Arc, Mutex};
40use std::time::Instant;
41use tokio::sync::oneshot;
42use tokio::time::timeout;
43use tracing::{debug, error, info, instrument, trace, warn};
44
45pub(crate) struct SessionInner<C: Connection> {
47 pub(crate) connection: C,
49 pub(crate) config: Config,
51 pub(crate) role: Role,
53 pub(crate) state: AtomicU8,
55 pub(crate) negotiated: Mutex<Option<NegotiatedParams>>,
57 pub(crate) registry: Mutex<StreamRegistry>,
59 pub(crate) next_ping_seq: AtomicU64,
61 pub(crate) pending_pings: Mutex<HashMap<u64, PendingPing>>,
63}
64
65pub(crate) struct PendingPing {
67 pub(crate) sent_at: Instant,
69 pub(crate) response_tx: oneshot::Sender<()>,
71}
72
73pub struct Session<C: Connection> {
95 inner: Arc<SessionInner<C>>,
96}
97
98impl<C: Connection> Clone for Session<C> {
99 fn clone(&self) -> Self {
100 Self {
101 inner: Arc::clone(&self.inner),
102 }
103 }
104}
105
106impl<C: Connection> Session<C> {
107 #[must_use]
112 pub fn new(connection: C, role: Role, config: Config) -> Self {
113 let registry =
114 StreamRegistry::new(config.max_inflight_opens, config.max_concurrent_streams);
115
116 debug!(
117 %role,
118 max_inflight = config.max_inflight_opens,
119 max_concurrent = config.max_concurrent_streams,
120 "session created"
121 );
122
123 Self {
124 inner: Arc::new(SessionInner {
125 connection,
126 config,
127 role,
128 state: AtomicU8::new(State::Init as u8),
129 negotiated: Mutex::new(None),
130 registry: Mutex::new(registry),
131 next_ping_seq: AtomicU64::new(1),
132 pending_pings: Mutex::new(HashMap::new()),
133 }),
134 }
135 }
136
137 #[must_use]
139 pub fn state(&self) -> State {
140 State::from_u8(self.inner.state.load(Ordering::SeqCst))
141 }
142
143 #[must_use]
145 pub fn role(&self) -> Role {
146 self.inner.role
147 }
148
149 #[must_use]
151 pub fn negotiated_params(&self) -> Option<NegotiatedParams> {
152 self.inner.negotiated.lock().unwrap().clone()
153 }
154
155 #[must_use]
157 pub fn is_ready(&self) -> bool {
158 self.state() == State::Ready
159 }
160
161 #[must_use]
163 pub fn is_disconnected(&self) -> bool {
164 self.state() == State::Disconnected
165 }
166
167 #[must_use]
169 pub fn connection(&self) -> &C {
170 &self.inner.connection
171 }
172
173 #[instrument(skip(self), fields(role = %self.inner.role))]
186 pub async fn start(&self) -> Result<SessionHandle<C>, Error> {
187 if self.state() != State::Init {
189 warn!(state = %self.state(), "cannot start session in non-init state");
190 return Err(Error::protocol_violation(format!(
191 "cannot start session in {} state",
192 self.state()
193 )));
194 }
195
196 self.inner.config.validate()?;
198
199 self.set_state(State::Negotiating);
201 debug!("transitioning to negotiating state");
202
203 let (control_send, control_recv) = match self.inner.role {
205 Role::Client => {
206 debug!("opening control stream");
208 self.inner.connection.open_bi().await.map_err(|e| {
209 error!(error = %e, "failed to open control stream");
210 Error::Transport(Box::new(e))
211 })?
212 }
213 Role::Server => {
214 debug!("waiting for control stream");
216 self.inner
217 .connection
218 .accept_bi()
219 .await
220 .map_err(|e| {
221 error!(error = %e, "failed to accept control stream");
222 Error::Transport(Box::new(e))
223 })?
224 .ok_or_else(|| {
225 error!("connection closed before control stream");
226 Error::protocol_violation("connection closed before control stream")
227 })?
228 }
229 };
230
231 let mut control = ControlStream::new(control_send, control_recv);
232 debug!("control stream established");
233
234 let negotiation_timeout = self.inner.config.negotiation_timeout;
236 debug!(?negotiation_timeout, "starting negotiation");
237
238 let negotiate_result = match self.inner.role {
239 Role::Client => {
240 timeout(
241 negotiation_timeout,
242 negotiate_client(&mut control, &self.inner.config),
243 )
244 .await
245 }
246 Role::Server => {
247 timeout(
248 negotiation_timeout,
249 negotiate_server(&mut control, &self.inner.config),
250 )
251 .await
252 }
253 };
254
255 let params = if let Ok(result) = negotiate_result {
256 result?
257 } else {
258 warn!("negotiation timed out");
259 self.set_state(State::Closed);
260 return Err(Error::Timeout(TimeoutKind::Negotiation));
261 };
262
263 info!(
265 version = params.version,
266 features = ?params.features,
267 remote_agent = ?params.remote_agent,
268 "negotiation complete"
269 );
270 *self.inner.negotiated.lock().unwrap() = Some(params);
271
272 self.set_state(State::Ready);
274 info!("session ready");
275
276 let (writer, reader) = control.split();
278
279 Ok(SessionHandle {
280 inner: Arc::clone(&self.inner),
281 writer,
282 reader,
283 })
284 }
285
286 fn set_state(&self, state: State) {
288 self.inner.state.store(state as u8, Ordering::SeqCst);
289 }
290}
291
292pub struct SessionHandle<C: Connection> {
301 pub(crate) inner: Arc<SessionInner<C>>,
302 pub(crate) writer: ControlWriter<C::SendStream>,
303 pub(crate) reader: ControlReader<C::RecvStream>,
304}
305
306impl<C: Connection> SessionHandle<C> {
307 #[must_use]
309 pub fn state(&self) -> State {
310 State::from_u8(self.inner.state.load(Ordering::SeqCst))
311 }
312
313 #[must_use]
315 pub fn negotiated_params(&self) -> Option<NegotiatedParams> {
316 self.inner.negotiated.lock().unwrap().clone()
317 }
318
319 #[must_use]
321 pub fn is_ready(&self) -> bool {
322 self.state() == State::Ready
323 }
324
325 #[must_use]
327 pub fn is_disconnected(&self) -> bool {
328 self.state() == State::Disconnected
329 }
330
331 #[instrument(skip(self, metadata), fields(service = %service.as_ref()))]
349 pub async fn open(
350 &mut self,
351 service: impl Into<ServiceId> + AsRef<str>,
352 metadata: Metadata,
353 ) -> Result<(C::SendStream, C::RecvStream), Error> {
354 if !self.is_ready() {
355 warn!("cannot open stream: session not ready");
356 return Err(Error::SessionClosed);
357 }
358
359 let service = service.into();
360
361 let (response_tx, response_rx) = oneshot::channel();
363 let request_id = {
364 let mut registry = self.inner.registry.lock().unwrap();
365 let request_id = registry.next_request_id();
366 let request =
367 OpenRequest::new(request_id, service.clone()).with_metadata(metadata.clone());
368
369 if registry.register_pending(&request, response_tx).is_none() {
370 warn!(
371 request_id,
372 "capacity exceeded: too many pending open requests"
373 );
374 return Err(Error::CapacityExceeded("too many pending open requests"));
375 }
376
377 request_id
378 };
379
380 debug!(request_id, service = %service.as_str(), "sending open request");
381
382 let request = OpenRequest::new(request_id, service).with_metadata(metadata);
384 self.writer
385 .write_message(&ProtocolMessage::OpenRequest(request))
386 .await?;
387 self.writer.flush().await?;
388
389 let open_timeout = self.inner.config.open_timeout;
391 let result = match timeout(open_timeout, response_rx).await {
392 Ok(Ok(result)) => result,
393 Ok(Err(_)) => {
394 warn!(request_id, "session closed while waiting for response");
397 let mut registry = self.inner.registry.lock().unwrap();
398 registry.take_pending(request_id);
399 return Err(Error::SessionClosed);
400 }
401 Err(_) => {
402 warn!(request_id, ?open_timeout, "open request timed out");
404 let mut registry = self.inner.registry.lock().unwrap();
405 registry.take_pending(request_id);
406 return Err(Error::Timeout(TimeoutKind::OpenRequest));
407 }
408 };
409
410 match result {
411 OpenResult::Accepted { logical_stream_id } => {
412 debug!(request_id, logical_stream_id, "open request accepted");
413
414 let bind_timeout = self.inner.config.stream_bind_timeout;
416 let stream_result = timeout(bind_timeout, self.inner.connection.accept_bi()).await;
417
418 let (send, recv) = match stream_result {
419 Ok(Ok(Some(streams))) => streams,
420 Ok(Ok(None)) => {
421 error!(request_id, "connection closed while waiting for stream");
422 return Err(Error::protocol_violation(
423 "connection closed while waiting for stream",
424 ));
425 }
426 Ok(Err(e)) => {
427 error!(request_id, error = %e, "transport error while binding stream");
428 return Err(Error::Transport(Box::new(e)));
429 }
430 Err(_) => {
431 warn!(request_id, ?bind_timeout, "stream bind timed out");
432 return Err(Error::Timeout(TimeoutKind::StreamBind));
433 }
434 };
435
436 {
438 let mut registry = self.inner.registry.lock().unwrap();
439 registry.register_active(
440 logical_stream_id,
441 ServiceId::from(""),
442 Metadata::Empty,
443 request_id,
444 );
445 }
446
447 info!(request_id, logical_stream_id, "stream opened successfully");
448 Ok((send, recv))
449 }
450 OpenResult::Rejected { code, reason } => {
451 warn!(request_id, ?code, ?reason, "open request rejected");
452 Err(Error::StreamRejected { code, reason })
453 }
454 }
455 }
456
457 pub async fn process_message(&mut self) -> Result<Option<ControlEvent>, Error> {
466 let Some(message) = self.reader.read_message().await? else {
467 debug!("control stream closed");
468 return Ok(None);
469 };
470
471 match message {
472 ProtocolMessage::OpenRequest(req) => {
473 debug!(
475 request_id = req.request_id,
476 service = %req.service.as_str(),
477 "received open request"
478 );
479 Ok(Some(ControlEvent::OpenRequest {
480 request_id: req.request_id,
481 service: req.service,
482 metadata: req.metadata,
483 }))
484 }
485
486 ProtocolMessage::OpenResponse(resp) => {
487 let accepted = matches!(resp.status, OpenStatus::Accepted);
489 debug!(
490 request_id = resp.request_id,
491 accepted,
492 logical_stream_id = ?resp.logical_stream_id,
493 "received open response"
494 );
495 let mut registry = self.inner.registry.lock().unwrap();
496 if let Some(pending) = registry.take_pending(resp.request_id) {
497 let result = match resp.status {
498 OpenStatus::Accepted => OpenResult::Accepted {
499 logical_stream_id: resp.logical_stream_id.unwrap_or(0),
500 },
501 OpenStatus::Rejected(code) => OpenResult::Rejected {
502 code,
503 reason: resp.reason,
504 },
505 };
506 let _ = pending.response_tx.send(result);
507 }
508 Ok(Some(ControlEvent::OpenResponseReceived {
509 request_id: resp.request_id,
510 accepted,
511 }))
512 }
513
514 ProtocolMessage::Ping(ping_msg) => {
515 trace!(sequence = ping_msg.sequence, "received ping, sending pong");
517 let pong_msg = quic_reverse_control::Pong {
518 sequence: ping_msg.sequence,
519 };
520 self.writer
521 .write_message(&ProtocolMessage::Pong(pong_msg))
522 .await?;
523 self.writer.flush().await?;
524 Ok(Some(ControlEvent::Ping {
525 sequence: ping_msg.sequence,
526 }))
527 }
528
529 ProtocolMessage::Pong(pong) => {
530 trace!(sequence = pong.sequence, "received pong");
532 let mut pending = self.inner.pending_pings.lock().unwrap();
533 if let Some(pending_ping) = pending.remove(&pong.sequence) {
534 let rtt = pending_ping.sent_at.elapsed();
535 trace!(sequence = pong.sequence, ?rtt, "ping resolved");
536 let _ = pending_ping.response_tx.send(());
537 }
538 Ok(Some(ControlEvent::Pong {
539 sequence: pong.sequence,
540 }))
541 }
542
543 ProtocolMessage::Hello(_) | ProtocolMessage::HelloAck(_) => {
544 warn!("received unexpected Hello/HelloAck after negotiation");
546 Err(Error::protocol_violation(
547 "unexpected Hello/HelloAck after negotiation",
548 ))
549 }
550
551 ProtocolMessage::StreamClose(sc) => {
552 if sc.logical_stream_id == 0 {
554 info!(code = ?sc.code, reason = ?sc.reason, "received session close");
555 self.set_state(State::Closing);
556 Ok(Some(ControlEvent::CloseReceived {
557 code: sc.code,
558 reason: sc.reason,
559 }))
560 } else {
561 debug!(
562 logical_stream_id = sc.logical_stream_id,
563 code = ?sc.code,
564 "received stream close"
565 );
566 Ok(Some(ControlEvent::StreamClose {
567 logical_stream_id: sc.logical_stream_id,
568 code: sc.code,
569 }))
570 }
571 }
572 }
573 }
574
575 #[instrument(skip(self))]
581 pub async fn accept_open(
582 &mut self,
583 request_id: u64,
584 logical_stream_id: u64,
585 ) -> Result<(), Error> {
586 debug!(request_id, logical_stream_id, "accepting open request");
587 let response = OpenResponse::accepted(request_id, logical_stream_id);
588 self.writer
589 .write_message(&ProtocolMessage::OpenResponse(response))
590 .await?;
591 self.writer.flush().await
592 }
593
594 #[instrument(skip(self))]
600 pub async fn reject_open(
601 &mut self,
602 request_id: u64,
603 code: RejectCode,
604 reason: Option<String>,
605 ) -> Result<(), Error> {
606 debug!(request_id, ?code, ?reason, "rejecting open request");
607 let response = OpenResponse::rejected(request_id, code, reason);
608 self.writer
609 .write_message(&ProtocolMessage::OpenResponse(response))
610 .await?;
611 self.writer.flush().await
612 }
613
614 #[instrument(skip(self))]
623 pub async fn close_stream(
624 &mut self,
625 logical_stream_id: u64,
626 code: CloseCode,
627 reason: Option<String>,
628 ) -> Result<(), Error> {
629 if !self.is_ready() {
630 warn!(logical_stream_id, "cannot close stream: session not ready");
631 return Err(Error::SessionClosed);
632 }
633
634 debug!(logical_stream_id, ?code, ?reason, "closing stream");
635
636 {
638 let mut registry = self.inner.registry.lock().unwrap();
639 registry.remove_active(logical_stream_id);
640 }
641
642 let close_msg = StreamClose {
643 logical_stream_id,
644 code,
645 reason,
646 };
647 self.writer
648 .write_message(&ProtocolMessage::StreamClose(close_msg))
649 .await?;
650 self.writer.flush().await
651 }
652
653 #[instrument(skip(self))]
663 pub async fn ping(&mut self) -> Result<std::time::Duration, Error> {
664 if !self.is_ready() {
665 warn!("cannot ping: session not ready");
666 return Err(Error::SessionClosed);
667 }
668
669 let sequence = self.inner.next_ping_seq.fetch_add(1, Ordering::SeqCst);
671 trace!(sequence, "sending ping");
672
673 let (response_tx, response_rx) = oneshot::channel();
675 let sent_at = Instant::now();
676
677 {
679 let mut pending = self.inner.pending_pings.lock().unwrap();
680 pending.insert(
681 sequence,
682 PendingPing {
683 sent_at,
684 response_tx,
685 },
686 );
687 }
688
689 let ping_msg = quic_reverse_control::Ping { sequence };
691 self.writer
692 .write_message(&ProtocolMessage::Ping(ping_msg))
693 .await?;
694 self.writer.flush().await?;
695
696 let ping_timeout = self.inner.config.ping_timeout;
698 match timeout(ping_timeout, response_rx).await {
699 Ok(Ok(())) => {
700 let rtt = sent_at.elapsed();
701 debug!(sequence, ?rtt, "ping completed");
702 Ok(rtt)
703 }
704 Ok(Err(_)) => {
705 warn!(sequence, "session closed while waiting for pong");
707 Err(Error::SessionClosed)
708 }
709 Err(_) => {
710 warn!(sequence, ?ping_timeout, "ping timed out");
712 let mut pending = self.inner.pending_pings.lock().unwrap();
713 pending.remove(&sequence);
714 Err(Error::Timeout(TimeoutKind::Ping))
715 }
716 }
717 }
718
719 #[instrument(skip(self))]
728 pub async fn close(&mut self, code: CloseCode, reason: Option<String>) -> Result<(), Error> {
729 if !self.is_ready() && self.state() != State::Closing {
730 warn!("cannot close: session already closed");
731 return Err(Error::SessionClosed);
732 }
733
734 info!(?code, ?reason, "closing session");
735 self.set_state(State::Closing);
736
737 let close_msg = StreamClose {
739 logical_stream_id: 0,
740 code,
741 reason,
742 };
743 self.writer
744 .write_message(&ProtocolMessage::StreamClose(close_msg))
745 .await?;
746 self.writer.flush().await
747 }
748
749 fn set_state(&self, state: State) {
751 self.inner.state.store(state as u8, Ordering::SeqCst);
752 }
753}
754
755#[derive(Debug, Clone)]
757pub enum ControlEvent {
758 OpenRequest {
760 request_id: u64,
762 service: ServiceId,
764 metadata: Metadata,
766 },
767 OpenResponseReceived {
769 request_id: u64,
771 accepted: bool,
773 },
774 CloseReceived {
776 code: CloseCode,
778 reason: Option<String>,
780 },
781 Ping {
783 sequence: u64,
785 },
786 Pong {
788 sequence: u64,
790 },
791 StreamClose {
793 logical_stream_id: u64,
795 code: CloseCode,
797 },
798}
799
800#[cfg(test)]
801mod tests {
802 use super::*;
803 use quic_reverse_control::Features;
804 use quic_reverse_transport::mock_connection_pair;
805
806 #[tokio::test]
807 async fn session_creation() {
808 let (conn_client, _conn_server) = mock_connection_pair();
809
810 let config = Config::new()
811 .with_features(Features::PING_PONG)
812 .with_agent("test/1.0");
813
814 let session = Session::new(conn_client, Role::Client, config);
815
816 assert_eq!(session.state(), State::Init);
817 assert_eq!(session.role(), Role::Client);
818 assert!(session.negotiated_params().is_none());
819 }
820
821 #[tokio::test]
822 async fn session_start_and_negotiate() {
823 let (conn_client, conn_server) = mock_connection_pair();
824
825 let client_config = Config::new()
826 .with_features(Features::PING_PONG)
827 .with_agent("client/1.0");
828
829 let server_config = Config::new()
830 .with_features(Features::PING_PONG)
831 .with_agent("server/1.0");
832
833 let client_session = Session::new(conn_client, Role::Client, client_config);
834 let server_session = Session::new(conn_server, Role::Server, server_config);
835
836 let client_session_ref = client_session.clone();
838 let server_session_ref = server_session.clone();
839
840 let client_handle = tokio::spawn(async move { client_session.start().await });
842 let server_handle = tokio::spawn(async move { server_session.start().await });
843
844 let client_result = client_handle.await.expect("client task");
846 let server_result = server_handle.await.expect("server task");
847
848 assert!(client_result.is_ok(), "client failed");
850 assert!(server_result.is_ok(), "server failed");
851
852 assert_eq!(client_session_ref.state(), State::Ready);
854 assert_eq!(server_session_ref.state(), State::Ready);
855
856 let client_params = client_session_ref
858 .negotiated_params()
859 .expect("client params");
860 let server_params = server_session_ref
861 .negotiated_params()
862 .expect("server params");
863
864 assert_eq!(client_params.version, server_params.version);
865 assert_eq!(client_params.features, Features::PING_PONG);
866
867 assert_eq!(client_params.remote_agent.as_deref(), Some("server/1.0"));
869 assert_eq!(server_params.remote_agent.as_deref(), Some("client/1.0"));
870 }
871
872 #[tokio::test]
873 async fn cannot_start_twice() {
874 let (conn_client, conn_server) = mock_connection_pair();
875
876 let client_session = Session::new(conn_client, Role::Client, Config::new());
877 let server_session = Session::new(conn_server, Role::Server, Config::new());
878
879 let client_session_ref = client_session.clone();
881
882 let client_handle = tokio::spawn(async move { client_session.start().await });
884 let server_handle = tokio::spawn(async move { server_session.start().await });
885
886 let _ = client_handle.await;
887 let _ = server_handle.await;
888
889 let result = client_session_ref.start().await;
891 assert!(result.is_err());
892 }
893
894 #[tokio::test]
895 async fn ping_pong_exchange() {
896 let (conn_client, conn_server) = mock_connection_pair();
897
898 let config = Config::new().with_features(Features::PING_PONG);
899
900 let client_session = Session::new(conn_client, Role::Client, config.clone());
901 let server_session = Session::new(conn_server, Role::Server, config);
902
903 let client_start = tokio::spawn(async move { client_session.start().await });
905 let server_start = tokio::spawn(async move { server_session.start().await });
906
907 let mut client_handle = client_start.await.unwrap().unwrap();
908 let mut server_handle = server_start.await.unwrap().unwrap();
909
910 let ping = quic_reverse_control::Ping { sequence: 42 };
912 client_handle
913 .writer
914 .write_message(&ProtocolMessage::Ping(ping))
915 .await
916 .unwrap();
917 client_handle.writer.flush().await.unwrap();
918
919 let event = server_handle.process_message().await.unwrap().unwrap();
921 assert!(matches!(event, ControlEvent::Ping { sequence: 42 }));
922
923 let event = client_handle.process_message().await.unwrap().unwrap();
925 assert!(matches!(event, ControlEvent::Pong { sequence: 42 }));
926 }
927
928 #[tokio::test]
929 async fn close_session() {
930 let (conn_client, conn_server) = mock_connection_pair();
931
932 let client_session = Session::new(conn_client, Role::Client, Config::new());
933 let server_session = Session::new(conn_server, Role::Server, Config::new());
934
935 let client_session_ref = client_session.clone();
937 let server_session_ref = server_session.clone();
938
939 let client_start = tokio::spawn(async move { client_session.start().await });
941 let server_start = tokio::spawn(async move { server_session.start().await });
942
943 let mut client_handle = client_start.await.unwrap().unwrap();
944 let mut server_handle = server_start.await.unwrap().unwrap();
945
946 client_handle
948 .close(CloseCode::Normal, Some("goodbye".into()))
949 .await
950 .unwrap();
951
952 let event = server_handle.process_message().await.unwrap().unwrap();
954 match event {
955 ControlEvent::CloseReceived { code, reason } => {
956 assert_eq!(code, CloseCode::Normal);
957 assert_eq!(reason.as_deref(), Some("goodbye"));
958 }
959 _ => panic!("expected CloseReceived"),
960 }
961
962 assert_eq!(client_session_ref.state(), State::Closing);
963 assert_eq!(server_session_ref.state(), State::Closing);
964 }
965
966 #[tokio::test]
967 async fn stream_open_and_accept() {
968 use tokio::io::{AsyncReadExt, AsyncWriteExt};
969 use tokio::sync::mpsc;
970
971 let (conn_client, conn_server) = mock_connection_pair();
972
973 let client_session = Session::new(conn_client, Role::Client, Config::new());
974 let server_session = Session::new(conn_server, Role::Server, Config::new());
975
976 let client_start = tokio::spawn(async move { client_session.start().await });
978 let server_start = tokio::spawn(async move { server_session.start().await });
979
980 let client_handle = client_start.await.unwrap().unwrap();
981 let mut server_handle = server_start.await.unwrap().unwrap();
982
983 let (open_done_tx, mut open_done_rx) = mpsc::channel::<(
986 quic_reverse_transport::MockSendStream,
987 quic_reverse_transport::MockRecvStream,
988 )>(1);
989
990 let client_inner = Arc::clone(&client_handle.inner);
992 let mut client_reader = client_handle.reader;
993 let client_msg_processor = tokio::spawn(async move {
994 let msg = client_reader.read_message().await.unwrap().unwrap();
996 if let ProtocolMessage::OpenResponse(resp) = msg {
997 let accepted = matches!(resp.status, OpenStatus::Accepted);
998 let mut registry = client_inner.registry.lock().unwrap();
999 if let Some(pending) = registry.take_pending(resp.request_id) {
1000 let result = match resp.status {
1001 OpenStatus::Accepted => OpenResult::Accepted {
1002 logical_stream_id: resp.logical_stream_id.unwrap_or(0),
1003 },
1004 OpenStatus::Rejected(code) => OpenResult::Rejected {
1005 code,
1006 reason: resp.reason,
1007 },
1008 };
1009 let _ = pending.response_tx.send(result);
1010 }
1011 accepted
1012 } else {
1013 panic!("expected OpenResponse");
1014 }
1015 });
1016
1017 let client_inner2 = Arc::clone(&client_handle.inner);
1019 let mut client_writer = client_handle.writer;
1020 let client_open = tokio::spawn(async move {
1021 let (response_tx, response_rx) = oneshot::channel();
1023 let request_id = {
1024 let mut registry = client_inner2.registry.lock().unwrap();
1025 let request_id = registry.next_request_id();
1026 let request = OpenRequest::new(request_id, "ssh").with_metadata(Metadata::Empty);
1027 registry.register_pending(&request, response_tx).unwrap();
1028 request_id
1029 };
1030
1031 let request = OpenRequest::new(request_id, "ssh").with_metadata(Metadata::Empty);
1033 client_writer
1034 .write_message(&ProtocolMessage::OpenRequest(request))
1035 .await
1036 .unwrap();
1037 client_writer.flush().await.unwrap();
1038
1039 let result = response_rx.await.unwrap();
1041
1042 match result {
1043 OpenResult::Accepted { .. } => {
1044 let (send, recv) = client_inner2.connection.accept_bi().await.unwrap().unwrap();
1046 open_done_tx.send((send, recv)).await.unwrap();
1047 }
1048 OpenResult::Rejected { code, reason } => {
1049 panic!("rejected: {code:?} {reason:?}");
1050 }
1051 }
1052 });
1053
1054 let event = server_handle.process_message().await.unwrap().unwrap();
1056 let (request_id, service) = match event {
1057 ControlEvent::OpenRequest {
1058 request_id,
1059 service,
1060 ..
1061 } => (request_id, service),
1062 _ => panic!("expected OpenRequest, got {event:?}"),
1063 };
1064 assert_eq!(service.as_str(), "ssh");
1065
1066 let logical_stream_id = 1;
1068 server_handle
1069 .accept_open(request_id, logical_stream_id)
1070 .await
1071 .unwrap();
1072
1073 let (mut server_send, mut server_recv) =
1075 server_handle.inner.connection.open_bi().await.unwrap();
1076
1077 client_msg_processor.await.unwrap();
1079 client_open.await.unwrap();
1080
1081 let (mut client_send, mut client_recv) = open_done_rx.recv().await.unwrap();
1083
1084 server_send.write_all(b"hello from server").await.unwrap();
1086 server_send.flush().await.unwrap();
1087
1088 let mut buf = [0u8; 32];
1089 let n = client_recv.read(&mut buf).await.unwrap();
1090 assert_eq!(&buf[..n], b"hello from server");
1091
1092 client_send.write_all(b"hello from client").await.unwrap();
1093 client_send.flush().await.unwrap();
1094
1095 let n = server_recv.read(&mut buf).await.unwrap();
1096 assert_eq!(&buf[..n], b"hello from client");
1097 }
1098
1099 #[tokio::test]
1100 async fn stream_open_rejected() {
1101 use tokio::sync::mpsc;
1102
1103 let (conn_client, conn_server) = mock_connection_pair();
1104
1105 let client_session = Session::new(conn_client, Role::Client, Config::new());
1106 let server_session = Session::new(conn_server, Role::Server, Config::new());
1107
1108 let client_start = tokio::spawn(async move { client_session.start().await });
1110 let server_start = tokio::spawn(async move { server_session.start().await });
1111
1112 let client_handle = client_start.await.unwrap().unwrap();
1113 let mut server_handle = server_start.await.unwrap().unwrap();
1114
1115 let (result_tx, mut result_rx) = mpsc::channel::<Result<(), Error>>(1);
1117
1118 let client_inner = Arc::clone(&client_handle.inner);
1120 let mut client_reader = client_handle.reader;
1121 let client_msg_processor = tokio::spawn(async move {
1122 let msg = client_reader.read_message().await.unwrap().unwrap();
1123 if let ProtocolMessage::OpenResponse(resp) = msg {
1124 let mut registry = client_inner.registry.lock().unwrap();
1125 if let Some(pending) = registry.take_pending(resp.request_id) {
1126 let result = match resp.status {
1127 OpenStatus::Accepted => OpenResult::Accepted {
1128 logical_stream_id: resp.logical_stream_id.unwrap_or(0),
1129 },
1130 OpenStatus::Rejected(code) => OpenResult::Rejected {
1131 code,
1132 reason: resp.reason,
1133 },
1134 };
1135 let _ = pending.response_tx.send(result);
1136 }
1137 }
1138 });
1139
1140 let client_inner2 = Arc::clone(&client_handle.inner);
1142 let mut client_writer = client_handle.writer;
1143 let client_open = tokio::spawn(async move {
1144 let (response_tx, response_rx) = oneshot::channel();
1145 let request_id = {
1146 let mut registry = client_inner2.registry.lock().unwrap();
1147 let request_id = registry.next_request_id();
1148 let request =
1149 OpenRequest::new(request_id, "unknown").with_metadata(Metadata::Empty);
1150 registry.register_pending(&request, response_tx).unwrap();
1151 request_id
1152 };
1153
1154 let request = OpenRequest::new(request_id, "unknown").with_metadata(Metadata::Empty);
1155 client_writer
1156 .write_message(&ProtocolMessage::OpenRequest(request))
1157 .await
1158 .unwrap();
1159 client_writer.flush().await.unwrap();
1160
1161 let result = response_rx.await.unwrap();
1162 match result {
1163 OpenResult::Accepted { .. } => {
1164 result_tx.send(Ok(())).await.unwrap();
1165 }
1166 OpenResult::Rejected { code, reason } => {
1167 result_tx
1168 .send(Err(Error::StreamRejected { code, reason }))
1169 .await
1170 .unwrap();
1171 }
1172 }
1173 });
1174
1175 let event = server_handle.process_message().await.unwrap().unwrap();
1177 let request_id = match event {
1178 ControlEvent::OpenRequest { request_id, .. } => request_id,
1179 _ => panic!("expected OpenRequest"),
1180 };
1181
1182 server_handle
1183 .reject_open(
1184 request_id,
1185 RejectCode::UnsupportedService,
1186 Some("not available".into()),
1187 )
1188 .await
1189 .unwrap();
1190
1191 client_msg_processor.await.unwrap();
1193 client_open.await.unwrap();
1194
1195 let result = result_rx.recv().await.unwrap();
1197 match result {
1198 Err(Error::StreamRejected { code, reason }) => {
1199 assert_eq!(code, RejectCode::UnsupportedService);
1200 assert_eq!(reason.as_deref(), Some("not available"));
1201 }
1202 other => panic!("expected StreamRejected, got {other:?}"),
1203 }
1204 }
1205
1206 #[tokio::test]
1207 async fn stream_close_notification() {
1208 use tokio::sync::mpsc;
1209
1210 let (conn_client, conn_server) = mock_connection_pair();
1211
1212 let client_session = Session::new(conn_client, Role::Client, Config::new());
1213 let server_session = Session::new(conn_server, Role::Server, Config::new());
1214
1215 let client_start = tokio::spawn(async move { client_session.start().await });
1217 let server_start = tokio::spawn(async move { server_session.start().await });
1218
1219 let client_handle = client_start.await.unwrap().unwrap();
1220 let mut server_handle = server_start.await.unwrap().unwrap();
1221
1222 let (open_done_tx, mut open_done_rx) = mpsc::channel::<u64>(1);
1224
1225 let client_inner = Arc::clone(&client_handle.inner);
1227 let mut client_reader = client_handle.reader;
1228 let client_msg_processor = tokio::spawn(async move {
1229 let msg = client_reader.read_message().await.unwrap().unwrap();
1231 if let ProtocolMessage::OpenResponse(resp) = msg {
1232 let mut registry = client_inner.registry.lock().unwrap();
1233 if let Some(pending) = registry.take_pending(resp.request_id) {
1234 let result = match resp.status {
1235 OpenStatus::Accepted => OpenResult::Accepted {
1236 logical_stream_id: resp.logical_stream_id.unwrap_or(0),
1237 },
1238 OpenStatus::Rejected(code) => OpenResult::Rejected {
1239 code,
1240 reason: resp.reason,
1241 },
1242 };
1243 let _ = pending.response_tx.send(result);
1244 }
1245 }
1246
1247 let msg = client_reader.read_message().await.unwrap().unwrap();
1249 if let ProtocolMessage::StreamClose(sc) = msg {
1250 (sc.logical_stream_id, sc.code, sc.reason)
1251 } else {
1252 panic!("expected StreamClose");
1253 }
1254 });
1255
1256 let client_inner2 = Arc::clone(&client_handle.inner);
1258 let mut client_writer = client_handle.writer;
1259 let client_open = tokio::spawn(async move {
1260 let (response_tx, response_rx) = oneshot::channel();
1261 let request_id = {
1262 let mut registry = client_inner2.registry.lock().unwrap();
1263 let request_id = registry.next_request_id();
1264 let request = OpenRequest::new(request_id, "ssh").with_metadata(Metadata::Empty);
1265 registry.register_pending(&request, response_tx).unwrap();
1266 request_id
1267 };
1268
1269 let request = OpenRequest::new(request_id, "ssh").with_metadata(Metadata::Empty);
1270 client_writer
1271 .write_message(&ProtocolMessage::OpenRequest(request))
1272 .await
1273 .unwrap();
1274 client_writer.flush().await.unwrap();
1275
1276 let result = response_rx.await.unwrap();
1277 if let OpenResult::Accepted { logical_stream_id } = result {
1278 open_done_tx.send(logical_stream_id).await.unwrap();
1279 }
1280 });
1281
1282 let event = server_handle.process_message().await.unwrap().unwrap();
1284 let request_id = match event {
1285 ControlEvent::OpenRequest { request_id, .. } => request_id,
1286 _ => panic!("expected OpenRequest"),
1287 };
1288
1289 let logical_stream_id = 42;
1291 server_handle
1292 .accept_open(request_id, logical_stream_id)
1293 .await
1294 .unwrap();
1295
1296 client_open.await.unwrap();
1298 let received_id = open_done_rx.recv().await.unwrap();
1299 assert_eq!(received_id, logical_stream_id);
1300
1301 server_handle
1303 .close_stream(logical_stream_id, CloseCode::Normal, Some("done".into()))
1304 .await
1305 .unwrap();
1306
1307 let (close_id, close_code, close_reason) = client_msg_processor.await.unwrap();
1309 assert_eq!(close_id, logical_stream_id);
1310 assert_eq!(close_code, CloseCode::Normal);
1311 assert_eq!(close_reason.as_deref(), Some("done"));
1312 }
1313
1314 #[tokio::test]
1315 async fn open_respects_inflight_limit() {
1316 let (conn_client, conn_server) = mock_connection_pair();
1317
1318 let client_config = Config::new().with_max_inflight_opens(2);
1320 let server_config = Config::new();
1321
1322 let client_session = Session::new(conn_client, Role::Client, client_config);
1323 let server_session = Session::new(conn_server, Role::Server, server_config);
1324
1325 let client_start = tokio::spawn(async move { client_session.start().await });
1327 let server_start = tokio::spawn(async move { server_session.start().await });
1328
1329 let client_handle = client_start.await.unwrap().unwrap();
1330 let _server_handle = server_start.await.unwrap().unwrap();
1331
1332 let inner = Arc::clone(&client_handle.inner);
1334 let mut writer = client_handle.writer;
1335
1336 {
1338 let mut registry = inner.registry.lock().unwrap();
1339 let (tx1, _rx1) = oneshot::channel();
1340 let (tx2, _rx2) = oneshot::channel();
1341 let req1 = OpenRequest::new(1, "service1");
1342 let req2 = OpenRequest::new(2, "service2");
1343 assert!(registry.register_pending(&req1, tx1).is_some());
1344 assert!(registry.register_pending(&req2, tx2).is_some());
1345 }
1346
1347 let (response_tx, _response_rx) = oneshot::channel();
1349 let result = {
1350 let mut registry = inner.registry.lock().unwrap();
1351 let request_id = registry.next_request_id();
1352 let request = OpenRequest::new(request_id, "service3");
1353 registry.register_pending(&request, response_tx)
1354 };
1355
1356 assert!(result.is_none(), "should fail due to limit");
1357
1358 let _ = writer
1360 .write_message(&ProtocolMessage::Ping(quic_reverse_control::Ping {
1361 sequence: 0,
1362 }))
1363 .await;
1364 }
1365
1366 #[tokio::test]
1367 async fn open_request_timeout() {
1368 use std::time::Duration;
1369
1370 let (conn_client, conn_server) = mock_connection_pair();
1371
1372 let client_config = Config::new().with_open_timeout(Duration::from_millis(50));
1374 let server_config = Config::new();
1375
1376 let client_session = Session::new(conn_client, Role::Client, client_config);
1377 let server_session = Session::new(conn_server, Role::Server, server_config);
1378
1379 let client_start = tokio::spawn(async move { client_session.start().await });
1381 let server_start = tokio::spawn(async move { server_session.start().await });
1382
1383 let mut client_handle = client_start.await.unwrap().unwrap();
1384 let _server_handle = server_start.await.unwrap().unwrap();
1385
1386 let result = client_handle.open("ssh", Metadata::Empty).await;
1388
1389 match result {
1391 Err(Error::Timeout(TimeoutKind::OpenRequest)) => {}
1392 other => panic!("expected OpenRequest timeout, got: {other:?}"),
1393 }
1394 }
1395
1396 #[tokio::test]
1397 async fn stream_bind_timeout() {
1398 use std::time::Duration;
1399
1400 let (conn_client, conn_server) = mock_connection_pair();
1401
1402 let client_config = Config::new()
1404 .with_open_timeout(Duration::from_secs(5))
1405 .with_stream_bind_timeout(Duration::from_millis(50));
1406 let server_config = Config::new();
1407
1408 let client_session = Session::new(conn_client, Role::Client, client_config);
1409 let server_session = Session::new(conn_server, Role::Server, server_config);
1410
1411 let client_start = tokio::spawn(async move { client_session.start().await });
1413 let server_start = tokio::spawn(async move { server_session.start().await });
1414
1415 let client_handle = client_start.await.unwrap().unwrap();
1416 let mut server_handle = server_start.await.unwrap().unwrap();
1417
1418 let client_inner_open = Arc::clone(&client_handle.inner);
1420 let client_inner_msg = Arc::clone(&client_handle.inner);
1421 let mut client_writer = client_handle.writer;
1422 let mut client_reader = client_handle.reader;
1423
1424 let client_open = tokio::spawn(async move {
1426 let (response_tx, response_rx) = oneshot::channel();
1428 let request_id = {
1429 let mut registry = client_inner_open.registry.lock().unwrap();
1430 let request_id = registry.next_request_id();
1431 let request = OpenRequest::new(request_id, "ssh").with_metadata(Metadata::Empty);
1432 registry.register_pending(&request, response_tx).unwrap();
1433 request_id
1434 };
1435
1436 let request = OpenRequest::new(request_id, "ssh").with_metadata(Metadata::Empty);
1437 client_writer
1438 .write_message(&ProtocolMessage::OpenRequest(request))
1439 .await
1440 .unwrap();
1441 client_writer.flush().await.unwrap();
1442
1443 let result = response_rx.await.unwrap();
1445
1446 match result {
1448 OpenResult::Accepted { .. } => {
1449 let bind_timeout = Duration::from_millis(50);
1450 match timeout(bind_timeout, client_inner_open.connection.accept_bi()).await {
1451 Ok(Ok(Some(streams))) => Ok(streams),
1452 Ok(Ok(None)) => Err(Error::SessionClosed),
1453 Ok(Err(e)) => Err(Error::Transport(Box::new(e))),
1454 Err(_) => Err(Error::Timeout(TimeoutKind::StreamBind)),
1455 }
1456 }
1457 OpenResult::Rejected { code, reason } => {
1458 Err(Error::StreamRejected { code, reason })
1459 }
1460 }
1461 });
1462
1463 let client_msg_processor = tokio::spawn(async move {
1465 let msg = client_reader.read_message().await.unwrap().unwrap();
1466 if let ProtocolMessage::OpenResponse(resp) = msg {
1467 let mut registry = client_inner_msg.registry.lock().unwrap();
1468 if let Some(pending) = registry.take_pending(resp.request_id) {
1469 let result = match resp.status {
1470 OpenStatus::Accepted => OpenResult::Accepted {
1471 logical_stream_id: resp.logical_stream_id.unwrap_or(0),
1472 },
1473 OpenStatus::Rejected(code) => OpenResult::Rejected {
1474 code,
1475 reason: resp.reason,
1476 },
1477 };
1478 let _ = pending.response_tx.send(result);
1479 }
1480 }
1481 });
1482
1483 let event = server_handle.process_message().await.unwrap().unwrap();
1485 let request_id = match event {
1486 ControlEvent::OpenRequest { request_id, .. } => request_id,
1487 _ => panic!("expected OpenRequest"),
1488 };
1489
1490 server_handle.accept_open(request_id, 1).await.unwrap();
1492
1493 let _ = client_msg_processor.await;
1495
1496 let result = client_open.await.unwrap();
1498 match result {
1499 Err(Error::Timeout(TimeoutKind::StreamBind)) => {}
1500 other => panic!("expected StreamBind timeout, got: {:?}", other),
1501 }
1502 }
1503
1504 #[tokio::test]
1505 async fn negotiation_timeout() {
1506 use std::time::Duration;
1507
1508 let (conn_client, _conn_server) = mock_connection_pair();
1509
1510 let client_config = Config::new().with_negotiation_timeout(Duration::from_millis(50));
1512
1513 let client_session = Session::new(conn_client, Role::Client, client_config);
1514
1515 let result = client_session.start().await;
1517
1518 assert!(
1520 matches!(result, Err(Error::Timeout(TimeoutKind::Negotiation))),
1521 "expected Negotiation timeout, got: {:?}",
1522 result.as_ref().map(|_| "Ok(SessionHandle)")
1523 );
1524
1525 assert_eq!(client_session.state(), State::Closed);
1527 }
1528
1529 #[tokio::test]
1530 async fn ping_returns_rtt() {
1531 use std::time::Duration;
1532
1533 let (conn_client, conn_server) = mock_connection_pair();
1534
1535 let client_session = Session::new(conn_client, Role::Client, Config::new());
1536 let server_session = Session::new(conn_server, Role::Server, Config::new());
1537
1538 let client_start = tokio::spawn(async move { client_session.start().await });
1540 let server_start = tokio::spawn(async move { server_session.start().await });
1541
1542 let client_handle = client_start.await.unwrap().unwrap();
1543 let mut server_handle = server_start.await.unwrap().unwrap();
1544
1545 let client_inner = Arc::clone(&client_handle.inner);
1547 let mut client_writer = client_handle.writer;
1548 let mut client_reader = client_handle.reader;
1549
1550 let ping_task = tokio::spawn(async move {
1552 let sequence = client_inner.next_ping_seq.fetch_add(1, Ordering::SeqCst);
1554
1555 let (response_tx, response_rx) = oneshot::channel();
1557 let sent_at = Instant::now();
1558
1559 {
1561 let mut pending = client_inner.pending_pings.lock().unwrap();
1562 pending.insert(
1563 sequence,
1564 PendingPing {
1565 sent_at,
1566 response_tx,
1567 },
1568 );
1569 }
1570
1571 let ping_msg = quic_reverse_control::Ping { sequence };
1573 client_writer
1574 .write_message(&ProtocolMessage::Ping(ping_msg))
1575 .await
1576 .unwrap();
1577 client_writer.flush().await.unwrap();
1578
1579 response_rx.await.unwrap();
1581 sent_at.elapsed()
1582 });
1583
1584 let client_inner2 = Arc::clone(&client_handle.inner);
1586 let client_msg_processor = tokio::spawn(async move {
1587 let msg = client_reader.read_message().await.unwrap().unwrap();
1588 if let ProtocolMessage::Pong(pong) = msg {
1589 let mut pending = client_inner2.pending_pings.lock().unwrap();
1590 if let Some(pending_ping) = pending.remove(&pong.sequence) {
1591 let _ = pending_ping.response_tx.send(());
1592 }
1593 }
1594 });
1595
1596 let event = server_handle.process_message().await.unwrap().unwrap();
1598 assert!(matches!(event, ControlEvent::Ping { sequence: 1 }));
1599
1600 let _ = client_msg_processor.await;
1602 let rtt = ping_task.await.unwrap();
1603
1604 assert!(rtt < Duration::from_secs(1));
1606 }
1607
1608 #[tokio::test]
1609 async fn ping_timeout() {
1610 use std::time::Duration;
1611
1612 let (conn_client, conn_server) = mock_connection_pair();
1613
1614 let client_config = Config::new().with_ping_timeout(Duration::from_millis(50));
1616 let server_config = Config::new();
1617
1618 let client_session = Session::new(conn_client, Role::Client, client_config);
1619 let server_session = Session::new(conn_server, Role::Server, server_config);
1620
1621 let client_start = tokio::spawn(async move { client_session.start().await });
1623 let server_start = tokio::spawn(async move { server_session.start().await });
1624
1625 let mut client_handle = client_start.await.unwrap().unwrap();
1626 let _server_handle = server_start.await.unwrap().unwrap();
1627
1628 let result = client_handle.ping().await;
1630
1631 match result {
1633 Err(Error::Timeout(TimeoutKind::Ping)) => {}
1634 other => panic!("expected Ping timeout, got: {other:?}"),
1635 }
1636 }
1637}