use std::{
collections::BTreeMap,
net::SocketAddr,
sync::{Arc, RwLock, RwLockReadGuard},
};
use anyhow::{Context, Ok};
use serde::{Deserialize, Serialize};
use snap_control::server::state::ControlPlaneIoConfig;
use snap_dataplane::tunnel_gateway::state::TunnelGatewayIoConfig;
use crate::{
authorization_server::api::IoAuthServerConfig,
dto::{IoConfigDto, IoSnapConfigDto},
endhost_api::EndhostApiId,
state::{RouterId, snap::SnapId},
};
#[derive(Default, Debug, PartialEq, Clone, Serialize, Deserialize)]
pub struct IoConfig {
snaps: BTreeMap<SnapId, SnapIoConfig>,
auth_server: IoAuthServerConfig,
router_sockets: BTreeMap<RouterId, SocketAddr>,
endhost_apis: BTreeMap<EndhostApiId, SocketAddr>,
}
impl AsRef<IoConfig> for RwLockReadGuard<'_, IoConfig> {
fn as_ref(&self) -> &IoConfig {
self
}
}
impl TryFrom<IoConfigDto> for IoConfig {
type Error = anyhow::Error;
fn try_from(value: IoConfigDto) -> Result<Self, Self::Error> {
let snaps = value
.snaps
.into_iter()
.map(|(snap_id, snap_io_config)| {
Ok((
snap_id,
snap_io_config
.try_into()
.with_context(|| format!("invalid SNAP I/O config ({snap_id})"))?,
))
})
.collect::<Result<_, Self::Error>>()?;
let router_sockets = value
.router_sockets
.into_iter()
.map(|(router_socket_id, addr)| {
Ok((
router_socket_id,
addr.parse().context("invalid router socket address")?,
))
})
.collect::<Result<_, Self::Error>>()?;
let endhost_apis = value
.endhost_apis
.into_iter()
.map(|(id, addr)| {
Ok((
id,
addr.parse().context("invalid endhost api socket address")?,
))
})
.collect::<Result<_, Self::Error>>()?;
Ok(Self {
snaps,
router_sockets,
auth_server: value
.auth_server
.try_into()
.context("invalid auth server I/O config")?,
endhost_apis,
})
}
}
impl From<&IoConfig> for IoConfigDto {
fn from(config: &IoConfig) -> Self {
Self {
auth_server: (&config.auth_server).into(),
snaps: config
.snaps
.iter()
.map(|(snap_id, snap_io_config)| (*snap_id, snap_io_config.into()))
.collect(),
router_sockets: config
.router_sockets
.iter()
.map(|(router_socket_id, addr)| (*router_socket_id, addr.to_string()))
.collect(),
endhost_apis: config
.endhost_apis
.iter()
.map(|(id, addr)| (*id, addr.to_string()))
.collect(),
}
}
}
#[derive(Clone, Default)]
pub struct SharedPocketScionIoConfig {
state: Arc<RwLock<IoConfig>>,
}
impl SharedPocketScionIoConfig {
pub fn new() -> Self {
Self::default()
}
pub fn from_state(state: IoConfig) -> Self {
Self {
state: Arc::new(RwLock::new(state)),
}
}
pub fn from_dto(state: IoConfigDto) -> Result<Self, anyhow::Error> {
let state = IoConfig::try_from(state)?;
Ok(Self {
state: Arc::new(RwLock::new(state)),
})
}
pub fn to_dto(&self) -> IoConfigDto {
self.get_state().as_ref().into()
}
pub fn get_state(&self) -> RwLockReadGuard<'_, IoConfig> {
self.state.read().unwrap()
}
pub fn into_state(self) -> IoConfig {
Arc::into_inner(self.state)
.expect("no fail")
.into_inner()
.expect("no fail")
}
}
impl SharedPocketScionIoConfig {
pub fn set_snap_control_addr(&self, snap_id: SnapId, control_plane_api_addr: SocketAddr) {
let mut sstate = self.state.write().unwrap();
assert!(!sstate.snaps.contains_key(&snap_id), "SNAP already exists");
sstate.snaps.insert(
snap_id,
SnapIoConfig {
control_plane: ControlPlaneIoConfig {
api_addr: Some(control_plane_api_addr),
},
data_plane: Default::default(),
},
);
}
pub fn set_snap_data_plane_addr(&self, snap_id: SnapId, listen_addr: SocketAddr) {
let mut sstate = self.state.write().unwrap();
let snap_io_config = sstate.snaps.get_mut(&snap_id).expect("SNAP doesn't exist");
snap_io_config.data_plane = TunnelGatewayIoConfig::new(listen_addr);
}
pub fn snaps(&self) -> Vec<(SnapId, Option<SocketAddr>)> {
let rstate = self.state.read().expect("no fail");
rstate
.snaps
.iter()
.map(|(snap_id, snap_state)| (*snap_id, snap_state.control_plane.api_addr))
.collect()
}
pub fn snap_control_addr(&self, snap_id: SnapId) -> Option<SocketAddr> {
let rstate = self.state.read().expect("no fail");
rstate
.snaps
.get(&snap_id)
.and_then(|snap| snap.control_plane.api_addr)
}
pub fn snap_data_plane_addr(&self, snap_id: SnapId) -> Option<SocketAddr> {
let rstate = self.state.read().expect("no fail");
let snap = rstate.snaps.get(&snap_id)?;
snap.data_plane.listen_addr
}
pub fn router_socket_addr(&self, router_socket_id: RouterId) -> Option<SocketAddr> {
let rstate = self.state.read().expect("no fail");
rstate.router_sockets.get(&router_socket_id).copied()
}
pub fn set_router_socket_addr(&self, router_socket_id: RouterId, addr: SocketAddr) {
let mut sstate = self.state.write().unwrap();
sstate.router_sockets.insert(router_socket_id, addr);
}
pub fn auth_server_addr(&self) -> Option<SocketAddr> {
let rstate = self.state.read().expect("no fail");
rstate.auth_server.addr
}
pub fn set_auth_server_addr(&self, addr: SocketAddr) {
let mut sstate = self.state.write().unwrap();
sstate.auth_server.addr = Some(addr);
}
pub fn endhost_api_addr(&self, id: EndhostApiId) -> Option<SocketAddr> {
self.state.read().unwrap().endhost_apis.get(&id).cloned()
}
pub fn set_endhost_api_addr(&self, id: EndhostApiId, addr: SocketAddr) {
self.state.write().unwrap().endhost_apis.insert(id, addr);
}
}
#[derive(Debug, PartialEq, Clone, Default, Serialize, Deserialize)]
pub struct SnapIoConfig {
pub control_plane: ControlPlaneIoConfig,
pub data_plane: TunnelGatewayIoConfig,
}
impl From<&SnapIoConfig> for IoSnapConfigDto {
fn from(value: &SnapIoConfig) -> Self {
IoSnapConfigDto {
control_plane: (&value.control_plane).into(),
data_plane: (&value.data_plane).into(),
}
}
}
impl TryFrom<IoSnapConfigDto> for SnapIoConfig {
type Error = anyhow::Error;
fn try_from(value: IoSnapConfigDto) -> Result<Self, Self::Error> {
let data_plane = value
.data_plane
.try_into()
.with_context(|| "Invalid data plane config".to_string())?;
Ok(Self {
control_plane: value.control_plane.try_into()?,
data_plane,
})
}
}
#[cfg(test)]
mod tests {
use std::net::Ipv4Addr;
use snap_dataplane::state::Id;
use super::*;
#[test]
fn should_convert_to_dto_and_back_without_data_loss() {
let io_config = SharedPocketScionIoConfig::new();
let tunnel_addr = "127.0.0.1:9000".parse().unwrap();
let cp_api = std::net::SocketAddr::from((Ipv4Addr::LOCALHOST, 9002));
let snap_id = SnapId::from_usize(1);
io_config.set_snap_control_addr(snap_id, cp_api);
io_config.set_snap_data_plane_addr(snap_id, tunnel_addr);
let before = io_config.state.read().unwrap().clone();
let dto_io_config = io_config.to_dto();
let after = IoConfig::try_from(dto_io_config).expect("failed to convert back");
assert_eq!(before, after);
}
}