use std::{
sync::Arc,
time::{Instant, SystemTime},
};
use axum::{
Extension, Router,
extract::{ConnectInfo, State},
};
use scion_sdk_axum_connect_rpc::{
error::{CrpcError, CrpcErrorCode},
extractor::ConnectRpc,
};
use scion_sdk_token_validator::validator::Token;
use snap_tokens::AnyClaims;
use x25519_dalek::PublicKey;
use crate::{
crpc_api::api_service::model::{SnapDataPlaneResolver, SnapTunIdentityRegistry},
protobuf::anapaya::snap::v1::api_service::{
GetSnapDataPlaneRequest, GetSnapDataPlaneResponse, RegisterSnapTunIdentityRequest,
RegisterSnapTunIdentityResponse,
},
};
pub mod model {
use std::{
net::{IpAddr, SocketAddr},
time::{Duration, Instant},
};
use axum::http::StatusCode;
use url::Url;
use x25519_dalek::PublicKey;
pub trait SnapDataPlaneResolver: Send + Sync {
fn get_data_plane_address(
&self,
endhost_ip: IpAddr,
) -> Result<SnapDataPlane, (StatusCode, anyhow::Error)>;
}
pub struct SnapDataPlane {
pub address: SocketAddr,
pub snap_tun_control_address: Option<Url>,
pub snap_static_x25519: Option<PublicKey>,
}
pub trait SnapTunIdentityRegistry: Send + Sync {
fn register(
&self,
now: Instant,
key: &str,
initiator_identity: [u8; 32],
psk_share: Option<[u8; 32]>,
lifetime: Duration,
) -> bool;
}
}
pub(crate) mod convert {
use std::net::{AddrParseError, SocketAddr};
use url::Url;
use x25519_dalek::PublicKey;
use crate::{
crpc_api::api_service::model::SnapDataPlane,
protobuf::anapaya::snap::v1::api_service as rpc,
};
#[derive(thiserror::Error, Debug)]
pub enum ConvertError {
#[error("failed to parse data plane address: {0}")]
ParseAddr(AddrParseError),
#[error("failed to parse server control address: {0}")]
ParseSnapTunControlAddr(AddrParseError),
#[error("server static identity is not 32 bytes")]
InvalidServerStaticIdentityLength,
}
impl TryFrom<rpc::GetSnapDataPlaneResponse> for SnapDataPlane {
type Error = ConvertError;
fn try_from(value: rpc::GetSnapDataPlaneResponse) -> Result<Self, Self::Error> {
let snap_tun_control_address = value
.snap_tun_control_address
.map(|address| {
if let Ok(url) = Url::parse(&address) {
return Ok(url);
}
match address.parse::<SocketAddr>() {
Ok(addr) => {
let mut u = Url::parse("http://.").unwrap();
let _ = u.set_ip_host(addr.ip());
let _ = u.set_port(Some(addr.port()));
Ok(u)
}
Err(e) => Err(ConvertError::ParseSnapTunControlAddr(e)),
}
})
.transpose()?;
let snap_static_x25519 = value
.snap_static_x25519
.map(|key| {
TryInto::<[u8; 32]>::try_into(key.as_slice())
.map_err(|_| ConvertError::InvalidServerStaticIdentityLength)
.map(PublicKey::from)
})
.transpose()?;
Ok(SnapDataPlane {
address: value.address.parse().map_err(ConvertError::ParseAddr)?,
snap_tun_control_address,
snap_static_x25519,
})
}
}
}
pub(crate) const SERVICE_PATH: &str = "/anapaya.snap.v1.SnapControl";
pub(crate) const GET_SNAP_DATA_PLANE_ADDRESS: &str = "/GetSnapDataPlaneAddress";
pub(crate) const REGISTER_SNAPTUN_IDENTITY: &str = "/RegisterSnapTunIdentity";
pub fn nest_snap_control_api(
router: axum::Router,
snap_resolver: Arc<dyn SnapDataPlaneResolver>,
identity_registrar: Arc<dyn SnapTunIdentityRegistry>,
) -> axum::Router {
router.nest(
SERVICE_PATH,
Router::new()
.route(
GET_SNAP_DATA_PLANE_ADDRESS,
axum::routing::post(get_snap_data_plane_address_handler),
)
.with_state(snap_resolver)
.route(
REGISTER_SNAPTUN_IDENTITY,
axum::routing::post(register_snaptun_identity_handler),
)
.with_state(identity_registrar),
)
}
async fn get_snap_data_plane_address_handler(
State(rendezvous_hasher): State<Arc<dyn SnapDataPlaneResolver>>,
_snap_token: Extension<AnyClaims>,
ConnectInfo(addr): ConnectInfo<std::net::SocketAddr>,
ConnectRpc(_request): ConnectRpc<GetSnapDataPlaneRequest>,
) -> Result<ConnectRpc<GetSnapDataPlaneResponse>, CrpcError> {
let addr = rendezvous_hasher.get_data_plane_address(addr.ip())?;
Ok(ConnectRpc(GetSnapDataPlaneResponse {
address: addr.address.to_string(),
snap_tun_control_address: addr
.snap_tun_control_address
.map(|address| address.to_string()),
snap_static_x25519: addr.snap_static_x25519.map(|key| key.to_bytes().to_vec()),
}))
}
async fn register_snaptun_identity_handler(
State(identity_registry): State<Arc<dyn SnapTunIdentityRegistry>>,
snap_token: Extension<AnyClaims>,
ConnectInfo(_): ConnectInfo<std::net::SocketAddr>,
ConnectRpc(request): ConnectRpc<RegisterSnapTunIdentityRequest>,
) -> Result<ConnectRpc<RegisterSnapTunIdentityResponse>, CrpcError> {
let now = SystemTime::now();
let lifetime = snap_token.0.exp_time().duration_since(now).map_err(|_| {
CrpcError::new(
CrpcErrorCode::InvalidArgument,
"expiration time is in the past".to_string(),
)
})?;
let initiator_identity = {
let key_bytes: [u8; 32] = request
.initiator_static_x25519
.as_slice()
.try_into()
.map_err(|_| {
CrpcError::new(
CrpcErrorCode::InvalidArgument,
"initiator identity is not 32 bytes".to_string(),
)
})?;
PublicKey::from(key_bytes)
};
let psk_share: Option<[u8; 32]> = if request.psk_share.as_slice() == [0u8; 32] {
None
} else {
Some(request.psk_share.as_slice().try_into().map_err(|_| {
CrpcError::new(
CrpcErrorCode::InvalidArgument,
"psk share is not 32 bytes".to_string(),
)
})?)
};
let key = &snap_token.jti();
if !identity_registry.register(
Instant::now(),
key,
*initiator_identity.as_bytes(),
psk_share,
lifetime,
) {
tracing::info!(key, "re-registered identity");
}
Ok(ConnectRpc(RegisterSnapTunIdentityResponse {
psk_share: [0u8; 32].to_vec(),
}))
}