mod discovery_context;
mod igd_manager;
mod network_state;
mod network_tcp;
mod network_udp;
mod protocol;
mod start_protocols;
mod tasks;
pub(super) use protocol::*;
use super::*;
use crate::routing_table::*;
use connection_manager::*;
use discovery_context::*;
use network_state::*;
use network_tcp::*;
use protocol::tcp::RawTcpProtocolHandler;
use protocol::udp::RawUdpProtocolHandler;
use protocol::ws::WebsocketProtocolHandler;
use start_protocols::*;
use futures_rustls::{
pki_types::{
pem::PemObject as _, CertificateDer, PrivateKeyDer, PrivatePkcs1KeyDer, PrivatePkcs8KeyDer,
},
rustls::server::ServerConfig,
TlsAcceptor,
};
use futures_util::StreamExt;
use std::fs::File;
use std::io;
use std::io::BufReader;
use std::path::{Path, PathBuf};
impl_veilid_log_facility!("net");
pub const MAX_DIAL_INFO_FAILURE_COUNT: usize = 100;
pub const UPDATE_OUTBOUND_ONLY_DIAL_INFO_PERIOD_SECS: u32 = 10;
pub const UPDATE_DIAL_INFO_TASK_TICK_PERIOD_SECS: u32 = 1;
pub const NETWORK_INTERFACES_TASK_TICK_PERIOD_SECS: u32 = 1;
pub const UPNP_TASK_TICK_PERIOD_SECS: u32 = 1;
pub const HOLE_PUNCH_TTL: u32 = 3;
pub const PEEK_DETECT_LEN: usize = 64;
struct NetworkInner {
network_needs_restart: bool,
dial_info_failure_count: BTreeMap<RoutingDomain, usize>,
needs_update_dial_info: bool,
resolved_detect_address_changes: bool,
next_outbound_only_dial_info_check: Timestamp,
join_handles: Vec<MustJoinHandle<()>>,
stop_source: Option<StopSource>,
bound_address_per_protocol: BTreeMap<ProtocolType, Vec<SocketAddr>>,
udp_protocol_handlers: BTreeMap<SocketAddr, RawUdpProtocolHandler>,
tls_acceptor: Option<TlsAcceptor>,
listener_states: BTreeMap<SocketAddr, Arc<RwLock<ListenerState>>>,
preferred_local_addresses: BTreeMap<(ProtocolType, AddressType), SocketAddr>,
static_public_dial_info: ProtocolTypeSet,
network_state: Option<Arc<NetworkState>>,
}
pub(super) struct NetworkUnlockedInner {
startup_lock: StartupLock,
interfaces: NetworkInterfaces,
update_dial_info_task: TickTask<EyreReport>,
network_interfaces_task: TickTask<EyreReport>,
upnp_task: TickTask<EyreReport>,
network_task_lock: AsyncMutex<()>,
igd_manager: igd_manager::IGDManager,
}
#[derive(Clone)]
pub(super) struct Network {
registry: VeilidComponentRegistry,
inner: Arc<Mutex<NetworkInner>>,
unlocked_inner: Arc<NetworkUnlockedInner>,
}
impl_veilid_component_accessors!(Network);
impl core::ops::Deref for Network {
type Target = NetworkUnlockedInner;
fn deref(&self) -> &Self::Target {
&self.unlocked_inner
}
}
impl Network {
fn new_inner() -> NetworkInner {
NetworkInner {
network_needs_restart: false,
dial_info_failure_count: BTreeMap::new(),
needs_update_dial_info: false,
resolved_detect_address_changes: false,
next_outbound_only_dial_info_check: Timestamp::default(),
join_handles: Vec::new(),
stop_source: None,
bound_address_per_protocol: BTreeMap::new(),
udp_protocol_handlers: BTreeMap::new(),
tls_acceptor: None,
listener_states: BTreeMap::new(),
preferred_local_addresses: BTreeMap::new(),
static_public_dial_info: ProtocolTypeSet::new(),
network_state: None,
}
}
fn new_unlocked_inner(registry: VeilidComponentRegistry) -> NetworkUnlockedInner {
NetworkUnlockedInner {
startup_lock: StartupLock::new(),
interfaces: NetworkInterfaces::new(),
update_dial_info_task: TickTask::new(
"update_dial_info_task",
UPDATE_DIAL_INFO_TASK_TICK_PERIOD_SECS,
),
network_interfaces_task: TickTask::new(
"network_interfaces_task",
NETWORK_INTERFACES_TASK_TICK_PERIOD_SECS,
),
upnp_task: TickTask::new("upnp_task", UPNP_TASK_TICK_PERIOD_SECS),
network_task_lock: AsyncMutex::new(()),
igd_manager: igd_manager::IGDManager::new(registry),
}
}
pub fn new(registry: VeilidComponentRegistry) -> Self {
let this = Self {
inner: Arc::new(Mutex::new(Self::new_inner())),
unlocked_inner: Arc::new(Self::new_unlocked_inner(registry.clone())),
registry,
};
this.setup_tasks();
this
}
fn load_certs(path: &Path) -> io::Result<Vec<CertificateDer<'static>>> {
let cvec = CertificateDer::<'static>::pem_reader_iter(&mut BufReader::new(
File::open(path)?,
))
.collect::<Result<Vec<CertificateDer<'static>>, futures_rustls::pki_types::pem::Error>>()
.map_err(io::Error::other)?;
Ok(cvec)
}
fn load_keys(path: &Path) -> io::Result<Vec<PrivateKeyDer<'static>>> {
{
if let Ok(v) = PrivatePkcs1KeyDer::<'static>::pem_reader_iter(&mut BufReader::new(
File::open(path)?,
))
.collect::<Result<Vec<PrivatePkcs1KeyDer<'static>>, futures_rustls::pki_types::pem::Error>>()
{
if !v.is_empty() {
return Ok(v
.into_iter()
.map(PrivateKeyDer::Pkcs1)
.collect::<Vec<PrivateKeyDer<'static>>>());
}
}
}
{
if let Ok(v) = PrivatePkcs8KeyDer::<'static>::pem_reader_iter(&mut BufReader::new(
File::open(path)?,
))
.collect::<Result<Vec<PrivatePkcs8KeyDer<'static>>,futures_rustls::pki_types::pem::Error>>()
{
if !v.is_empty() {
return Ok(v.into_iter().map(PrivateKeyDer::Pkcs8).collect());
}
}
}
Err(io::Error::new(
io::ErrorKind::InvalidInput,
"invalid TLS private key",
))
}
fn load_server_config(&self) -> io::Result<ServerConfig> {
let config = self.config();
veilid_log!(self trace
"loading certificate from {}",
config.network.tls.certificate_path
);
let certs_path = PathBuf::from(&config.network.tls.certificate_path);
let certs = Self::load_certs(&certs_path)?;
veilid_log!(self trace "loaded {} certificates", certs.len());
if certs.is_empty() {
return Err(io::Error::new(io::ErrorKind::InvalidInput, format!("Certificates at {} could not be loaded.\nEnsure it is in PEM format, beginning with '-----BEGIN CERTIFICATE-----'",config.network.tls.certificate_path)));
}
veilid_log!(self trace
"loading private key from {}",
config.network.tls.private_key_path
);
let keys_path = PathBuf::from(&config.network.tls.private_key_path);
let mut keys = Self::load_keys(&keys_path)?;
veilid_log!(self trace "loaded {} keys", keys.len());
if keys.is_empty() {
return Err(io::Error::new(io::ErrorKind::InvalidInput, format!("Private key at {} could not be loaded.\nEnsure it is unencrypted and in RSA or PKCS8 format, beginning with '-----BEGIN RSA PRIVATE KEY-----' or '-----BEGIN PRIVATE KEY-----'",config.network.tls.private_key_path)));
}
let config = ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(certs, keys.remove(0))
.map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?;
Ok(config)
}
fn add_to_join_handles(&self, jh: MustJoinHandle<()>) {
let mut inner = self.inner.lock();
inner.join_handles.push(jh);
}
fn translate_unspecified_address(&self, from: SocketAddr) -> Vec<SocketAddr> {
if !from.ip().is_unspecified() {
vec![from]
} else {
self.last_network_state()
.unwrap_or_log()
.interface_address_state
.interface_addresses
.iter()
.filter_map(|a| {
if (a.ip().is_ipv4() && from.ip().is_ipv4())
|| (a.ip().is_ipv6() && from.ip().is_ipv6())
{
Some(SocketAddr::new(a.ip(), from.port()))
} else {
None
}
})
.collect()
}
}
pub fn get_preferred_local_address(&self, dial_info: &DialInfo) -> Option<SocketAddr> {
let inner = self.inner.lock();
let key = (dial_info.protocol_type(), dial_info.address_type());
inner.preferred_local_addresses.get(&key).copied()
}
pub fn get_preferred_local_address_by_key(
&self,
pt: ProtocolType,
at: AddressType,
) -> Option<SocketAddr> {
let inner = self.inner.lock();
let key = (pt, at);
inner.preferred_local_addresses.get(&key).copied()
}
async fn record_dial_info_failure<T, F: Future<Output = EyreResult<NetworkResult<T>>>>(
&self,
dial_info: DialInfo,
fut: F,
) -> EyreResult<NetworkResult<T>> {
let opt_routing_domain = self
.routing_table()
.routing_domain_for_address(dial_info.address());
let network_result = pin_future_closure!(fut).await?;
if matches!(
network_result,
NetworkResult::NoConnection(_) | NetworkResult::Timeout
) {
self.network_manager()
.address_filter()
.set_dial_info_failed(dial_info);
if let Some(rd) = opt_routing_domain {
let dial_info_failure_count = {
let mut inner = self.inner.lock();
*inner
.dial_info_failure_count
.entry(rd)
.and_modify(|x| *x += 1)
.or_insert(1)
};
if dial_info_failure_count == MAX_DIAL_INFO_FAILURE_COUNT {
veilid_log!(self debug "Node may be offline. Exceeded maximum dial info failure count for {:?}", rd);
}
}
} else {
if let Some(rd) = opt_routing_domain {
let mut inner = self.inner.lock();
inner.dial_info_failure_count.remove(&rd);
}
}
Ok(network_result)
}
#[cfg_attr(feature = "instrument", instrument(level="trace", target="net", err, skip(self, data), fields(data.len = data.len())))]
pub async fn send_data_unbound_to_dial_info(
&self,
dial_info: DialInfo,
data: Bytes,
) -> EyreResult<NetworkResult<()>> {
let _guard = self.startup_lock.enter()?;
self.record_dial_info_failure(
dial_info.clone(),
async move {
let data_len = data.len();
let connect_timeout_ms = self.config().network.connection_initial_timeout_ms;
if self
.network_manager()
.address_filter()
.is_ip_addr_punished(dial_info.address().ip_addr())
{
return Ok(NetworkResult::no_connection_other("punished"));
}
match dial_info.protocol_type() {
ProtocolType::UDP => {
let peer_socket_addr = dial_info.to_socket_addr();
let h = RawUdpProtocolHandler::new_unspecified_bound_handler(
self.registry(),
&peer_socket_addr,
)
.wrap_err("create socket failure")?;
let _ = network_result_try!(h
.send_message(data, peer_socket_addr)
.await
.map(NetworkResult::Value)
.wrap_err("send message failure")?);
}
ProtocolType::TCP => {
let peer_socket_addr = dial_info.to_socket_addr();
let pnc = network_result_try!(RawTcpProtocolHandler::connect(
self.registry(),
None,
peer_socket_addr,
connect_timeout_ms
)
.await
.wrap_err("connect failure")?);
network_result_try!(pnc.send(data).await.wrap_err("send failure")?);
}
ProtocolType::WS => {
let pnc = network_result_try!(WebsocketProtocolHandler::connect(
self.registry(),
None,
&dial_info,
connect_timeout_ms
)
.await
.wrap_err("connect failure")?);
network_result_try!(pnc.send(data).await.wrap_err("send failure")?);
}
#[cfg(feature = "enable-protocol-wss")]
ProtocolType::WSS => {
let pnc = network_result_try!(WebsocketProtocolHandler::connect(
self.registry(),
None,
&dial_info,
connect_timeout_ms
)
.await
.wrap_err("connect failure")?);
network_result_try!(pnc.send(data).await.wrap_err("send failure")?);
}
}
self.network_manager()
.stats_packet_sent(dial_info.ip_addr(), ByteCount::new(data_len as u64));
Ok(NetworkResult::Value(()))
}
.in_current_span(),
)
.await
}
#[cfg_attr(feature = "instrument", instrument(level="trace", target="net", err, skip(self, data), fields(data.len = data.len())))]
pub async fn send_recv_data_unbound_to_dial_info(
&self,
dial_info: DialInfo,
data: Bytes,
timeout_ms: u32,
) -> EyreResult<NetworkResult<Bytes>> {
let _guard = self.startup_lock.enter()?;
self.record_dial_info_failure(
dial_info.clone(),
async move {
let data_len = data.len();
let connect_timeout_ms = self.config().network.connection_initial_timeout_ms;
if self
.network_manager()
.address_filter()
.is_ip_addr_punished(dial_info.address().ip_addr())
{
return Ok(NetworkResult::no_connection_other("punished"));
}
match dial_info.protocol_type() {
ProtocolType::UDP => {
let peer_socket_addr = dial_info.to_socket_addr();
let h = RawUdpProtocolHandler::new_unspecified_bound_handler(
self.registry(),
&peer_socket_addr,
)
.wrap_err("create socket failure")?;
network_result_try!(h
.send_message(data, peer_socket_addr)
.await
.wrap_err("send message failure")?);
self.network_manager().stats_packet_sent(
dial_info.ip_addr(),
ByteCount::new(data_len as u64),
);
let mut out = BytesMut::zeroed(MAX_MESSAGE_SIZE);
let (recv_len, recv_addr) = network_result_try!(timeout(
timeout_ms,
h.recv_message(&mut out).in_current_span()
)
.await
.into_network_result())
.wrap_err("recv_message failure")?;
let recv_socket_addr = recv_addr.remote_address().socket_addr();
self.network_manager().stats_packet_rcvd(
recv_socket_addr.ip(),
ByteCount::new(recv_len as u64),
);
if recv_socket_addr != peer_socket_addr {
bail!("wrong address");
}
out.resize(recv_len, 0u8);
Ok(NetworkResult::Value(out.into()))
}
_ => {
let pnc = network_result_try!(match dial_info.protocol_type() {
ProtocolType::UDP => unreachable!(),
ProtocolType::TCP => {
let peer_socket_addr = dial_info.to_socket_addr();
RawTcpProtocolHandler::connect(
self.registry(),
None,
peer_socket_addr,
connect_timeout_ms,
)
.await
.wrap_err("connect failure")?
}
ProtocolType::WS => {
WebsocketProtocolHandler::connect(
self.registry(),
None,
&dial_info,
connect_timeout_ms,
)
.await
.wrap_err("connect failure")?
}
#[cfg(feature = "enable-protocol-wss")]
ProtocolType::WSS => {
WebsocketProtocolHandler::connect(
self.registry(),
None,
&dial_info,
connect_timeout_ms,
)
.await
.wrap_err("connect failure")?
}
});
network_result_try!(pnc.send(data).await.wrap_err("send failure")?);
self.network_manager().stats_packet_sent(
dial_info.ip_addr(),
ByteCount::new(data_len as u64),
);
let out = network_result_try!(network_result_try!(timeout(
timeout_ms,
pnc.recv().in_current_span()
)
.await
.into_network_result())
.wrap_err("recv failure")?);
self.network_manager().stats_packet_rcvd(
dial_info.ip_addr(),
ByteCount::new(out.len() as u64),
);
Ok(NetworkResult::Value(out))
}
}
}
.in_current_span(),
)
.await
}
#[cfg_attr(feature = "instrument", instrument(level="trace", target="net", err, skip(self, data), fields(data.len = data.len())))]
pub async fn send_data_to_existing_flow(
&self,
flow: Flow,
data: Bytes,
) -> EyreResult<SendDataToExistingFlowResult> {
let _guard = self.startup_lock.enter()?;
let data_len = data.len();
if flow.protocol_type() == ProtocolType::UDP {
let peer_socket_addr = flow.remote().socket_addr();
if let Some(ph) = self.find_best_udp_protocol_handler(
&peer_socket_addr,
&flow.local().map(|sa| sa.socket_addr()),
) {
network_result_value_or_log!(self ph.clone()
.send_message(data.clone(), peer_socket_addr)
.await
.wrap_err("sending data to existing flow")? => [ format!(": data.len={}, flow={:?}", data.len(), flow) ]
{ return Ok(SendDataToExistingFlowResult::NotSent(data)); } );
self.network_manager()
.stats_packet_sent(peer_socket_addr.ip(), ByteCount::new(data_len as u64));
let unique_flow = UniqueFlow {
flow,
connection_id: None,
};
return Ok(SendDataToExistingFlowResult::Sent(unique_flow));
}
}
if let Some(conn) = self
.network_manager()
.connection_manager()
.get_connection(flow)
{
match conn.send_async(data).await {
ConnectionHandleSendResult::Sent => {
self.network_manager().stats_packet_sent(
flow.remote().socket_addr().ip(),
ByteCount::new(data_len as u64),
);
return Ok(SendDataToExistingFlowResult::Sent(conn.unique_flow()));
}
ConnectionHandleSendResult::NotSent(data) => {
return Ok(SendDataToExistingFlowResult::NotSent(data));
}
}
}
Ok(SendDataToExistingFlowResult::NotSent(data))
}
#[cfg_attr(feature = "instrument", instrument(level="trace", target="net", err, skip(self, data), fields(data.len = data.len())))]
pub async fn send_data_to_dial_info(
&self,
dial_info: DialInfo,
data: Bytes,
) -> EyreResult<NetworkResult<UniqueFlow>> {
let _guard = self.startup_lock.enter()?;
self.record_dial_info_failure(
dial_info.clone(),
async move {
let data_len = data.len();
let unique_flow;
if dial_info.protocol_type() == ProtocolType::UDP {
let peer_socket_addr = dial_info.to_socket_addr();
let ph = match self.find_best_udp_protocol_handler(&peer_socket_addr, &None) {
Some(ph) => ph,
None => {
return Ok(NetworkResult::no_connection_other(
"no appropriate UDP protocol handler for dial_info",
));
}
};
let flow = network_result_try!(ph
.send_message(data, peer_socket_addr)
.await
.wrap_err("failed to send data to dial info")?);
unique_flow = UniqueFlow {
flow,
connection_id: None,
};
} else {
let connmgr = self.network_manager().connection_manager();
let conn = network_result_try!(
connmgr.get_or_create_connection(dial_info.clone()).await?
);
if let ConnectionHandleSendResult::NotSent(_) = conn.send_async(data).await {
return Ok(NetworkResult::NoConnection(io::Error::new(
io::ErrorKind::ConnectionReset,
"failed to send",
)));
}
unique_flow = conn.unique_flow();
}
self.network_manager()
.stats_packet_sent(dial_info.ip_addr(), ByteCount::new(data_len as u64));
Ok(NetworkResult::value(unique_flow))
}
.in_current_span(),
)
.await
}
#[cfg_attr(
feature = "instrument",
instrument(level = "trace", target = "net", err, skip(self), fields(__VEILID_LOG_KEY = self.log_key()))
)]
pub async fn send_hole_punch(
&self,
dial_info: DialInfo,
) -> EyreResult<NetworkResult<UniqueFlow>> {
let _guard = self.startup_lock.enter()?;
self.record_dial_info_failure(
dial_info.clone(),
async move {
let unique_flow;
if dial_info.protocol_type().low_level_protocol_type() == LowLevelProtocolType::UDP
{
let peer_socket_addr = dial_info.to_socket_addr();
let ph = match self.find_best_udp_protocol_handler(&peer_socket_addr, &None) {
Some(ph) => ph,
None => {
return Ok(NetworkResult::no_connection_other(
"no appropriate UDP protocol handler for dial_info",
));
}
};
let flow = network_result_try!(ph
.send_hole_punch(peer_socket_addr, HOLE_PUNCH_TTL)
.await
.wrap_err("failed to send hole punch to dial info")?);
unique_flow = UniqueFlow {
flow,
connection_id: None,
};
} else {
return Ok(NetworkResult::ServiceUnavailable(
"unimplemented for this protocol".to_owned(),
));
}
self.network_manager()
.stats_packet_sent(dial_info.ip_addr(), ByteCount::new(0));
Ok(NetworkResult::value(unique_flow))
}
.in_current_span(),
)
.await
}
pub async fn startup_internal(&self) -> EyreResult<StartupDisposition> {
let network_state = self.refresh_network_state().await?.unwrap_or_log();
let resolved_detect_address_changes = {
let mut inner = self.inner.lock();
inner.stop_source = Some(StopSource::new());
let detect_address_changes = self.config().network.detect_address_changes;
if let Some(detect_address_changes) = detect_address_changes {
inner.resolved_detect_address_changes = detect_address_changes;
if inner.resolved_detect_address_changes {
veilid_log!(self info "Manually-enabled detection of address changes");
} else {
veilid_log!(self info "Manually-disabled detection of address changes");
}
} else {
let mut global_ipv4 = false;
let mut global_ipv6 = false;
for siaddr in network_state
.interface_address_state
.interface_addresses
.iter()
{
if Address::from_ip_addr(siaddr.ip()).is_global() {
match siaddr {
IfAddr::V4(_) => {
global_ipv4 = true;
}
IfAddr::V6(_) => {
global_ipv6 = true;
}
}
}
}
inner.resolved_detect_address_changes = !(global_ipv4 && global_ipv6);
if inner.resolved_detect_address_changes {
veilid_log!(self info "Auto-enabled detection of address changes: global_ipv4={}, global_ipv6={}", global_ipv4, global_ipv6);
} else {
veilid_log!(self info "Auto-disabled detection of address changes because this node has global IPv4 and IPv6 addresses");
}
}
inner.resolved_detect_address_changes
};
let routing_table = self.routing_table();
let confirmed_public_internet;
{
let mut editor_public_internet = routing_table.edit_public_internet_routing_domain();
let mut editor_local_network = routing_table.edit_local_network_routing_domain();
editor_local_network.set_interface_addresses(
network_state
.interface_address_state
.as_ref()
.interface_addresses
.clone(),
);
editor_local_network.setup_network(
network_state.protocol_config.outbound,
network_state.protocol_config.inbound,
network_state.protocol_config.family_local,
network_state
.protocol_config
.local_network_capabilities
.clone(),
true,
);
confirmed_public_internet = !resolved_detect_address_changes
|| self.config().network.privacy.require_inbound_relay;
editor_public_internet.set_interface_addresses(
network_state
.interface_address_state
.as_ref()
.interface_addresses
.clone(),
);
editor_public_internet.setup_network(
network_state.protocol_config.outbound,
network_state.protocol_config.inbound,
network_state.protocol_config.family_global,
network_state
.protocol_config
.public_internet_capabilities
.clone(),
confirmed_public_internet,
);
if network_state
.protocol_config
.inbound
.contains(ProtocolType::UDP)
{
let res = self.bind_udp_protocol_handlers();
if !matches!(res, Ok(StartupDisposition::Success)) {
return res;
}
}
if network_state
.protocol_config
.inbound
.contains(ProtocolType::WS)
{
let res = self.start_ws_listeners();
if !matches!(res, Ok(StartupDisposition::Success)) {
return res;
}
}
#[cfg(feature = "enable-protocol-wss")]
if network_state
.protocol_config
.inbound
.contains(ProtocolType::WSS)
{
let res = self.start_wss_listeners();
if !matches!(res, Ok(StartupDisposition::Success)) {
return res;
}
}
if network_state
.protocol_config
.inbound
.contains(ProtocolType::TCP)
{
let res = self.start_tcp_listeners();
if !matches!(res, Ok(StartupDisposition::Success)) {
return res;
}
}
self.register_all_dial_info(&mut editor_public_internet, &mut editor_local_network)?;
if editor_public_internet.commit(true).await {
editor_public_internet.publish();
}
if editor_local_network.commit(true).await {
editor_local_network.publish();
}
}
if !confirmed_public_internet {
self.trigger_update_dial_info(RoutingDomain::PublicInternet);
} else {
let pi = routing_table.get_current_peer_info(RoutingDomain::PublicInternet);
if !pi.node_info().has_any_dial_info()
&& !self.config().network.privacy.require_inbound_relay
{
veilid_log!(self warn
"This node has no valid public dial info.\nConfigure this node with a static public IP address and correct firewall rules."
);
}
}
Ok(StartupDisposition::Success)
}
#[cfg_attr(feature = "instrument", instrument(level = "debug", err, skip_all, fields(__VEILID_LOG_KEY = self.log_key())))]
pub(super) fn register_all_dial_info(
&self,
editor_public_internet: &mut RoutingDomainEditorPublicInternet<'_>,
editor_local_network: &mut RoutingDomainEditorLocalNetwork<'_>,
) -> EyreResult<()> {
let Some(protocol_config) = ({
let inner = self.inner.lock();
inner
.network_state
.as_ref()
.map(|ns| ns.protocol_config.clone())
}) else {
bail!("can't register dial info without network state");
};
if protocol_config.inbound.contains(ProtocolType::UDP) {
self.register_udp_dial_info(editor_public_internet, editor_local_network)?;
}
if protocol_config.inbound.contains(ProtocolType::WS) {
self.register_ws_dial_info(editor_public_internet, editor_local_network)?;
}
#[cfg(feature = "enable-protocol-wss")]
if protocol_config.inbound.contains(ProtocolType::WSS) {
self.register_wss_dial_info(editor_public_internet, editor_local_network)?;
}
if protocol_config.inbound.contains(ProtocolType::TCP) {
self.register_tcp_dial_info(editor_public_internet, editor_local_network)?;
}
Ok(())
}
#[cfg_attr(feature = "instrument", instrument(level = "debug", err, skip_all, fields(__VEILID_LOG_KEY = self.log_key())))]
pub async fn startup(&self) -> EyreResult<StartupDisposition> {
let guard = self.startup_lock.startup()?;
match self.startup_internal().await {
Ok(StartupDisposition::Success) => {
veilid_log!(self debug "Network started");
guard.success();
Ok(StartupDisposition::Success)
}
Ok(StartupDisposition::BindRetry) => {
debug!("network bind retry");
self.shutdown_internal().await;
Ok(StartupDisposition::BindRetry)
}
Err(e) => {
debug!("network failed to start");
self.shutdown_internal().await;
Err(e)
}
}
}
pub fn needs_restart(&self) -> bool {
self.inner.lock().network_needs_restart
}
pub fn is_started(&self) -> bool {
self.startup_lock.is_started()
}
#[cfg_attr(feature = "instrument", instrument(level = "debug", skip_all, fields(__VEILID_LOG_KEY = self.log_key())))]
pub fn restart_network(&self) {
self.inner.lock().network_needs_restart = true;
}
#[cfg_attr(feature = "instrument", instrument(level = "debug", skip_all, fields(__VEILID_LOG_KEY = self.log_key())))]
async fn shutdown_internal(&self) {
let routing_table = self.routing_table();
let mut unord = FuturesUnordered::new();
{
let mut inner = self.inner.lock();
for h in inner.join_handles.drain(..) {
veilid_log!(self trace "joining: {:?}", h);
unord.push(h);
}
drop(inner.stop_source.take());
}
veilid_log!(self debug "stopping {} low level network tasks", unord.len());
while unord.next().await.is_some() {}
veilid_log!(self debug "clearing dial info");
routing_table
.edit_public_internet_routing_domain()
.reset()
.await;
routing_table
.edit_local_network_routing_domain()
.reset()
.await;
*self.inner.lock() = Self::new_inner();
}
#[cfg_attr(feature = "instrument", instrument(level = "debug", skip_all, fields(__VEILID_LOG_KEY = self.log_key())))]
pub async fn shutdown(&self) {
veilid_log!(self debug "starting low level network shutdown");
let Ok(guard) = self.startup_lock.shutdown().await else {
veilid_log!(self error "low level network is already shut down");
return;
};
self.shutdown_internal().await;
guard.success();
veilid_log!(self debug "finished low level network shutdown");
}
pub fn needs_update_dial_info(&self) -> bool {
let Ok(_guard) = self.startup_lock.enter() else {
veilid_log!(self debug "ignoring 'needs_update_dial_info' due to not started up");
return false;
};
self.inner.lock().needs_update_dial_info
}
pub fn resolved_detect_address_changes(&self) -> bool {
let Ok(_guard) = self.startup_lock.enter() else {
veilid_log!(self debug "ignoring 'resolved_detect_address_changes' due to not started up");
return false;
};
self.inner.lock().resolved_detect_address_changes
}
pub fn trigger_update_dial_info(&self, routing_domain: RoutingDomain) {
let Ok(_guard) = self.startup_lock.enter() else {
veilid_log!(self debug "ignoring 'trigger_update_dial_info' due to not started up");
return;
};
if !matches!(routing_domain, RoutingDomain::PublicInternet) {
return;
}
self.inner.lock().needs_update_dial_info = true;
}
}