use std::{
io,
net::{Ipv4Addr, SocketAddr},
path::{Path, PathBuf},
sync::{Arc, atomic::AtomicBool},
time::{Duration, SystemTime},
};
use jsonwebtoken::DecodingKey;
use scion_sdk_observability::metrics::registry::MetricsRegistry;
use scion_sdk_token_validator::validator::Validator;
use scion_sdk_utils::{
io::{get_tmp_path, read_file, write_file},
task_handler::{CancelTaskSet, InProcess},
};
use snap_dataplane::{
session::state::insecure_const_session_key_pair,
state::Id,
tunnel_gateway::{
dispatcher::TunnelGatewayDispatcher, metrics::TunnelGatewayDispatcherMetrics,
start_tunnel_gateway, state::SharedTunnelGatewayState,
},
};
use thiserror::Error;
use tokio::{net::TcpListener, time::sleep};
use tracing::{debug, info};
use crate::{
addr_to_http_url,
api::admin,
authorization_server,
dto::{self},
endhost_api::PsEndhostApi,
io_config::{IoConfig, SharedPocketScionIoConfig},
management_api,
network::local::receivers::{
router_socket::{RouterSocket, SharedRouterSocket},
tunnel_gateway::TunnelGatewayReceiver,
},
state::{
SharedPocketScionState, SystemState,
address_allocator::StateSnapAddressAllocator,
endhost_segment_lister::StateEndhostSegmentLister,
simulation_dispatcher::{AsNetSimDispatcher, NetSimDispatcher},
},
};
pub const DEFAULT_MGMT_PORT: u16 = 9000;
pub struct PocketScionRuntimeBuilder {
system_state: PathOrObject<SystemState>,
io_config: PathOrObject<IoConfig>,
mgmt_listen_addr: Option<SocketAddr>,
start_time: TimestampOrNow,
}
impl PocketScionRuntimeBuilder {
pub fn new() -> Self {
Self {
system_state: PathOrObject::Unspecified,
io_config: PathOrObject::Unspecified,
mgmt_listen_addr: None,
start_time: TimestampOrNow::Now,
}
}
pub fn with_mgmt_listen_addr(mut self, mgmt_listen_addr: SocketAddr) -> Self {
self.mgmt_listen_addr = Some(mgmt_listen_addr);
self
}
pub fn with_io_config(mut self, io_config: IoConfig) -> Self {
self.io_config = PathOrObject::Object(io_config);
self
}
pub fn with_io_config_path<P: AsRef<Path>>(mut self, path: P) -> Self {
self.io_config = PathOrObject::Path(path.as_ref().into());
self
}
pub fn with_system_state(mut self, system_state: SystemState) -> Self {
self.system_state = PathOrObject::Object(system_state);
self
}
pub fn with_system_state_path<P: AsRef<Path>>(mut self, path: P) -> Self {
self.system_state = PathOrObject::Path(path.as_ref().into());
self
}
pub fn with_start_time(mut self, time: SystemTime) -> Self {
self.start_time = TimestampOrNow::Timestamp(time);
self
}
pub async fn start(self) -> Result<PocketScionRuntime, PocketScionRuntimeError> {
self.start_with_task_set(CancelTaskSet::new()).await
}
pub async fn start_with_task_set(
self,
mut task_set: CancelTaskSet,
) -> Result<PocketScionRuntime, PocketScionRuntimeError> {
let ready_state = Arc::new(AtomicBool::new(false));
let start_time = match self.start_time {
TimestampOrNow::Now => SystemTime::now(),
TimestampOrNow::Timestamp(system_time) => system_time,
};
let system_state = self.system_state.load(start_time).await?;
let pstate = SharedPocketScionState::from_system_state(system_state);
let io_config = self.io_config.load().await?;
let io_config = SharedPocketScionIoConfig::from_state(io_config);
for (snap_id, snap_state) in pstate.snaps() {
let token = task_set.cancellation_token();
let listener = match io_config.snap_control_addr(snap_id) {
Some(addr) => {
TcpListener::bind(&addr).await.map_err(|e| {
std::io::Error::new(
e.kind(),
format!("Failed to bind to SNAP CP addr {addr}: {e}"),
)
})?
}
None => {
debug!(snap=%snap_id, "No control plane API port for SNAP specified");
let listener =
TcpListener::bind(&SocketAddr::from((Ipv4Addr::LOCALHOST, 0))).await?;
io_config.set_snap_control_addr(snap_id, listener.local_addr()?);
listener
}
};
let dp_discovery = pstate.snap_data_plane_discovery(snap_id, io_config.clone());
let session_manager = pstate.snap_session_manager(snap_id, io_config.clone());
let decoding_key =
DecodingKey::from_ed_pem(pem::encode(&pstate.snap_token_public_key()).as_bytes())
.unwrap();
let local_ases = snap_state.isd_ases();
let segment_lister = StateEndhostSegmentLister::new(pstate.clone(), local_ases);
task_set.join_set.spawn(async move {
snap_control::server::start(
token,
listener,
dp_discovery,
session_manager,
segment_lister,
decoding_key,
snap_control::server::metrics::Metrics::new(&MetricsRegistry::new()),
)
.await
});
}
for (id, _) in pstate.endhost_apis() {
let pstate = pstate.clone();
let io_config = io_config.clone();
task_set.join_set.spawn(async move {
PsEndhostApi::start(id, pstate, io_config)
.await
.map_err(|e| io::Error::other(format!("{e:?}")))
});
}
for snap_id in pstate.snaps_ids() {
let (_, session_decoding_key) = insecure_const_session_key_pair(snap_id.as_usize());
let validator = Arc::new(Validator::new(session_decoding_key, None));
for snap_dp_id in pstate.snap_data_planes(snap_id) {
let address_allocator =
Arc::new(StateSnapAddressAllocator::new(pstate.clone(), snap_dp_id));
let metrics_registry = MetricsRegistry::new();
let (_cert_der, server_config) = scion_sdk_utils::test::generate_cert(
[42u8; 32],
vec!["localhost".into()],
vec![b"snaptun".to_vec()],
);
let server_endpoint = match io_config.snap_data_plane_addr(snap_dp_id) {
Some(addr) => quinn::Endpoint::server(server_config, addr)?,
None => {
debug!(data_plane_id=%snap_dp_id, "No listen address specified for SNAP data plane");
let server_endpoint = quinn::Endpoint::server(
server_config,
SocketAddr::from((Ipv4Addr::LOCALHOST, 0)),
)?;
io_config
.set_snap_data_plane_addr(snap_dp_id, server_endpoint.local_addr()?);
server_endpoint
}
};
let shared_tunnel_gw_state = SharedTunnelGatewayState::new();
let tunnel_gw_dispatcher = TunnelGatewayDispatcher::new(
shared_tunnel_gw_state.clone(),
TunnelGatewayDispatcherMetrics::new(&metrics_registry),
);
task_set.spawn_cancellable_task({
let dispatcher = tunnel_gw_dispatcher.clone();
async move { dispatcher.start_dispatching().await }
});
let registrations = pstate
.snap_data_plane_prefixes(snap_dp_id)
.expect("Data plane registrations should exist");
for (ias, ipnets) in registrations {
for ipnet in ipnets {
pstate
.add_sim_receiver(
ias,
ipnet,
Arc::new(TunnelGatewayReceiver::new(tunnel_gw_dispatcher.clone())),
)
.expect("Failed to add dispatcher");
}
}
start_tunnel_gateway(
&mut task_set,
shared_tunnel_gw_state,
address_allocator,
validator.clone(),
server_endpoint,
Arc::new(NetSimDispatcher::new(pstate.clone())),
metrics_registry,
);
}
}
for sock_id in pstate.router_ids() {
let udp_socket = {
let bind_addr = match io_config.router_socket_addr(sock_id) {
Some(addr) => addr,
None => {
let bind_addr = SocketAddr::from((Ipv4Addr::LOCALHOST, 0));
io_config.set_router_socket_addr(sock_id, bind_addr);
bind_addr
}
};
let socket = tokio::net::UdpSocket::bind(bind_addr).await?;
io_config.set_router_socket_addr(sock_id, socket.local_addr()?);
socket
};
let ias = pstate
.router_ias(sock_id)
.expect("We iterate through existing routers, should exist");
let router_dispatcher = AsNetSimDispatcher::new(ias, pstate.clone());
let router_socket = RouterSocket::new(udp_socket, Arc::new(router_dispatcher)).await?;
let router_socket = SharedRouterSocket::new(router_socket);
pstate
.add_wildcard_sim_receiver(ias, Arc::new(router_socket.clone()))
.expect("Failed to add wildcard receiver");
task_set.spawn_cancellable_task(async move { router_socket.run().await });
}
ready_state.store(true, std::sync::atomic::Ordering::Relaxed);
let mgmt_listen_addr = {
let ready_state_clone = ready_state.clone();
let token = task_set.cancellation_token();
let system_state = pstate.clone();
let io_config = io_config.clone();
let listener = TcpListener::bind(
self.mgmt_listen_addr
.unwrap_or(SocketAddr::from((Ipv4Addr::LOCALHOST, DEFAULT_MGMT_PORT))),
)
.await?;
let listen_address = listener.local_addr()?;
info!(addr=%listen_address, "Starting management API");
task_set.join_set.spawn(async move {
management_api::start(token, ready_state_clone, system_state, io_config, listener)
.await
});
io::Result::Ok(listen_address)
}?;
if pstate.has_auth_server() {
let auth_server = pstate.auth_server();
let io_config = io_config.clone();
let token = task_set.cancellation_token();
task_set.join_set.spawn(async move {
authorization_server::api::start(token, auth_server, io_config).await
});
}
let client = admin::client::ApiClient::new(&addr_to_http_url(mgmt_listen_addr))
.expect("create client");
Ok(PocketScionRuntime {
handle: InProcess::new(task_set),
client,
})
}
}
impl Default for PocketScionRuntimeBuilder {
fn default() -> Self {
Self::new()
}
}
pub struct PocketScionRuntime {
handle: InProcess,
client: admin::client::ApiClient,
}
const MAX_ATTEMPTS: i32 = 5;
const ATTEMPT_WAIT: Duration = Duration::from_millis(200);
#[derive(Error, Debug)]
pub enum PocketScionRuntimeError {
#[error("client error: {0:?}")]
ClientError(#[from] admin::client::ClientError),
#[error("pocket-scion not ready: {0}")]
PocketScionNotReady(String),
#[error("i/o error {0}")]
IoError(#[from] std::io::Error),
#[error("startup error: {0}")]
StartupError(String),
}
impl PocketScionRuntime {
pub async fn stop_and_join(&mut self) {
self.handle.task_set.cancellation_token().cancel();
self.join().await;
}
pub async fn join(&mut self) {
self.handle.task_set.join_all().await;
}
pub async fn wait_for_ready(&self) -> Result<(), PocketScionRuntimeError> {
let mut err = PocketScionRuntimeError::PocketScionNotReady("Unknown state".to_string());
for _ in 1..=MAX_ATTEMPTS {
err = match self.client.get_status().await {
Ok(status) => {
if status.state == admin::api::ReadyState::Ready {
return Ok(());
}
PocketScionRuntimeError::PocketScionNotReady(format!("{status:?}"))
}
Err(e) => PocketScionRuntimeError::ClientError(e),
};
debug!("waiting for Pocket SCION to be ready: {:?}", err);
sleep(ATTEMPT_WAIT).await;
}
Err(err)
}
pub fn api_client(&self) -> admin::client::ApiClient {
self.client.clone()
}
}
#[derive(Debug)]
pub(crate) enum PathOrObject<T> {
Unspecified,
Path(PathBuf),
Object(T),
}
impl PathOrObject<SystemState> {
#[allow(unused)]
pub(crate) async fn sync_to_file(self) -> Option<PathBuf> {
let state = match self {
PathOrObject::Unspecified => return None,
PathOrObject::Path(path_buf) => return Some(path_buf),
PathOrObject::Object(s) => s,
};
let path = get_tmp_path("system_state.json");
let dto = crate::dto::SystemStateDto::from(&state);
write_file(path.clone(), &dto).await.expect("failed");
Some(path)
}
pub(crate) async fn load(self, start_time: SystemTime) -> Result<SystemState, std::io::Error> {
match self {
PathOrObject::Unspecified => Ok(SystemState::default_from_start_time(start_time)),
PathOrObject::Path(path_buf) => {
let dto: dto::SystemStateDto = read_file(path_buf).await?;
SystemState::try_from(dto).map_err(io::Error::other)
}
PathOrObject::Object(t) => Ok(t),
}
}
}
impl PathOrObject<IoConfig> {
#[allow(unused)]
pub(crate) async fn sync_to_file(self) -> Option<PathBuf> {
let state = match self {
PathOrObject::Unspecified => return None,
PathOrObject::Path(path_buf) => return Some(path_buf),
PathOrObject::Object(s) => s,
};
let path = get_tmp_path("io_config.json");
let dto = crate::dto::IoConfigDto::from(&state);
write_file(path.clone(), &dto).await.expect("failed");
Some(path)
}
pub(crate) async fn load(self) -> Result<IoConfig, std::io::Error> {
match self {
PathOrObject::Unspecified => Ok(IoConfig::default()),
PathOrObject::Path(path_buf) => {
let dto: dto::IoConfigDto = read_file(path_buf).await?;
IoConfig::try_from(dto).map_err(io::Error::other)
}
PathOrObject::Object(t) => Ok(t),
}
}
}
impl PathOrObject<IoConfig> {
#[allow(unused)]
pub(crate) async fn write_to_temp_file(&self) -> PathBuf {
todo!()
}
}
impl<T> Default for PathOrObject<T> {
fn default() -> Self {
Self::Unspecified
}
}
#[derive(Debug, Clone)]
enum TimestampOrNow {
Now,
Timestamp(SystemTime),
}