1#![allow(clippy::unwrap_used, clippy::missing_panics_doc)]
20
21use crate::control::{ControlReader, ControlWriter};
44use crate::error::TimeoutKind;
45use crate::registry::OpenResult;
46use crate::session::{PendingPing, SessionHandle, SessionInner};
47use crate::{Error, State};
48use quic_reverse_control::{
49 CloseCode, Metadata, OpenRequest, OpenResponse, OpenStatus, ProtocolMessage, RejectCode,
50 ServiceId, StreamBind, StreamClose,
51};
52use quic_reverse_transport::Connection;
53use std::sync::atomic::Ordering;
54use std::sync::Arc;
55use std::time::{Duration, Instant};
56use tokio::io::AsyncReadExt;
57use tokio::sync::{mpsc, oneshot, Mutex};
58use tokio::task::JoinHandle;
59use tokio::time::timeout;
60use tracing::{debug, error, info, trace, warn};
61
62#[derive(Debug, Clone)]
64pub enum ClientEvent {
65 OpenRequest {
67 request_id: u64,
69 service: ServiceId,
71 metadata: Metadata,
73 },
74 StreamClosed {
76 logical_stream_id: u64,
78 code: CloseCode,
80 },
81 PingReceived {
83 sequence: u64,
85 },
86 Closing {
88 code: CloseCode,
90 reason: Option<String>,
92 },
93}
94
95pub struct SessionClient<C: Connection> {
125 inner: Arc<SessionInner<C>>,
127 writer: Arc<Mutex<ControlWriter<C::SendStream>>>,
129 processor_handle: Arc<JoinHandle<()>>,
131}
132
133impl<C: Connection> Clone for SessionClient<C> {
134 fn clone(&self) -> Self {
135 Self {
136 inner: Arc::clone(&self.inner),
137 writer: Arc::clone(&self.writer),
138 processor_handle: Arc::clone(&self.processor_handle),
139 }
140 }
141}
142
143impl<C: Connection> SessionClient<C> {
144 pub fn new(handle: SessionHandle<C>) -> Self {
150 let (client, _events) = Self::with_events(handle);
151 client
152 }
153
154 pub fn with_events(handle: SessionHandle<C>) -> (Self, mpsc::Receiver<ClientEvent>) {
160 let (event_tx, event_rx) = mpsc::channel(64);
162
163 let inner = handle.inner;
165 let writer = Arc::new(Mutex::new(handle.writer));
166 let reader = handle.reader;
167
168 let processor_inner = Arc::clone(&inner);
170 let processor_writer = Arc::clone(&writer);
171 let processor_handle = tokio::spawn(async move {
172 run_message_processor(processor_inner, processor_writer, reader, event_tx).await;
173 });
174
175 let client = Self {
176 inner,
177 writer,
178 processor_handle: Arc::new(processor_handle),
179 };
180
181 (client, event_rx)
182 }
183
184 #[must_use]
186 pub fn state(&self) -> State {
187 State::from_u8(self.inner.state.load(Ordering::SeqCst))
188 }
189
190 #[must_use]
192 pub fn is_ready(&self) -> bool {
193 self.state() == State::Ready
194 }
195
196 #[must_use]
198 pub fn connection(&self) -> &C {
199 &self.inner.connection
200 }
201
202 pub async fn open(
215 &self,
216 service: impl Into<ServiceId>,
217 metadata: Metadata,
218 ) -> Result<(C::SendStream, C::RecvStream), Error> {
219 if !self.is_ready() {
220 return Err(Error::SessionClosed);
221 }
222
223 let service = service.into();
224
225 let (response_tx, response_rx) = oneshot::channel();
227
228 let request_id = {
230 let mut registry = self.inner.registry.lock().unwrap();
231 let request_id = registry.next_request_id();
232 let request =
233 OpenRequest::new(request_id, service.clone()).with_metadata(metadata.clone());
234 if registry.register_pending(&request, response_tx).is_none() {
235 return Err(Error::CapacityExceeded("too many pending open requests"));
236 }
237 request_id
238 };
239
240 debug!(request_id, service = %service.as_str(), "sending open request");
241
242 {
244 let mut writer = self.writer.lock().await;
245 let request = OpenRequest::new(request_id, service).with_metadata(metadata);
246 writer
247 .write_message(&ProtocolMessage::OpenRequest(request))
248 .await?;
249 writer.flush().await?;
250 }
251
252 let open_timeout = self.inner.config.open_timeout;
254 let result = match timeout(open_timeout, response_rx).await {
255 Ok(Ok(result)) => result,
256 Ok(Err(_)) => {
257 let mut registry = self.inner.registry.lock().unwrap();
259 registry.take_pending(request_id);
260 return Err(Error::SessionClosed);
261 }
262 Err(_) => {
263 let mut registry = self.inner.registry.lock().unwrap();
265 registry.take_pending(request_id);
266 return Err(Error::Timeout(TimeoutKind::OpenRequest));
267 }
268 };
269
270 match result {
271 OpenResult::Accepted { logical_stream_id } => {
272 debug!(request_id, logical_stream_id, "open request accepted");
273
274 let bind_timeout = self.inner.config.stream_bind_timeout;
276 let stream_result = timeout(bind_timeout, self.inner.connection.accept_bi()).await;
277
278 match stream_result {
279 Ok(Ok(Some((send, mut recv)))) => {
280 let mut bind_buf = [0u8; StreamBind::ENCODED_SIZE];
282 let read_result =
283 timeout(bind_timeout, recv.read_exact(&mut bind_buf)).await;
284
285 match read_result {
286 Ok(Ok(_)) => {
287 match StreamBind::decode(&bind_buf) {
289 Some(bind) if bind.logical_stream_id == logical_stream_id => {
290 info!(
291 request_id,
292 logical_stream_id, "stream bound successfully"
293 );
294 Ok((send, recv))
295 }
296 Some(bind) => {
297 warn!(
298 request_id,
299 expected = logical_stream_id,
300 received = bind.logical_stream_id,
301 "stream bind ID mismatch"
302 );
303 Err(Error::protocol_violation(format!(
304 "stream bind ID mismatch: expected {}, got {}",
305 logical_stream_id, bind.logical_stream_id
306 )))
307 }
308 None => {
309 warn!(request_id, "invalid stream bind frame");
310 Err(Error::protocol_violation("invalid stream bind frame"))
311 }
312 }
313 }
314 Ok(Err(e)) => {
315 warn!(request_id, error = %e, "failed to read stream bind");
316 Err(Error::Transport(Box::new(e)))
317 }
318 Err(_) => {
319 warn!(request_id, "timeout reading stream bind");
320 Err(Error::Timeout(TimeoutKind::StreamBind))
321 }
322 }
323 }
324 Ok(Ok(None)) => Err(Error::protocol_violation(
325 "connection closed while waiting for stream",
326 )),
327 Ok(Err(e)) => Err(Error::Transport(Box::new(e))),
328 Err(_) => Err(Error::Timeout(TimeoutKind::StreamBind)),
329 }
330 }
331 OpenResult::Rejected { code, reason } => {
332 warn!(request_id, ?code, ?reason, "open request rejected");
333 Err(Error::StreamRejected { code, reason })
334 }
335 }
336 }
337
338 pub async fn accept_open(&self, request_id: u64, logical_stream_id: u64) -> Result<(), Error> {
348 let mut writer = self.writer.lock().await;
349 let response = OpenResponse::accepted(request_id, logical_stream_id);
350 writer
351 .write_message(&ProtocolMessage::OpenResponse(response))
352 .await?;
353 writer.flush().await
354 }
355
356 pub async fn reject_open(
365 &self,
366 request_id: u64,
367 code: RejectCode,
368 reason: Option<String>,
369 ) -> Result<(), Error> {
370 let mut writer = self.writer.lock().await;
371 let response = OpenResponse::rejected(request_id, code, reason);
372 writer
373 .write_message(&ProtocolMessage::OpenResponse(response))
374 .await?;
375 writer.flush().await
376 }
377
378 pub async fn bind_stream<S: tokio::io::AsyncWriteExt + Unpin>(
402 &self,
403 send: &mut S,
404 logical_stream_id: u64,
405 ) -> Result<(), Error> {
406 let bind_frame = StreamBind::new(logical_stream_id);
407 send.write_all(&bind_frame.encode())
408 .await
409 .map_err(|e| Error::Transport(Box::new(e)))?;
410 send.flush()
411 .await
412 .map_err(|e| Error::Transport(Box::new(e)))?;
413 Ok(())
414 }
415
416 pub async fn ping(&self) -> Result<Duration, Error> {
424 if !self.is_ready() {
425 return Err(Error::SessionClosed);
426 }
427
428 let sequence = self.inner.next_ping_seq.fetch_add(1, Ordering::SeqCst);
429 let (response_tx, response_rx) = oneshot::channel();
430 let sent_at = Instant::now();
431
432 {
434 let mut pending = self.inner.pending_pings.lock().unwrap();
435 pending.insert(
436 sequence,
437 PendingPing {
438 sent_at,
439 response_tx,
440 },
441 );
442 }
443
444 {
446 let mut writer = self.writer.lock().await;
447 let ping = quic_reverse_control::Ping { sequence };
448 writer.write_message(&ProtocolMessage::Ping(ping)).await?;
449 writer.flush().await?;
450 }
451
452 let ping_timeout = self.inner.config.ping_timeout;
454 match timeout(ping_timeout, response_rx).await {
455 Ok(Ok(())) => {
456 let rtt = sent_at.elapsed();
457 debug!(sequence, ?rtt, "ping completed");
458 Ok(rtt)
459 }
460 Ok(Err(_)) => Err(Error::SessionClosed),
461 Err(_) => {
462 let mut pending = self.inner.pending_pings.lock().unwrap();
463 pending.remove(&sequence);
464 Err(Error::Timeout(TimeoutKind::Ping))
465 }
466 }
467 }
468
469 pub async fn close(&self, code: CloseCode, reason: Option<String>) -> Result<(), Error> {
475 self.inner
476 .state
477 .store(State::Closing as u8, Ordering::SeqCst);
478
479 let mut writer = self.writer.lock().await;
480 let close_msg = StreamClose {
481 logical_stream_id: 0,
482 code,
483 reason,
484 };
485 writer
486 .write_message(&ProtocolMessage::StreamClose(close_msg))
487 .await?;
488 writer.flush().await
489 }
490
491 pub async fn close_stream(
497 &self,
498 logical_stream_id: u64,
499 code: CloseCode,
500 reason: Option<String>,
501 ) -> Result<(), Error> {
502 let mut writer = self.writer.lock().await;
503 let close_msg = StreamClose {
504 logical_stream_id,
505 code,
506 reason,
507 };
508 writer
509 .write_message(&ProtocolMessage::StreamClose(close_msg))
510 .await?;
511 writer.flush().await
512 }
513}
514
515impl<C: Connection> Drop for SessionClient<C> {
516 fn drop(&mut self) {
517 if Arc::strong_count(&self.processor_handle) == 1 {
519 self.processor_handle.abort();
520 }
521 }
522}
523
524async fn run_message_processor<C: Connection>(
526 inner: Arc<SessionInner<C>>,
527 writer: Arc<Mutex<ControlWriter<C::SendStream>>>,
528 mut reader: ControlReader<C::RecvStream>,
529 event_tx: mpsc::Sender<ClientEvent>,
530) {
531 debug!("message processor started");
532 loop {
533 debug!("message processor: waiting for next message");
534 match reader.read_message().await {
535 Ok(Some(msg)) => {
536 debug!(
537 "message processor: received message {:?}",
538 message_type(&msg)
539 );
540 if let Err(should_break) = handle_message(&inner, &writer, msg, &event_tx).await {
541 if should_break {
542 debug!("message processor: breaking loop");
543 break;
544 }
545 }
546 }
547 Ok(None) => {
548 debug!("control stream closed");
549 inner.state.store(State::Closed as u8, Ordering::SeqCst);
550 break;
551 }
552 Err(e) => {
553 error!("message read error: {}", e);
554 inner
555 .state
556 .store(State::Disconnected as u8, Ordering::SeqCst);
557 break;
558 }
559 }
560 }
561 debug!("message processor exited");
562}
563
564const fn message_type(msg: &ProtocolMessage) -> &'static str {
565 match msg {
566 ProtocolMessage::Hello(_) => "Hello",
567 ProtocolMessage::HelloAck(_) => "HelloAck",
568 ProtocolMessage::OpenRequest(_) => "OpenRequest",
569 ProtocolMessage::OpenResponse(_) => "OpenResponse",
570 ProtocolMessage::StreamClose(_) => "StreamClose",
571 ProtocolMessage::Ping(_) => "Ping",
572 ProtocolMessage::Pong(_) => "Pong",
573 }
574}
575
576async fn handle_message<C: Connection>(
579 inner: &Arc<SessionInner<C>>,
580 writer: &Arc<Mutex<ControlWriter<C::SendStream>>>,
581 msg: ProtocolMessage,
582 event_tx: &mpsc::Sender<ClientEvent>,
583) -> Result<(), bool> {
584 match msg {
585 ProtocolMessage::OpenRequest(req) => {
586 let _ = event_tx
588 .send(ClientEvent::OpenRequest {
589 request_id: req.request_id,
590 service: req.service,
591 metadata: req.metadata,
592 })
593 .await;
594 Ok(())
595 }
596
597 ProtocolMessage::OpenResponse(resp) => {
598 let mut registry = inner.registry.lock().unwrap();
600 if let Some(pending) = registry.take_pending(resp.request_id) {
601 let result = match resp.status {
602 OpenStatus::Accepted => OpenResult::Accepted {
603 logical_stream_id: resp.logical_stream_id.unwrap_or(0),
604 },
605 OpenStatus::Rejected(code) => OpenResult::Rejected {
606 code,
607 reason: resp.reason,
608 },
609 };
610 let _ = pending.response_tx.send(result);
611 }
612 Ok(())
613 }
614
615 ProtocolMessage::Ping(ping) => {
616 trace!(sequence = ping.sequence, "received ping, sending pong");
618 let mut w = writer.lock().await;
619 let pong = quic_reverse_control::Pong {
620 sequence: ping.sequence,
621 };
622 if let Err(e) = w.write_message(&ProtocolMessage::Pong(pong)).await {
623 warn!("failed to send pong: {}", e);
624 }
625 let _ = w.flush().await;
626
627 let _ = event_tx
628 .send(ClientEvent::PingReceived {
629 sequence: ping.sequence,
630 })
631 .await;
632 Ok(())
633 }
634
635 ProtocolMessage::Pong(pong) => {
636 let mut pending = inner.pending_pings.lock().unwrap();
638 if let Some(pending_ping) = pending.remove(&pong.sequence) {
639 let _ = pending_ping.response_tx.send(());
640 }
641 Ok(())
642 }
643
644 ProtocolMessage::StreamClose(sc) => {
645 if sc.logical_stream_id == 0 {
646 info!(code = ?sc.code, reason = ?sc.reason, "peer closed session");
648 inner.state.store(State::Closing as u8, Ordering::SeqCst);
649 let _ = event_tx
650 .send(ClientEvent::Closing {
651 code: sc.code,
652 reason: sc.reason,
653 })
654 .await;
655 Err(true) } else {
657 let _ = event_tx
659 .send(ClientEvent::StreamClosed {
660 logical_stream_id: sc.logical_stream_id,
661 code: sc.code,
662 })
663 .await;
664 Ok(())
665 }
666 }
667
668 ProtocolMessage::Hello(_) | ProtocolMessage::HelloAck(_) => {
669 warn!("received unexpected Hello/HelloAck after negotiation");
670 Ok(())
671 }
672 }
673}
674
675#[cfg(test)]
676mod tests {
677 use super::*;
678 use crate::{Config, Role, Session};
679 use quic_reverse_transport::mock_connection_pair;
680
681 async fn create_session_pair() -> (
682 SessionHandle<quic_reverse_transport::MockConnection>,
683 SessionHandle<quic_reverse_transport::MockConnection>,
684 ) {
685 let (conn_client, conn_server) = mock_connection_pair();
686
687 let client_session = Session::new(conn_client, Role::Client, Config::new());
688 let server_session = Session::new(conn_server, Role::Server, Config::new());
689
690 let client_start = tokio::spawn(async move { client_session.start().await });
691 let server_start = tokio::spawn(async move { server_session.start().await });
692
693 let client_handle = client_start.await.unwrap().unwrap();
694 let server_handle = server_start.await.unwrap().unwrap();
695
696 (client_handle, server_handle)
697 }
698
699 #[tokio::test]
700 async fn client_creation() {
701 let (client_handle, _server_handle) = create_session_pair().await;
702
703 let client = SessionClient::new(client_handle);
704 assert!(client.is_ready());
705 }
706
707 #[tokio::test]
708 async fn client_is_cloneable() {
709 let (client_handle, _server_handle) = create_session_pair().await;
710
711 let client = SessionClient::new(client_handle);
712 let client2 = client.clone();
713
714 assert!(client.is_ready());
715 assert!(client2.is_ready());
716 }
717
718 #[tokio::test]
719 async fn ping_pong_via_client() {
720 let (client_handle, server_handle) = create_session_pair().await;
721
722 let client = SessionClient::new(client_handle);
723 let _server = SessionClient::new(server_handle);
724
725 tokio::time::sleep(Duration::from_millis(10)).await;
727
728 let rtt = client.ping().await.expect("ping should succeed");
729 assert!(rtt < Duration::from_secs(1));
730 }
731
732 #[tokio::test]
733 async fn open_and_accept_stream() {
734 let (client_handle, server_handle) = create_session_pair().await;
735
736 let (server_client, mut server_events) = SessionClient::with_events(server_handle);
738 let server_conn = server_client.connection().clone();
739
740 let server_task = tokio::spawn(async move {
742 while let Some(event) = server_events.recv().await {
743 if let ClientEvent::OpenRequest {
744 request_id,
745 service,
746 ..
747 } = event
748 {
749 if service.as_str() == "echo" {
750 let logical_stream_id = 1;
751 server_client
752 .accept_open(request_id, logical_stream_id)
753 .await
754 .unwrap();
755
756 let (mut send, mut recv) = server_conn.open_bi().await.unwrap();
758 server_client
759 .bind_stream(&mut send, logical_stream_id)
760 .await
761 .unwrap();
762
763 use tokio::io::{AsyncReadExt, AsyncWriteExt};
765 let mut buf = [0u8; 32];
766 let n = recv.read(&mut buf).await.unwrap();
767 send.write_all(&buf[..n]).await.unwrap();
768 send.flush().await.unwrap();
769 break;
770 }
771 }
772 }
773 });
774
775 let client = SessionClient::new(client_handle);
777
778 tokio::time::sleep(Duration::from_millis(10)).await;
780
781 let (mut send, mut recv) = client
782 .open("echo", Metadata::Empty)
783 .await
784 .expect("open should succeed");
785
786 use tokio::io::{AsyncReadExt, AsyncWriteExt};
788 send.write_all(b"hello").await.unwrap();
789 send.flush().await.unwrap();
790
791 let mut buf = [0u8; 32];
792 let n = recv.read(&mut buf).await.unwrap();
793 assert_eq!(&buf[..n], b"hello");
794
795 server_task.await.unwrap();
796 }
797
798 #[tokio::test]
799 async fn reject_unknown_service() {
800 let (client_handle, server_handle) = create_session_pair().await;
801
802 let (server_client, mut server_events) = SessionClient::with_events(server_handle);
804
805 let server_task = tokio::spawn(async move {
806 while let Some(event) = server_events.recv().await {
807 if let ClientEvent::OpenRequest {
808 request_id,
809 service,
810 ..
811 } = event
812 {
813 server_client
814 .reject_open(
815 request_id,
816 RejectCode::UnsupportedService,
817 Some(format!("unknown: {}", service.as_str())),
818 )
819 .await
820 .unwrap();
821 break;
822 }
823 }
824 });
825
826 let client = SessionClient::new(client_handle);
827 tokio::time::sleep(Duration::from_millis(10)).await;
828
829 let result = client.open("foobar", Metadata::Empty).await;
830 assert!(matches!(result, Err(Error::StreamRejected { .. })));
831
832 server_task.await.unwrap();
833 }
834
835 #[tokio::test]
836 async fn graceful_close() {
837 let (client_handle, server_handle) = create_session_pair().await;
838
839 let (_server_client, mut server_events) = SessionClient::with_events(server_handle);
840 let client = SessionClient::new(client_handle);
841
842 let server_task = tokio::spawn(async move {
844 while let Some(event) = server_events.recv().await {
845 if let ClientEvent::Closing { code, .. } = event {
846 assert_eq!(code, CloseCode::Normal);
847 break;
848 }
849 }
850 });
851
852 tokio::time::sleep(Duration::from_millis(10)).await;
853
854 client
855 .close(CloseCode::Normal, Some("goodbye".into()))
856 .await
857 .unwrap();
858
859 server_task.await.unwrap();
860 }
861
862 #[tokio::test]
863 async fn stream_bind_mismatch_rejected() {
864 let (client_handle, server_handle) = create_session_pair().await;
865
866 let (server_client, mut server_events) = SessionClient::with_events(server_handle);
868 let server_conn = server_client.connection().clone();
869
870 let server_task = tokio::spawn(async move {
871 while let Some(event) = server_events.recv().await {
872 if let ClientEvent::OpenRequest {
873 request_id,
874 service,
875 ..
876 } = event
877 {
878 if service.as_str() == "test" {
879 server_client.accept_open(request_id, 1).await.unwrap();
881
882 let (mut send, _recv) = server_conn.open_bi().await.unwrap();
884 server_client.bind_stream(&mut send, 99).await.unwrap();
885 break;
886 }
887 }
888 }
889 });
890
891 let client = SessionClient::new(client_handle);
892 tokio::time::sleep(Duration::from_millis(10)).await;
893
894 let result = client.open("test", Metadata::Empty).await;
896 assert!(matches!(result, Err(Error::ProtocolViolation { .. })));
897
898 server_task.await.unwrap();
899 }
900
901 #[tokio::test]
902 async fn stream_bind_invalid_magic_rejected() {
903 let (client_handle, server_handle) = create_session_pair().await;
904
905 let (server_client, mut server_events) = SessionClient::with_events(server_handle);
907 let server_conn = server_client.connection().clone();
908
909 let server_task = tokio::spawn(async move {
910 while let Some(event) = server_events.recv().await {
911 if let ClientEvent::OpenRequest {
912 request_id,
913 service,
914 ..
915 } = event
916 {
917 if service.as_str() == "test" {
918 server_client.accept_open(request_id, 1).await.unwrap();
919
920 let (mut send, _recv) = server_conn.open_bi().await.unwrap();
922 use tokio::io::AsyncWriteExt;
923 send.write_all(&[0u8; StreamBind::ENCODED_SIZE])
924 .await
925 .unwrap();
926 send.flush().await.unwrap();
927 break;
928 }
929 }
930 }
931 });
932
933 let client = SessionClient::new(client_handle);
934 tokio::time::sleep(Duration::from_millis(10)).await;
935
936 let result = client.open("test", Metadata::Empty).await;
938 assert!(matches!(result, Err(Error::ProtocolViolation { .. })));
939
940 server_task.await.unwrap();
941 }
942}