1use std::{
51 net::SocketAddr,
52 pin::Pin,
53 sync::{
54 Arc, RwLock,
55 atomic::{AtomicBool, AtomicUsize, Ordering},
56 },
57 time::SystemTime,
58 vec,
59};
60
61use bytes::Bytes;
62use chrono::{DateTime, Utc};
63use http::StatusCode;
64use prost::Message;
65use quinn::{RecvStream, SendStream, VarInt};
66use scion_proto::address::EndhostAddr;
67use scion_sdk_token_validator::validator::{Token, TokenValidator, TokenValidatorError};
68use serde::Deserialize;
69use tokio::sync::watch;
70
71use crate::{
72 AUTH_HEADER, PATH_SOCK_ADDR_ASSIGNMENT, PATH_UPDATE_TOKEN,
73 metrics::{Metrics, ReceiverMetrics, SenderMetrics},
74 requests::{SocketAddrAssignmentResponse, TokenUpdateResponse, unix_epoch_from_system_time},
75};
76
77#[derive(Copy, Clone)]
79pub enum SnaptunConnErrors {
80 InvalidRequest = 1,
82 Timeout = 2,
84 Unauthenticated = 3,
86 TokenExpired = 4,
88 InternalError = 5,
90}
91
92impl From<SnaptunConnErrors> for quinn::VarInt {
93 fn from(e: SnaptunConnErrors) -> Self {
94 VarInt::from_u32(e as u32)
95 }
96}
97
98pub trait SnapTunToken: for<'de> Deserialize<'de> + Token + Clone {}
100impl<T> SnapTunToken for T where T: for<'de> Deserialize<'de> + Token + Clone {}
101
102pub const ACCEPT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(3);
105
106pub const SEND_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(2);
109const MAX_CTRL_MESSAGE_SIZE: usize = 4096;
111
112pub struct Server<T> {
115 metrics: Metrics,
116 validator: Arc<dyn TokenValidator<T>>,
117}
118
119#[derive(Debug, thiserror::Error)]
121pub enum AcceptError {
122 #[error("timeout reached.")]
124 Timeout,
125 #[error("quinn connection error: {0}")]
127 ConnectionError(#[from] quinn::ConnectionError),
128 #[error("parse control request error: {0}")]
130 ParseControlRequestError(#[from] ParseControlRequestError),
131 #[error("send control response error: {0}")]
133 SendControlResponseError(#[from] SendControlResponseError),
134 #[error("unexpected control request")]
136 UnexpectedControlRequest,
137}
138
139impl<T: SnapTunToken> Server<T> {
140 pub fn new(validator: Arc<dyn TokenValidator<T>>, metrics: Metrics) -> Self {
143 Self { validator, metrics }
144 }
145
146 pub async fn accept_with_timeout(
154 &self,
155 conn: quinn::Connection,
156 ) -> Result<(Sender<T>, Receiver<T>, Control), AcceptError> {
157 match tokio::time::timeout(ACCEPT_TIMEOUT, self.accept(conn.clone())).await {
158 Ok(res) => res,
159 Err(_elapsed) => {
160 conn.close(
161 SnaptunConnErrors::Timeout.into(),
162 b"timeout establishing snaptun",
163 );
164 Err(AcceptError::Timeout)
165 }
166 }
167 }
168
169 async fn accept(
176 &self,
177 conn: quinn::Connection,
178 ) -> Result<(Sender<T>, Receiver<T>, Control), AcceptError> {
179 let state_machine = Arc::new(TunnelStateMachine::new(
180 conn.remote_address(),
181 self.validator.clone(),
182 ));
183
184 let (token_update_req, mut snd, _rcv) = receive_expected_control_request(
187 &conn,
188 |r| matches!(r, ControlRequest::TokenUpdate(_)),
189 b"expected token update request",
190 )
191 .await?;
192
193 let now = SystemTime::now();
194 tracing::debug!(?now, request=?token_update_req, "Got token update request");
195
196 let (code, body) = state_machine.process_control_request(now, token_update_req);
197 let send_res = send_http_response(&mut snd, code, &body).await;
198 if !code.is_success() {
199 conn.close(SnaptunConnErrors::InvalidRequest.into(), &body);
200 return Err(AcceptError::UnexpectedControlRequest);
201 }
202 if let Err(e) = send_res {
203 conn.close(
204 SnaptunConnErrors::InternalError.into(),
205 b"failed to send control response",
206 );
207 return Err(AcceptError::SendControlResponseError(e));
208 }
209
210 let (address_assign_request, mut snd, _rcv) = receive_expected_control_request(
212 &conn,
213 |r| matches!(r, ControlRequest::SocketAddrAssignment { .. }),
214 b"expected socket addr assignment request",
215 )
216 .await?;
217
218 let now = SystemTime::now();
219
220 tracing::debug!(?now, request=?address_assign_request, "Got address assignment request");
221
222 let (code, body) = state_machine.process_control_request(now, address_assign_request);
223 let send_res = send_http_response(&mut snd, code, &body).await;
224 if !code.is_success() {
225 conn.close(SnaptunConnErrors::InvalidRequest.into(), &body);
226 return Err(AcceptError::UnexpectedControlRequest);
227 }
228 if let Err(e) = send_res {
229 conn.close(
230 SnaptunConnErrors::InternalError.into(),
231 b"failed to send control response",
232 );
233 return Err(AcceptError::SendControlResponseError(e));
234 }
235
236 let initial_state_version = state_machine.state_version();
237 Ok((
238 Sender::new(
239 state_machine.get_socket_addr(),
240 state_machine.get_addresses().expect("assigned state"),
241 conn.clone(),
242 state_machine.clone(),
243 initial_state_version,
244 self.metrics.sender_metrics.clone(),
245 ),
246 Receiver::new(
247 conn.clone(),
248 state_machine.clone(),
249 initial_state_version,
250 self.metrics.receiver_metrics.clone(),
251 ),
252 Control::new(conn, state_machine.clone()),
253 ))
254 }
255}
256
257async fn receive_expected_control_request(
258 conn: &quinn::Connection,
259 expected: fn(&ControlRequest) -> bool,
260 wrong_request_conn_close_reason: &'static [u8],
261) -> Result<(ControlRequest, SendStream, RecvStream), AcceptError> {
262 let (snd, mut rcv) = conn
263 .accept_bi()
264 .await
265 .map_err(AcceptError::ConnectionError)?;
266 let mut buf = vec![0u8; MAX_CTRL_MESSAGE_SIZE];
267 let req = match recv_request(&mut buf, &mut rcv).await {
268 Ok(req) if expected(&req) => req,
269 Ok(_) => {
270 conn.close(
271 SnaptunConnErrors::InvalidRequest.into(),
272 wrong_request_conn_close_reason,
273 );
274 return Err(AcceptError::UnexpectedControlRequest);
275 }
276 Err(err) => {
277 handle_invalid_request(conn, &err);
278 return Err(err.into());
279 }
280 };
281 Ok((req, snd, rcv))
282}
283
284pub struct Sender<T: SnapTunToken> {
289 assigned_socket_addr: Option<SocketAddr>,
290 metrics: SenderMetrics,
291 addresses: Vec<EndhostAddr>,
292 conn: quinn::Connection,
293 state_machine: Arc<TunnelStateMachine<T>>,
294 last_state_version: AtomicUsize,
295 is_closed: AtomicBool,
296}
297
298impl<T: SnapTunToken> Sender<T> {
299 fn new(
300 assigned_socket_addr: Option<SocketAddr>,
301 addresses: Vec<EndhostAddr>,
302 conn: quinn::Connection,
303 state_machine: Arc<TunnelStateMachine<T>>,
304 initial_state_version: usize,
305 metrics: SenderMetrics,
306 ) -> Self {
307 Self {
308 assigned_socket_addr,
309 addresses,
310 conn,
311 state_machine,
312 last_state_version: AtomicUsize::new(initial_state_version),
313 is_closed: AtomicBool::new(false),
314 metrics,
315 }
316 }
317
318 pub fn assigned_addresses(&self) -> Vec<EndhostAddr> {
320 self.addresses.clone()
321 }
322
323 pub fn assigned_socket_addr(&self) -> Option<SocketAddr> {
325 self.assigned_socket_addr
326 }
327
328 pub fn remote_underlay_address(&self) -> SocketAddr {
330 self.conn.remote_address()
331 }
332
333 pub fn send(&self, pkt: Bytes) -> Result<(), SendPacketError<T>> {
343 let pkt = self.validate_tun(pkt)?;
344 self.conn.send_datagram(pkt)?;
345 self.metrics.datagrams_sent_total.inc();
346 Ok(())
347 }
348
349 pub async fn send_wait(&self, pkt: Bytes) -> Result<(), SendPacketError<T>> {
354 let pkt = self.validate_tun(pkt)?;
355 self.conn.send_datagram_wait(pkt).await?;
356 Ok(())
357 }
358
359 pub fn close(&self, error_code: SnaptunConnErrors, reason: &[u8]) {
363 self.conn.close(error_code.into(), reason)
364 }
365
366 fn validate_tun(&self, pkt: Bytes) -> Result<Bytes, SendPacketError<T>> {
367 if self.is_closed.load(Ordering::Acquire) {
369 return Err(SendPacketError::ConnectionClosed);
370 }
371 let current_state_version = self.state_machine.state_version();
373 if self
374 .last_state_version
375 .compare_exchange(
376 current_state_version - 1,
377 current_state_version,
378 Ordering::AcqRel,
379 Ordering::Acquire,
380 )
381 .is_ok()
382 {
383 if self.state_machine.is_closed() {
386 self.is_closed.store(true, Ordering::Release);
387 return Err(SendPacketError::ConnectionClosed);
388 }
389 let addresses = self.state_machine.get_addresses()?;
391
392 return Err(SendPacketError::NewAssignedAddress((
394 Box::new(Sender::new(
395 self.state_machine.get_socket_addr(),
396 addresses,
397 self.conn.clone(),
398 self.state_machine.clone(),
399 current_state_version,
400 self.metrics.clone(),
401 )),
402 pkt,
403 )));
404 }
405
406 Ok(pkt)
407 }
408}
409
410impl<T: SnapTunToken> std::fmt::Debug for Sender<T> {
411 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
412 f.debug_struct("Sender")
413 .field("addresses", &self.addresses)
414 .field("conn", &self.conn.stable_id())
415 .field("last_state_version", &self.last_state_version)
416 .finish()
417 }
418}
419
420#[derive(Debug, thiserror::Error)]
422pub enum SendPacketError<T: SnapTunToken> {
423 #[error("connection closed")]
425 ConnectionClosed,
426 #[error("address was re-assigned")]
428 NewAssignedAddress((Box<Sender<T>>, Bytes)),
429 #[error("address assignment error: {0}")]
431 AddressAssignmentError(#[from] AddressAssignmentError),
432 #[error("underlying send error")]
434 SendDatagramError(#[from] quinn::SendDatagramError),
435}
436
437pub struct Receiver<T: SnapTunToken> {
440 metrics: ReceiverMetrics,
441 conn: quinn::Connection,
442 state_machine: Arc<TunnelStateMachine<T>>,
443 last_state_version: AtomicUsize,
444 is_closed: AtomicBool,
445}
446
447#[derive(Debug, thiserror::Error)]
449pub enum ReceivePacketError {
450 #[error("quinn error: {0}")]
452 ConnectionError(#[from] quinn::ConnectionError),
453 #[error("connection closed")]
455 ConnectionClosed,
456}
457
458impl<T: SnapTunToken> Receiver<T> {
459 fn new(
460 conn: quinn::Connection,
461 state_machine: Arc<TunnelStateMachine<T>>,
462 initial_state_version: usize,
463 metrics: ReceiverMetrics,
464 ) -> Self {
465 Self {
466 conn,
467 state_machine,
468 last_state_version: AtomicUsize::new(initial_state_version),
469 is_closed: AtomicBool::new(false),
470 metrics,
471 }
472 }
473
474 pub async fn receive(&self) -> Result<Bytes, ReceivePacketError> {
476 let current_state_version = self.state_machine.state_version();
478 if self
479 .last_state_version
480 .compare_exchange(
481 current_state_version - 1,
482 current_state_version,
483 Ordering::AcqRel,
484 Ordering::Acquire,
485 )
486 .is_ok()
487 {
488 if self.state_machine.is_closed() {
490 self.is_closed.store(true, Ordering::Release);
491 }
492 }
493 if self.is_closed.load(Ordering::Acquire) {
494 return Err(ReceivePacketError::ConnectionClosed);
495 }
496 let p = self.conn.read_datagram().await?;
497 self.metrics.datagrams_received_total.inc();
498 Ok(p)
499 }
500}
501
502#[derive(Debug, thiserror::Error)]
504pub enum ControlError {
505 #[error("parse control request error: {0}")]
507 ParseError(#[from] ParseControlRequestError),
508 #[error("send control response error: {0}")]
510 SendError(#[from] SendControlResponseError),
511 #[error("wait for completion error: {0}")]
513 StoppedError(#[from] quinn::StoppedError),
514 #[error("token expired")]
516 TokenExpired,
517 #[error("connection closed prematurely")]
519 ClosedPrematurely,
520}
521
522pub struct Control {
525 driver_fut: Pin<Box<dyn Future<Output = Result<(), ControlError>> + Send>>,
526}
527
528impl Control {
529 fn new<T>(conn: quinn::Connection, tunnel_state: Arc<TunnelStateMachine<T>>) -> Self
530 where
531 T: for<'de> Deserialize<'de> + Token + Clone,
532 {
533 let fut = async move {
534 loop {
535 tokio::select! {
536 _ = tunnel_state.await_token_expiry() => {
537 tunnel_state.shutdown();
539 conn.close(SnaptunConnErrors::TokenExpired.into(), b"token expired");
540 return Err(ControlError::TokenExpired)
541 }
542 res = conn.accept_bi() => {
543 let (mut snd, mut rcv) = match res {
544 Ok(v) => v,
545 Err(quinn::ConnectionError::ApplicationClosed(_)) => {
546 tunnel_state.shutdown();
547 return Ok(());
548 }
549 Err(_) => {
550 tunnel_state.shutdown();
551 return Err(ControlError::ClosedPrematurely);
552 }
553 };
554
555 let mut buf = vec![0u8; MAX_CTRL_MESSAGE_SIZE];
556 let control_request = recv_request(&mut buf, &mut rcv).await.inspect_err(|err| {
557 handle_invalid_request(&conn, err);
558 tunnel_state.shutdown();
559 })?;
560
561 let (code, body) = tunnel_state.process_control_request(SystemTime::now(), control_request);
562 send_http_response(&mut snd, code, &body).await
563 .inspect_err(|_| {
564 tunnel_state.shutdown();
565 conn.close(SnaptunConnErrors::InternalError.into(), b"send control response error");
566 })?;
567
568 snd.stopped().await?;
569 }
570 }
571 }
572 };
573 let driver_fut = Box::pin(fut);
574 Self { driver_fut }
575 }
576}
577
578impl Future for Control {
579 type Output = Result<(), ControlError>;
580
581 fn poll(
582 mut self: std::pin::Pin<&mut Self>,
583 cx: &mut std::task::Context<'_>,
584 ) -> std::task::Poll<Self::Output> {
585 self.driver_fut.as_mut().poll(cx)
586 }
587}
588
589#[derive(Debug, thiserror::Error)]
591pub enum AddressAssignmentError {
592 #[error("no address assigned")]
594 NoAddressAssigned,
595}
596
597pub struct TunnelStateMachine<T: SnapTunToken> {
607 remote_sock_addr: SocketAddr,
608 validator: Arc<dyn TokenValidator<T>>,
609 inner_state: RwLock<TunnelState>,
610 state_version: AtomicUsize,
611 sender: watch::Sender<()>,
613 receiver: watch::Receiver<()>,
614}
615
616impl<T: SnapTunToken> Drop for TunnelStateMachine<T> {
617 fn drop(&mut self) {
618 self.shutdown();
620 }
621}
622
623impl<T: SnapTunToken> TunnelStateMachine<T> {
624 pub(crate) fn new(remote_sock_addr: SocketAddr, validator: Arc<dyn TokenValidator<T>>) -> Self {
625 let (sender, receiver) = watch::channel(());
626
627 Self {
628 remote_sock_addr,
629 validator,
630 inner_state: Default::default(),
631 state_version: AtomicUsize::new(0),
632 sender,
633 receiver,
634 }
635 }
636
637 fn process_control_request(
640 &self,
641 now: SystemTime,
642 control_request: ControlRequest,
643 ) -> (http::StatusCode, Vec<u8>) {
644 let mut inner_state = self.inner_state.write().expect("no fail");
645
646 if let TunnelState::Closed = *inner_state {
647 return (http::StatusCode::BAD_REQUEST, "tunnel is closed".into());
648 }
649 match control_request {
650 ControlRequest::SocketAddrAssignment(token) => {
651 self.locked_process_socket_addr_assignment_request(&mut inner_state, now, token)
652 }
653 ControlRequest::TokenUpdate(token) => {
654 self.locked_process_token_update(&mut inner_state, now, token)
655 }
656 }
657 }
658
659 fn locked_process_token_update(
660 &self,
661 inner_state: &mut TunnelState,
662 now: SystemTime,
663 token: String,
664 ) -> (http::StatusCode, Vec<u8>) {
665 match self.validator.validate(now, &token) {
666 Ok(claims) => {
667 let token_expiry = claims.exp_time();
668
669 self.locked_update_tunnel_expiry(inner_state, token_expiry);
671
672 let resp = TokenUpdateResponse {
673 valid_until: unix_epoch_from_system_time(token_expiry),
674 };
675
676 let mut resp_body = vec![];
677 resp.encode(&mut resp_body).expect("no fail");
678 (StatusCode::OK, resp_body)
679 }
680 Err(e) => map_token_validation_err_to_response(e),
681 }
682 }
683
684 fn locked_process_socket_addr_assignment_request(
685 &self,
686 inner_state: &mut TunnelState,
687 now: SystemTime,
688 token: String,
689 ) -> (http::StatusCode, Vec<u8>) {
690 let token_expiry = match inner_state.token_validity() {
694 Ok(v) => v,
695 Err(err) => {
696 tracing::error!(
697 ?err,
698 "Failed to get token validity when processing address assignment request"
699 );
700 return (
703 StatusCode::INTERNAL_SERVER_ERROR,
704 "invalid state transition".into(),
705 );
706 }
707 };
708 match self.validator.validate(now, &token) {
709 Ok(_claims) => {
710 self.locked_update_state(
711 inner_state,
712 TunnelState::SockAddrAssigned { token_expiry },
713 );
714 let resp = SocketAddrAssignmentResponse::from(self.remote_sock_addr);
715
716 let mut resp_body = vec![];
717 resp.encode(&mut resp_body).expect("no fail");
718 (StatusCode::OK, resp_body)
719 }
720 Err(e) => map_token_validation_err_to_response(e),
721 }
722 }
723
724 fn locked_update_tunnel_expiry(&self, inner_state: &mut TunnelState, token_expiry: SystemTime) {
725 match inner_state {
726 TunnelState::Unassigned => {
727 *inner_state = TunnelState::SessionEstablished { token_expiry };
728 }
729 TunnelState::SessionEstablished { .. } => {
730 *inner_state = TunnelState::SessionEstablished { token_expiry };
731 }
732 TunnelState::SockAddrAssigned { .. } => {
733 *inner_state = TunnelState::SockAddrAssigned { token_expiry }
734 }
735 TunnelState::Closed => {
736 tracing::error!("Updating tunnel token expiry but in closed state")
737 }
738 };
739 }
740
741 fn locked_update_state(&self, inner_state: &mut TunnelState, new_state: TunnelState) {
742 tracing::debug!(%new_state, "Updating tunnel state");
743 *inner_state = new_state;
744
745 self.state_version.fetch_add(1, Ordering::AcqRel);
746
747 if self.sender.send(()).is_err() {
748 tracing::debug!("Failed to notify token expiry update");
751 }
752 }
753
754 fn get_addresses(&self) -> Result<Vec<EndhostAddr>, AddressAssignmentError> {
755 let guard = self.inner_state.read().expect("no fail");
756
757 match &*guard {
758 TunnelState::SockAddrAssigned { .. } => Ok(vec![]),
759 _ => Err(AddressAssignmentError::NoAddressAssigned),
760 }
761 }
762
763 fn get_socket_addr(&self) -> Option<SocketAddr> {
764 let guard = self.inner_state.read().expect("no fail");
765 if let TunnelState::SockAddrAssigned { .. } = &*guard {
766 return Some(self.remote_sock_addr);
767 }
768 None
769 }
770
771 async fn await_token_expiry(&self) {
772 let mut expiry_notifier = self.receiver.clone();
773 loop {
774 let valid_duration = {
775 let res = {
776 let guard = self.inner_state.read().expect("no fail");
777 guard.token_validity()
778 };
779 match res {
780 Ok(token_validity) => {
781 match token_validity.duration_since(SystemTime::now()) {
782 Ok(dur) => dur,
783 Err(_) => return, }
785 }
786 Err(err) => {
787 tracing::warn!(%err, "Tunnel in an invalid state");
790 return;
791 }
792 }
793 };
794
795 tokio::select! {
796 _ = expiry_notifier.changed() => {
797 continue;
799 }
800 _ = tokio::time::sleep(valid_duration) => {
801 return;
803 }
804 }
805 }
806 }
807
808 fn state_version(&self) -> usize {
809 self.state_version.load(Ordering::Acquire)
810 }
811
812 fn is_closed(&self) -> bool {
813 if let TunnelState::Closed = *self.inner_state.read().expect("no fail") {
814 return true;
815 }
816 false
817 }
818
819 fn shutdown(&self) {
820 let mut inner_state = self.inner_state.write().expect("no fail");
821 self.locked_update_state(&mut inner_state, TunnelState::Closed);
822 }
823}
824
825fn map_token_validation_err_to_response(value: TokenValidatorError) -> (StatusCode, Vec<u8>) {
826 match value {
827 TokenValidatorError::JwtSignatureInvalid() => {
828 tracing::info!("Invalid JWT Signature");
829 (StatusCode::UNAUTHORIZED, "unauthorized".into())
830 }
831 TokenValidatorError::JwtError(err) => {
832 tracing::info!(?err, "Token validation failed");
833 (StatusCode::UNAUTHORIZED, "unauthorized".into())
834 }
835 TokenValidatorError::TokenExpired(err) => {
836 tracing::info!(?err, "Token validation failed: token expired");
837 (StatusCode::UNAUTHORIZED, "unauthorized".into())
838 }
839 }
840}
841
842#[derive(Debug, thiserror::Error)]
843enum TunnelStateError {
844 #[error("invalid state: {0}")]
845 InvalidState(TunnelState),
846}
847
848#[derive(Debug, Clone, Default)]
849enum TunnelState {
850 #[default]
851 Unassigned,
852 SessionEstablished {
853 token_expiry: SystemTime,
854 },
855 SockAddrAssigned {
856 token_expiry: SystemTime,
857 },
858 Closed,
859}
860
861impl TunnelState {
862 fn token_validity(&self) -> Result<SystemTime, TunnelStateError> {
863 match self {
864 TunnelState::SessionEstablished { token_expiry } => Ok(*token_expiry),
865 TunnelState::SockAddrAssigned { token_expiry, .. } => Ok(*token_expiry),
866 _ => Err(TunnelStateError::InvalidState(self.clone())),
867 }
868 }
869}
870
871impl std::fmt::Display for TunnelState {
872 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
873 match self {
874 TunnelState::Unassigned => write!(f, "Unassigned"),
875 TunnelState::SessionEstablished { token_expiry } => {
876 write!(
877 f,
878 "SessionEstablished ({})",
879 DateTime::<Utc>::from(*token_expiry)
880 )
881 }
882 TunnelState::Closed => write!(f, "Closed"),
883 TunnelState::SockAddrAssigned { token_expiry } => {
884 write!(
885 f,
886 "Remote socket address assigned (valid until: {}).",
887 DateTime::<Utc>::from(*token_expiry),
888 )
889 }
890 }
891 }
892}
893
894#[derive(Debug)]
895enum ControlRequest {
896 SocketAddrAssignment(String),
897 TokenUpdate(String),
898}
899
900fn handle_invalid_request(conn: &quinn::Connection, err: &ParseControlRequestError) {
901 match err {
902 ParseControlRequestError::ClosedPrematurely => {
903 conn.close(
904 SnaptunConnErrors::InternalError.into(),
905 b"closed prematurely",
906 );
907 }
908 ParseControlRequestError::ReadError(_) => {
909 conn.close(SnaptunConnErrors::InternalError.into(), b"read error");
910 }
911 ParseControlRequestError::InvalidRequest(reason) => {
912 conn.close(SnaptunConnErrors::InvalidRequest.into(), reason.as_bytes());
913 }
914 ParseControlRequestError::Unauthenticated(reason) => {
915 conn.close(SnaptunConnErrors::Unauthenticated.into(), reason.as_bytes());
916 }
917 }
918}
919
920#[derive(Debug, thiserror::Error)]
922pub enum ParseControlRequestError {
923 #[error("invalid request: {0}")]
925 InvalidRequest(String),
926 #[error("read error: {0}")]
928 ReadError(#[from] quinn::ReadError),
929 #[error("unauthenticated: {0}")]
931 Unauthenticated(String),
932 #[error("closed prematurely")]
934 ClosedPrematurely,
935}
936
937async fn recv_request(
949 buf: &mut [u8],
950 rcv: &mut RecvStream,
951) -> Result<ControlRequest, ParseControlRequestError> {
952 use ParseControlRequestError::*;
953 let mut cursor = 0;
954
955 while let Some(n) = rcv.read(&mut buf[cursor..]).await? {
957 cursor += n;
958 let mut headers = [httparse::EMPTY_HEADER; 16];
959 let mut req = httparse::Request::new(&mut headers);
960
961 let Ok(httparse::Status::Complete(_body_offset)) = req.parse(&buf[..cursor]) else {
963 if cursor >= buf.len() {
965 return Err(InvalidRequest("request too big".into()));
966 }
967 continue;
968 };
969
970 if !matches!(req.method, Some("POST")) {
972 return Err(InvalidRequest("invalid method".into()));
973 }
974
975 match req.path {
978 Some(PATH_SOCK_ADDR_ASSIGNMENT) => {}
979 Some(PATH_UPDATE_TOKEN) => {}
980 Some(_) | None => return Err(InvalidRequest("invalid path".into())),
981 }
982
983 let Some(auth_header) = req.headers.iter().find(|h| h.name == AUTH_HEADER) else {
985 return Err(Unauthenticated("no auth header".into()));
986 };
987 let bearer_token = auth_header
988 .value
989 .strip_prefix(b"Bearer ")
990 .ok_or(Unauthenticated(
991 "bearer not found in authorization header".into(),
992 ))
993 .map(|x| String::from_utf8_lossy(x).to_string())?;
994
995 let path = req.path.unwrap();
997 match path {
998 PATH_SOCK_ADDR_ASSIGNMENT => {
999 return Ok(ControlRequest::SocketAddrAssignment(bearer_token));
1000 }
1001 PATH_UPDATE_TOKEN => return Ok(ControlRequest::TokenUpdate(bearer_token)),
1002 path => unreachable!("invalid path: {path}"),
1003 }
1004 }
1005
1006 Err(ClosedPrematurely)
1007}
1008
1009#[derive(Debug, thiserror::Error)]
1011pub enum SendControlResponseError {
1012 #[error("i/o error: {0}")]
1014 IoError(#[from] std::io::Error),
1015 #[error("stream closed: {0}")]
1017 ClosedStream(#[from] quinn::ClosedStream),
1018}
1019
1020async fn send_http_response(
1022 stream: &mut SendStream,
1023 code: http::StatusCode,
1024 body: &[u8],
1025) -> Result<(), SendControlResponseError> {
1026 async fn write_all(stream: &mut SendStream, data: &[u8]) -> std::io::Result<()> {
1028 let mut cursor = 0;
1029 while cursor < data.len() {
1030 cursor += stream.write(&data[cursor..]).await?;
1031 }
1032 Ok(())
1033 }
1034
1035 write_all(
1036 stream,
1037 format!(
1038 "HTTP/1.1 {} {}\r\nContent-Length: {}\r\n\r\n",
1039 code.as_str(),
1040 code.canonical_reason().unwrap_or(""),
1041 body.len(),
1042 )
1043 .as_bytes(),
1044 )
1045 .await?;
1046 write_all(stream, body).await?;
1047
1048 stream.finish()?;
1050 Ok(())
1051}