1use std::{
51 net::SocketAddr,
52 pin::Pin,
53 sync::{
54 Arc, RwLock,
55 atomic::{AtomicBool, AtomicUsize, Ordering},
56 },
57 time::SystemTime,
58 vec,
59};
60
61use bytes::Bytes;
62use chrono::{DateTime, Utc};
63use http::StatusCode;
64use ipnet::IpNet;
65use prost::Message;
66use quinn::{RecvStream, SendStream, VarInt};
67use scion_proto::address::{EndhostAddr, IsdAsn};
68use scion_sdk_token_validator::validator::{Token, TokenValidator, TokenValidatorError};
69use serde::Deserialize;
70use tokio::sync::watch;
71use tracing::{debug, error, info, instrument, warn};
72
73use crate::{
74 AUTH_HEADER, AddressAllocation, AddressAllocator, IPV4_WILDCARD, IPV6_WILDCARD,
75 PATH_ADDR_ASSIGNMENT, PATH_SESSION_RENEWAL,
76 metrics::{Metrics, ReceiverMetrics, SenderMetrics},
77 requests::{
78 AddrError, AddressAssignRequest, AddressAssignResponse, SessionRenewalResponse,
79 unix_epoch_from_system_time,
80 },
81};
82
83#[derive(Copy, Clone)]
85pub enum SnaptunConnErrors {
86 InvalidRequest = 1,
88 Timeout = 2,
90 Unauthenticated = 3,
92 SessionExpired = 4,
94 InternalError = 5,
96}
97
98impl From<SnaptunConnErrors> for quinn::VarInt {
99 fn from(e: SnaptunConnErrors) -> Self {
100 VarInt::from_u32(e as u32)
101 }
102}
103
104pub trait SnapTunToken: for<'de> Deserialize<'de> + Token + Clone {}
106impl<T> SnapTunToken for T where T: for<'de> Deserialize<'de> + Token + Clone {}
107
108pub const ACCEPT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(3);
111
112pub const SEND_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(2);
115const CTRL_REQUEST_BUF_SIZE: usize = 4096;
118
119pub struct Server<T> {
122 metrics: Metrics,
123 validator: Arc<dyn TokenValidator<T>>,
124 allocator: Arc<dyn AddressAllocator<T>>,
125}
126
127#[derive(Debug, thiserror::Error)]
129pub enum AcceptError {
130 #[error("timeout reached.")]
132 Timeout,
133 #[error("quinn connection error: {0}")]
135 ConnectionError(#[from] quinn::ConnectionError),
136 #[error("parse control request error: {0}")]
138 ParseControlRequestError(#[from] ParseControlRequestError),
139 #[error("send control response error: {0}")]
141 SendControlResponseError(#[from] SendControlResponseError),
142 #[error("unexpected control request")]
144 UnexpectedControlRequest,
145}
146
147impl<T: SnapTunToken> Server<T> {
148 pub fn new(
151 allocator: Arc<dyn AddressAllocator<T>>,
152 validator: Arc<dyn TokenValidator<T>>,
153 metrics: Metrics,
154 ) -> Self {
155 Self {
156 allocator,
157 validator,
158 metrics,
159 }
160 }
161
162 pub async fn accept_with_timeout(
170 &self,
171 conn: quinn::Connection,
172 ) -> Result<(Sender<T>, Receiver<T>, Control), AcceptError> {
173 match tokio::time::timeout(ACCEPT_TIMEOUT, self.accept(conn.clone())).await {
174 Ok(res) => res,
175 Err(_elapsed) => {
176 conn.close(
177 SnaptunConnErrors::Timeout.into(),
178 b"timeout establishing snaptun",
179 );
180 Err(AcceptError::Timeout)
181 }
182 }
183 }
184
185 #[instrument(name = "SnapTunServer::accept", skip_all, fields(conn_id = conn.stable_id()))]
192 async fn accept(
193 &self,
194 conn: quinn::Connection,
195 ) -> Result<(Sender<T>, Receiver<T>, Control), AcceptError> {
196 let state_machine = Arc::new(TunnelStateMachine::new(
197 self.validator.clone(),
198 self.allocator.clone(),
199 ));
200
201 let (address_assign_request, mut snd, _rcv) = receive_expected_control_request(
204 &conn,
205 |r| matches!(r, ControlRequest::SessionRenewal(_)),
206 b"expected session renewal request",
207 )
208 .await?;
209
210 let now = SystemTime::now();
211 debug!(?now, request=?address_assign_request, "Process expected session renewal request");
212 let (code, body) = state_machine.process_control_request(now, address_assign_request);
213 let send_res = send_http_response(&mut snd, code, &body).await;
214 if !code.is_success() {
215 conn.close(
216 SnaptunConnErrors::InvalidRequest.into(),
217 b"handling session renewal request",
218 );
219 return Err(AcceptError::UnexpectedControlRequest);
220 }
221 if let Err(e) = send_res {
222 conn.close(
223 SnaptunConnErrors::InternalError.into(),
224 b"send control response error",
225 );
226 return Err(AcceptError::SendControlResponseError(e));
227 }
228
229 let (address_assign_request, mut snd, _rcv) = receive_expected_control_request(
232 &conn,
233 |r| matches!(r, ControlRequest::AddressAssignment { .. }),
234 b"expected address assignment request",
235 )
236 .await?;
237
238 let now = SystemTime::now();
239 debug!(?now, request=?address_assign_request, "Process expected address assignment request");
240 let (code, body) = state_machine.process_control_request(now, address_assign_request);
241 let send_res = send_http_response(&mut snd, code, &body).await;
242 if !code.is_success() {
243 conn.close(
244 SnaptunConnErrors::InvalidRequest.into(),
245 b"handling address assignment request",
246 );
247 return Err(AcceptError::UnexpectedControlRequest);
248 }
249 if let Err(e) = send_res {
250 conn.close(
251 SnaptunConnErrors::InternalError.into(),
252 b"send control response error",
253 );
254 return Err(AcceptError::SendControlResponseError(e));
255 }
256
257 let initial_state_version = state_machine.state_version();
258 Ok((
259 Sender::new(
260 state_machine.get_addresses().expect("assigned state"),
261 conn.clone(),
262 state_machine.clone(),
263 initial_state_version,
264 self.metrics.sender_metrics.clone(),
265 ),
266 Receiver::new(
267 conn.clone(),
268 state_machine.clone(),
269 initial_state_version,
270 self.metrics.receiver_metrics.clone(),
271 ),
272 Control::new(conn, state_machine.clone()),
273 ))
274 }
275}
276
277async fn receive_expected_control_request(
278 conn: &quinn::Connection,
279 expected: fn(&ControlRequest) -> bool,
280 wrong_request_conn_close_reason: &'static [u8],
281) -> Result<(ControlRequest, SendStream, RecvStream), AcceptError> {
282 let (snd, mut rcv) = conn
283 .accept_bi()
284 .await
285 .map_err(AcceptError::ConnectionError)?;
286 let mut buf = vec![0u8; CTRL_REQUEST_BUF_SIZE];
287 let req = match parse_http_request(&mut buf, &mut rcv).await {
288 Ok(req) if expected(&req) => req,
289 Ok(_) => {
290 conn.close(
291 SnaptunConnErrors::InvalidRequest.into(),
292 wrong_request_conn_close_reason,
293 );
294 return Err(AcceptError::UnexpectedControlRequest);
295 }
296 Err(err) => {
297 handle_invalid_request(conn, &err);
298 return Err(err.into());
299 }
300 };
301 Ok((req, snd, rcv))
302}
303
304pub struct Sender<T: SnapTunToken> {
309 metrics: SenderMetrics,
310 addresses: Vec<EndhostAddr>,
311 conn: quinn::Connection,
312 state_machine: Arc<TunnelStateMachine<T>>,
313 last_state_version: AtomicUsize,
314 is_closed: AtomicBool,
315}
316
317impl<T: SnapTunToken> Sender<T> {
318 fn new(
319 addresses: Vec<EndhostAddr>,
320 conn: quinn::Connection,
321 state_machine: Arc<TunnelStateMachine<T>>,
322 initial_state_version: usize,
323 metrics: SenderMetrics,
324 ) -> Self {
325 Self {
326 addresses,
327 conn,
328 state_machine,
329 last_state_version: AtomicUsize::new(initial_state_version),
330 is_closed: AtomicBool::new(false),
331 metrics,
332 }
333 }
334
335 pub fn assigned_addresses(&self) -> Vec<EndhostAddr> {
337 self.addresses.clone()
338 }
339
340 pub fn remote_underlay_address(&self) -> SocketAddr {
342 self.conn.remote_address()
343 }
344
345 pub fn send(&self, pkt: Bytes) -> Result<(), SendPacketError<T>> {
355 let pkt = self.validate_tun(pkt)?;
356 self.conn.send_datagram(pkt)?;
357 self.metrics.datagrams_sent_total.inc();
358 Ok(())
359 }
360
361 pub async fn send_wait(&self, pkt: Bytes) -> Result<(), SendPacketError<T>> {
366 let pkt = self.validate_tun(pkt)?;
367 self.conn.send_datagram_wait(pkt).await?;
368 Ok(())
369 }
370
371 fn validate_tun(&self, pkt: Bytes) -> Result<Bytes, SendPacketError<T>> {
372 if self.is_closed.load(Ordering::Acquire) {
374 return Err(SendPacketError::ConnectionClosed);
375 }
376 let current_state_version = self.state_machine.state_version();
378 if self
379 .last_state_version
380 .compare_exchange(
381 current_state_version - 1,
382 current_state_version,
383 Ordering::AcqRel,
384 Ordering::Acquire,
385 )
386 .is_ok()
387 {
388 if self.state_machine.is_closed() {
391 self.is_closed.store(true, Ordering::Release);
392 return Err(SendPacketError::ConnectionClosed);
393 }
394 let addresses = self.state_machine.get_addresses()?;
396
397 return Err(SendPacketError::NewAssignedAddress((
399 Box::new(Sender::new(
400 addresses,
401 self.conn.clone(),
402 self.state_machine.clone(),
403 current_state_version,
404 self.metrics.clone(),
405 )),
406 pkt,
407 )));
408 }
409
410 Ok(pkt)
411 }
412}
413
414impl<T: SnapTunToken> std::fmt::Debug for Sender<T> {
415 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
416 f.debug_struct("Sender")
417 .field("addresses", &self.addresses)
418 .field("conn", &self.conn.stable_id())
419 .field("last_state_version", &self.last_state_version)
420 .finish()
421 }
422}
423
424#[derive(Debug, thiserror::Error)]
426pub enum SendPacketError<T: SnapTunToken> {
427 #[error("connection closed")]
429 ConnectionClosed,
430 #[error("address was re-assigned")]
432 NewAssignedAddress((Box<Sender<T>>, Bytes)),
433 #[error("address assignment error: {0}")]
435 AddressAssignmentError(#[from] AddressAssignmentError),
436 #[error("underlying send error")]
438 SendDatagramError(#[from] quinn::SendDatagramError),
439}
440
441pub struct Receiver<T: SnapTunToken> {
444 metrics: ReceiverMetrics,
445 conn: quinn::Connection,
446 state_machine: Arc<TunnelStateMachine<T>>,
447 last_state_version: AtomicUsize,
448 is_closed: AtomicBool,
449}
450
451#[derive(Debug, thiserror::Error)]
453pub enum ReceivePacketError {
454 #[error("quinn error: {0}")]
456 ConnectionError(#[from] quinn::ConnectionError),
457 #[error("connection closed")]
459 ConnectionClosed,
460}
461
462impl<T: SnapTunToken> Receiver<T> {
463 fn new(
464 conn: quinn::Connection,
465 state_machine: Arc<TunnelStateMachine<T>>,
466 initial_state_version: usize,
467 metrics: ReceiverMetrics,
468 ) -> Self {
469 Self {
470 conn,
471 state_machine,
472 last_state_version: AtomicUsize::new(initial_state_version),
473 is_closed: AtomicBool::new(false),
474 metrics,
475 }
476 }
477
478 pub async fn receive(&self) -> Result<Bytes, ReceivePacketError> {
480 let current_state_version = self.state_machine.state_version();
482 if self
483 .last_state_version
484 .compare_exchange(
485 current_state_version - 1,
486 current_state_version,
487 Ordering::AcqRel,
488 Ordering::Acquire,
489 )
490 .is_ok()
491 {
492 if self.state_machine.is_closed() {
494 self.is_closed.store(true, Ordering::Release);
495 }
496 }
497 if self.is_closed.load(Ordering::Acquire) {
498 return Err(ReceivePacketError::ConnectionClosed);
499 }
500 let p = self.conn.read_datagram().await?;
501 self.metrics.datagrams_received_total.inc();
502 Ok(p)
503 }
504}
505
506#[derive(Debug, thiserror::Error)]
508pub enum ControlError {
509 #[error("parse control request error: {0}")]
511 ParseError(#[from] ParseControlRequestError),
512 #[error("send control response error: {0}")]
514 SendError(#[from] SendControlResponseError),
515 #[error("wait for completion error: {0}")]
517 StoppedError(#[from] quinn::StoppedError),
518 #[error("session expired")]
520 SessionExpired,
521 #[error("connection closed prematurely")]
523 ClosedPrematurely,
524}
525
526pub struct Control {
529 driver_fut: Pin<Box<dyn Future<Output = Result<(), ControlError>> + Send>>,
530}
531
532impl Control {
533 fn new<T>(conn: quinn::Connection, tunnel_state: Arc<TunnelStateMachine<T>>) -> Self
534 where
535 T: for<'de> Deserialize<'de> + Token + Clone,
536 {
537 let fut = async move {
538 loop {
539 tokio::select! {
540 _ = tunnel_state.await_session_expiry() => {
541 tunnel_state.shutdown();
543 conn.close(SnaptunConnErrors::SessionExpired.into(), b"session expired");
544 return Err(ControlError::SessionExpired)
545 }
546 res = conn.accept_bi() => {
547 let (mut snd, mut rcv) = match res {
548 Ok(v) => v,
549 Err(quinn::ConnectionError::ApplicationClosed(_)) => {
550 tunnel_state.shutdown();
551 return Ok(());
552 }
553 Err(_) => {
554 tunnel_state.shutdown();
555 return Err(ControlError::ClosedPrematurely);
556 }
557 };
558
559 let mut buf = vec![0u8; CTRL_REQUEST_BUF_SIZE];
560 let control_request = parse_http_request(&mut buf, &mut rcv).await.inspect_err(|err| {
561 handle_invalid_request(&conn, err);
562 tunnel_state.shutdown();
563 })?;
564
565 let (code, body) = tunnel_state.process_control_request(SystemTime::now(), control_request);
566 send_http_response(&mut snd, code, &body).await
567 .inspect_err(|_| {
568 tunnel_state.shutdown();
569 conn.close(SnaptunConnErrors::InternalError.into(), b"send control response error");
570 })?;
571
572 snd.stopped().await?;
573 }
574 }
575 }
576 };
577 let driver_fut = Box::pin(fut);
578 Self { driver_fut }
579 }
580}
581
582impl Future for Control {
583 type Output = Result<(), ControlError>;
584
585 fn poll(
586 mut self: std::pin::Pin<&mut Self>,
587 cx: &mut std::task::Context<'_>,
588 ) -> std::task::Poll<Self::Output> {
589 self.driver_fut.as_mut().poll(cx)
590 }
591}
592
593#[derive(Debug, thiserror::Error)]
595pub enum AddressAssignmentError {
596 #[error("no address assigned")]
598 NoAddressAssigned,
599}
600
601pub struct TunnelStateMachine<T: SnapTunToken> {
611 validator: Arc<dyn TokenValidator<T>>,
612 allocator: Arc<dyn AddressAllocator<T>>,
613 inner_state: RwLock<TunnelState>,
614 state_version: AtomicUsize,
615 sender: watch::Sender<()>,
617 receiver: watch::Receiver<()>,
618}
619
620impl<T: SnapTunToken> Drop for TunnelStateMachine<T> {
621 fn drop(&mut self) {
622 self.shutdown();
624 }
625}
626
627impl<T: SnapTunToken> TunnelStateMachine<T> {
628 pub(crate) fn new(
629 validator: Arc<dyn TokenValidator<T>>,
630 allocator: Arc<dyn AddressAllocator<T>>,
631 ) -> Self {
632 let (sender, receiver) = watch::channel(());
633
634 Self {
635 validator,
636 allocator,
637 inner_state: Default::default(),
638 state_version: AtomicUsize::new(0),
639 sender,
640 receiver,
641 }
642 }
643
644 fn process_control_request(
647 &self,
648 now: SystemTime,
649 control_request: ControlRequest,
650 ) -> (http::StatusCode, Vec<u8>) {
651 let mut inner_state = self.inner_state.write().expect("no fail");
652 if let TunnelState::Closed = *inner_state {
653 return (http::StatusCode::BAD_REQUEST, vec![]);
654 }
655 match control_request {
656 ControlRequest::AddressAssignment(token, address_assign_request) => {
657 self.locked_process_addr_assignment_request(
658 &mut inner_state,
659 now,
660 token,
661 address_assign_request,
662 )
663 }
664 ControlRequest::SessionRenewal(token) => {
665 self.locked_process_session_renewal(&mut inner_state, now, token)
666 }
667 }
668 }
669
670 fn locked_process_session_renewal(
671 &self,
672 inner_state: &mut TunnelState,
673 now: SystemTime,
674 token: String,
675 ) -> (http::StatusCode, Vec<u8>) {
676 let mut resp_body = vec![];
677 let resp_code = match self.validator.validate(now, &token) {
678 Ok(claims) => {
679 let token_expiry = claims.exp_time();
680
681 self.locked_update_tunnel_session(inner_state, token_expiry);
683
684 let resp = SessionRenewalResponse {
685 valid_until: unix_epoch_from_system_time(token_expiry),
686 };
687 resp.encode(&mut resp_body).expect("no fail");
688 StatusCode::OK
689 }
690 Err(TokenValidatorError::JwtSignatureInvalid()) => {
691 info!("Invalid signature");
692 StatusCode::UNAUTHORIZED
693 }
694 Err(TokenValidatorError::JwtError(err)) => {
695 info!(?err, "Token validation failed");
696 StatusCode::BAD_REQUEST
697 }
698 Err(TokenValidatorError::TokenExpired(err)) => {
699 info!(?err, "Token validation failed: token expired");
700 StatusCode::UNAUTHORIZED
701 }
702 };
703 (resp_code, resp_body)
704 }
705
706 fn locked_process_addr_assignment_request(
709 &self,
710 inner_state: &mut TunnelState,
711 now: SystemTime,
712 token: String,
713 addr_assignments: AddressAssignRequest,
714 ) -> (http::StatusCode, Vec<u8>) {
715 let mut resp_body = vec![];
716 let resp_code = match self.validator.validate(now, &token) {
717 Ok(claims) => {
718 if addr_assignments.requested_addresses.len() > 1 {
719 warn!("Address assignment failed, multiple address assignments not supported");
721 return (StatusCode::NOT_IMPLEMENTED, resp_body);
722 }
723
724 let mut requests: Vec<(IsdAsn, IpNet)> = match addr_assignments
725 .requested_addresses
726 .iter()
727 .map(|range| range.try_into())
728 .collect::<Result<Vec<_>, AddrError>>()
729 {
730 Ok(reqs) => reqs,
731 Err(_) => return (StatusCode::BAD_REQUEST, vec![]),
732 };
733
734 if requests
736 .iter()
737 .any(|(_, net)| net.prefix_len() != net.max_prefix_len())
738 {
739 warn!("Address assignment failed, prefix assignments are not supported");
740 return (StatusCode::NOT_IMPLEMENTED, resp_body);
741 }
742
743 if requests.is_empty() {
745 requests.push((IsdAsn::WILDCARD, IPV4_WILDCARD));
746 requests.push((IsdAsn::WILDCARD, IPV6_WILDCARD));
747 }
748
749 let session_expiry = match inner_state.session_validity() {
751 Ok(v) => v,
752 Err(err) => {
753 error!(
754 ?err,
755 "Failed to get session validity when processing address assignment request"
756 );
757 return (StatusCode::INTERNAL_SERVER_ERROR, vec![]);
758 }
759 };
760
761 let mut assigned_address: Option<AddressAllocation> = None;
763 for (requested_isd_as, requested_net) in requests.iter() {
764 match self
765 .allocator
766 .allocate(*requested_isd_as, *requested_net, claims.clone())
767 {
768 Ok(allocation) => {
769 assigned_address = Some(allocation);
770 break;
771 }
772 Err(err) => {
773 debug!(
774 ?err,
775 "Address allocation failed for ISD-AS {requested_isd_as} and net {requested_net}"
776 );
777 }
778 }
779 }
780
781 let Some(assigned_address) = assigned_address else {
783 warn!("Address assignment failed - no available addresses for: {requests:?}",);
784 return (StatusCode::BAD_REQUEST, vec![]);
785 };
786
787 self.locked_update_state(
788 inner_state,
789 TunnelState::Assigned {
790 session_expiry,
791 address: assigned_address.clone(),
792 },
793 );
794
795 let resp = AddressAssignResponse {
796 assigned_addresses: vec![(&assigned_address.address).into()],
797 };
798
799 resp.encode(&mut resp_body).expect("no fail");
800 StatusCode::OK
801 }
802 Err(TokenValidatorError::JwtSignatureInvalid()) => {
803 info!("Invalid JWT Signature");
804 StatusCode::UNAUTHORIZED
805 }
806 Err(TokenValidatorError::JwtError(err)) => {
807 info!(?err, "Token validation failed");
808 StatusCode::BAD_REQUEST
809 }
810 Err(TokenValidatorError::TokenExpired(err)) => {
811 info!(?err, "Token validation failed: token expired");
812 StatusCode::UNAUTHORIZED
813 }
814 };
815 (resp_code, resp_body)
816 }
817
818 fn locked_update_tunnel_session(
819 &self,
820 inner_state: &mut TunnelState,
821 session_expiry: SystemTime,
822 ) {
823 match inner_state {
824 TunnelState::Unassigned => {
825 *inner_state = TunnelState::SessionEstablished { session_expiry };
826 }
827 TunnelState::SessionEstablished { .. } => {
828 *inner_state = TunnelState::SessionEstablished { session_expiry };
829 }
830 TunnelState::Assigned { address, .. } => {
831 *inner_state = TunnelState::Assigned {
832 session_expiry,
833 address: address.clone(),
834 };
835 }
836 TunnelState::Closed => tracing::error!("Updating tunnel session but in closed state"),
838 };
839 }
840
841 fn locked_update_state(&self, inner_state: &mut TunnelState, new_state: TunnelState) {
842 tracing::debug!(%new_state, "Updating tunnel state");
843 *inner_state = new_state;
844
845 self.state_version.fetch_add(1, Ordering::AcqRel);
846
847 if self.sender.send(()).is_err() {
848 debug!("Failed to notify session expiry update");
851 }
852 }
853
854 fn get_addresses(&self) -> Result<Vec<EndhostAddr>, AddressAssignmentError> {
855 let guard = self.inner_state.read().expect("no fail");
856 if let TunnelState::Assigned {
857 address,
858 session_expiry: _,
859 } = &*guard
860 {
861 return Ok(vec![address.address]);
862 }
863 Err(AddressAssignmentError::NoAddressAssigned)
864 }
865
866 async fn await_session_expiry(&self) {
867 let mut expiry_notifier = self.receiver.clone();
868 loop {
869 let valid_duration = {
870 let res = {
871 let guard = self.inner_state.read().expect("no fail");
872 guard.session_validity()
873 };
874 match res {
875 Ok(session_validity) => {
876 match session_validity.duration_since(SystemTime::now()) {
877 Ok(dur) => dur,
878 Err(_) => return, }
880 }
881 Err(err) => {
882 tracing::warn!(%err, "Tunnel in an invalid state");
885 return;
886 }
887 }
888 };
889
890 tokio::select! {
891 _ = expiry_notifier.changed() => {
892 continue;
894 }
895 _ = tokio::time::sleep(valid_duration) => {
896 return;
898 }
899 }
900 }
901 }
902
903 fn state_version(&self) -> usize {
904 self.state_version.load(Ordering::Acquire)
905 }
906
907 fn is_closed(&self) -> bool {
908 if let TunnelState::Closed = *self.inner_state.read().expect("no fail") {
909 return true;
910 }
911 false
912 }
913
914 fn shutdown(&self) {
915 let mut inner_state = self.inner_state.write().expect("no fail");
916
917 if let TunnelState::Assigned {
919 session_expiry: _,
920 address,
921 } = &*inner_state
922 {
923 if !self.allocator.put_on_hold(address.id.clone()) {
924 error!(addr=?address.address, "Could not set address to hold during shutdown - address was released while tunnel was still assigned");
925 }
926 }
927
928 self.locked_update_state(&mut inner_state, TunnelState::Closed);
929 }
930}
931
932#[derive(Debug, thiserror::Error)]
933enum TunnelStateError {
934 #[error("invalid state: {0}")]
935 InvalidState(TunnelState),
936}
937
938#[derive(Debug, Clone)]
939enum TunnelState {
940 Unassigned,
941 SessionEstablished {
942 session_expiry: SystemTime,
943 },
944 Assigned {
945 session_expiry: SystemTime,
946 address: AddressAllocation,
947 },
948 Closed,
949}
950
951impl TunnelState {
952 fn session_validity(&self) -> Result<SystemTime, TunnelStateError> {
953 match self {
954 TunnelState::SessionEstablished { session_expiry } => Ok(*session_expiry),
955 TunnelState::Assigned { session_expiry, .. } => Ok(*session_expiry),
956 _ => Err(TunnelStateError::InvalidState(self.clone())),
957 }
958 }
959}
960
961impl Default for TunnelState {
962 fn default() -> Self {
963 Self::Unassigned
964 }
965}
966
967impl std::fmt::Display for TunnelState {
968 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
969 match self {
970 TunnelState::Unassigned => write!(f, "Unassigned"),
971 TunnelState::SessionEstablished { session_expiry } => {
972 write!(
973 f,
974 "SessionEstablished ({})",
975 DateTime::<Utc>::from(*session_expiry)
976 )
977 }
978 TunnelState::Assigned {
979 session_expiry,
980 address,
981 } => {
982 write!(
983 f,
984 "Assigned (valid until: {}, addresses: [{}])",
985 DateTime::<Utc>::from(*session_expiry),
986 address.address
987 )
988 }
989 TunnelState::Closed => write!(f, "Closed"),
990 }
991 }
992}
993
994#[derive(Debug)]
995enum ControlRequest {
996 AddressAssignment(String, AddressAssignRequest),
997 SessionRenewal(String),
998}
999
1000fn handle_invalid_request(conn: &quinn::Connection, err: &ParseControlRequestError) {
1001 match err {
1002 ParseControlRequestError::ClosedPrematurely => {
1003 conn.close(
1004 SnaptunConnErrors::InternalError.into(),
1005 b"closed prematurely",
1006 );
1007 }
1008 ParseControlRequestError::ReadError(_) => {
1009 conn.close(SnaptunConnErrors::InternalError.into(), b"read error");
1010 }
1011 ParseControlRequestError::InvalidRequest(reason) => {
1012 conn.close(SnaptunConnErrors::InvalidRequest.into(), reason.as_bytes());
1013 }
1014 ParseControlRequestError::Unauthenticated(reason) => {
1015 conn.close(SnaptunConnErrors::Unauthenticated.into(), reason.as_bytes());
1016 }
1017 }
1018}
1019
1020#[derive(Debug, thiserror::Error)]
1022pub enum ParseControlRequestError {
1023 #[error("invalid request: {0}")]
1025 InvalidRequest(String),
1026 #[error("read error: {0}")]
1028 ReadError(#[from] quinn::ReadError),
1029 #[error("unauthenticated: {0}")]
1031 Unauthenticated(String),
1032 #[error("closed prematurely")]
1034 ClosedPrematurely,
1035}
1036
1037async fn parse_http_request(
1049 buf: &mut [u8],
1050 rcv: &mut RecvStream,
1051) -> Result<ControlRequest, ParseControlRequestError> {
1052 use ParseControlRequestError::*;
1053 let mut cursor = 0;
1054 while let Some(n) = rcv.read(&mut buf[cursor..]).await? {
1055 cursor += n;
1056 let mut headers = [httparse::EMPTY_HEADER; 16];
1057 let mut req = httparse::Request::new(&mut headers);
1058 if let Ok(httparse::Status::Complete(body_offset)) = req.parse(&buf[..cursor]) {
1059 if !matches!(req.method, Some("POST")) {
1060 return Err(InvalidRequest("invalid method".into()));
1061 }
1062 match req.path {
1065 Some(PATH_ADDR_ASSIGNMENT) => {}
1066 Some(PATH_SESSION_RENEWAL) => {}
1067 Some(_) | None => return Err(InvalidRequest("invalid path".into())),
1068 }
1069 let Some(h) = req.headers.iter().find(|h| h.name == AUTH_HEADER) else {
1070 return Err(Unauthenticated("no auth header".into()));
1071 };
1072 let t = h
1073 .value
1074 .strip_prefix(b"Bearer ")
1075 .ok_or(Unauthenticated(
1076 "bearer not found in authorization header".into(),
1077 ))
1078 .map(|x| String::from_utf8_lossy(x).to_string())?;
1079 let path = req.path.unwrap();
1081 match path {
1082 PATH_ADDR_ASSIGNMENT => {
1083 while let Some(n) = rcv.read(&mut buf[cursor..]).await? {
1086 cursor += n;
1087 }
1088 let Ok(addr_req) = AddressAssignRequest::decode(&buf[body_offset..cursor])
1090 else {
1091 return Err(InvalidRequest(
1092 "error when parsing address assignment request".into(),
1093 ));
1094 };
1095 return Ok(ControlRequest::AddressAssignment(t, addr_req));
1096 }
1097 PATH_SESSION_RENEWAL => return Ok(ControlRequest::SessionRenewal(t)),
1098 path => unreachable!("invalid path: {path}"),
1099 }
1100 }
1101 if cursor == buf.len() {
1103 return Err(InvalidRequest("request too big".into()));
1104 }
1105 }
1106 Err(ClosedPrematurely)
1107}
1108
1109#[derive(Debug, thiserror::Error)]
1111pub enum SendControlResponseError {
1112 #[error("i/o error: {0}")]
1114 IoError(#[from] std::io::Error),
1115 #[error("stream closed: {0}")]
1117 ClosedStream(#[from] quinn::ClosedStream),
1118}
1119
1120async fn send_http_response(
1122 stream: &mut SendStream,
1123 code: http::StatusCode,
1124 body: &[u8],
1125) -> Result<(), SendControlResponseError> {
1126 async fn write_all(stream: &mut SendStream, data: &[u8]) -> std::io::Result<()> {
1128 let mut cursor = 0;
1129 while cursor < data.len() {
1130 cursor += stream.write(&data[cursor..]).await?;
1131 }
1132 Ok(())
1133 }
1134
1135 write_all(
1136 stream,
1137 format!(
1138 "HTTP/1.1 {} {}\r\nContent-Length: {}\r\n\r\n",
1139 code.as_str(),
1140 code.canonical_reason().unwrap_or(""),
1141 body.len(),
1142 )
1143 .as_bytes(),
1144 )
1145 .await?;
1146 write_all(stream, body).await?;
1147
1148 stream.finish()?;
1150 Ok(())
1151}
1152
1153#[cfg(test)]
1154mod tests {
1155 use std::time::{Duration, UNIX_EPOCH};
1156
1157 use snap_tokens::{Pssid, snap_token::SnapTokenClaims};
1158
1159 use super::*;
1160
1161 mod address_allocation {
1162
1163 fn setup() -> (TunnelStateMachine<SnapTokenClaims>, Arc<MockAllocator>) {
1164 let alloc = Arc::new(MockAllocator {
1165 is_allocated: AtomicBool::new(false),
1166 is_on_hold: AtomicBool::new(false),
1167 });
1168
1169 let tun = TunnelStateMachine::new(Arc::new(MockValidator), alloc.clone());
1170 let (status, body) = tun.process_control_request(
1172 SystemTime::now(),
1173 ControlRequest::SessionRenewal("valid_token".into()),
1174 );
1175 assert_eq!(
1176 status,
1177 http::StatusCode::OK,
1178 "failed to renew session - body: {body:?}"
1179 );
1180
1181 (tun, alloc)
1182 }
1183
1184 use snap_tokens::snap_token::SnapTokenClaims;
1185
1186 use super::*;
1187
1188 #[test]
1189 fn should_put_on_hold_after_shutdown() {
1190 let (tun, alloc) = setup();
1191
1192 let (status, body) = tun.process_control_request(
1193 SystemTime::now(),
1194 ControlRequest::AddressAssignment(
1195 "valid_token".into(),
1196 AddressAssignRequest {
1197 requested_addresses: vec![],
1198 },
1199 ),
1200 );
1201 assert_eq!(status, http::StatusCode::OK, "failed - body: {body:?}");
1202 assert!(alloc.is_allocated.load(Ordering::Acquire));
1203 tun.shutdown();
1204 assert!(alloc.is_on_hold.load(Ordering::Acquire));
1205 }
1206
1207 #[test]
1208 fn should_put_on_hold_after_drop() {
1209 let (tun, alloc) = setup();
1210
1211 let (status, body) = tun.process_control_request(
1212 SystemTime::now(),
1213 ControlRequest::AddressAssignment(
1214 "valid_token".into(),
1215 AddressAssignRequest {
1216 requested_addresses: vec![],
1217 },
1218 ),
1219 );
1220 assert_eq!(status, http::StatusCode::OK, "failed - body: {body:?}");
1221 assert!(alloc.is_allocated.load(Ordering::Acquire));
1222 drop(tun);
1223 assert!(alloc.is_on_hold.load(Ordering::Acquire));
1224 }
1225 }
1226
1227 struct MockValidator;
1228 impl TokenValidator<SnapTokenClaims> for MockValidator {
1229 fn validate(
1230 &self,
1231 now: SystemTime,
1232 _: &str,
1233 ) -> Result<SnapTokenClaims, TokenValidatorError> {
1234 Ok(SnapTokenClaims {
1235 pssid: Pssid::new(),
1236 exp: (now.duration_since(UNIX_EPOCH).unwrap() + Duration::from_secs(3600))
1237 .as_secs(),
1238 })
1239 }
1240 }
1241
1242 struct MockAllocator {
1243 is_allocated: AtomicBool,
1244 is_on_hold: AtomicBool,
1245 }
1246 impl AddressAllocator<SnapTokenClaims> for MockAllocator {
1247 fn allocate(
1248 &self,
1249 isd_as: IsdAsn,
1250 prefix: IpNet,
1251 claims: SnapTokenClaims,
1252 ) -> Result<AddressAllocation, crate::AddressAllocationError> {
1253 if self.is_allocated.load(Ordering::Acquire) {
1254 return Err(crate::AddressAllocationError::NoAddressesAvailable);
1255 }
1256 self.is_allocated.store(true, Ordering::Release);
1257
1258 Ok(AddressAllocation {
1259 id: crate::AddressAllocationId {
1260 isd_as,
1261 id: claims.id(),
1262 },
1263 address: EndhostAddr::new(isd_as, prefix.addr()),
1264 })
1265 }
1266
1267 fn put_on_hold(&self, _id: crate::AddressAllocationId) -> bool {
1268 self.is_on_hold.store(true, Ordering::Release);
1269 true
1270 }
1271
1272 fn deallocate(&self, _id: crate::AddressAllocationId) -> bool {
1273 false
1274 }
1275 }
1276}