use citadel_proto::prelude::*;
use citadel_io::ServerMode;
use citadel_proto::kernel::KernelExecutorArguments;
use citadel_proto::macros::{ContextRequirements, LocalContextRequirements};
use citadel_types::crypto::{HeaderObfuscatorSettings, PreSharedKey};
use futures::Future;
use std::fmt::{Debug, Formatter};
use std::marker::PhantomData;
use std::pin::Pin;
use std::task::{Context, Poll};
pub struct NodeBuilder<R: Ratchet = StackedRatchet, T: PlatformOps = DefaultTransport> {
hypernode_type: Option<NodeType>,
underlying_protocol: Option<ServerMode<T>>,
backend_type: Option<BackendType>,
server_argon_settings: Option<ArgonDefaultServerSettings>,
#[cfg(feature = "google-services")]
services: Option<ServicesConfig>,
server_misc_settings: Option<ServerMiscSettings>,
client_tls_config: Option<T::ClientConfig>,
kernel_executor_settings: Option<KernelExecutorSettings>,
stun_servers: Option<Vec<String>>,
turn_servers: Option<Vec<TurnServerConfig>>,
local_only_server_settings: Option<ServerOnlySessionInitSettings>,
websocket_listen_addr: Option<std::net::SocketAddr>,
#[cfg(target_family = "wasm")]
serverless_config: Option<ServerlessConfig>,
_ratchet: PhantomData<R>,
_transport: PhantomData<T>,
}
pub type DefaultNodeBuilder = NodeBuilder<StackedRatchet, DefaultTransport>;
pub type LightweightNodeBuilder = NodeBuilder<MonoRatchet, DefaultTransport>;
impl<R: Ratchet, T: PlatformOps> Default for NodeBuilder<R, T> {
fn default() -> Self {
Self {
hypernode_type: None,
underlying_protocol: None,
backend_type: None,
server_argon_settings: None,
#[cfg(feature = "google-services")]
services: None,
server_misc_settings: None,
client_tls_config: None,
kernel_executor_settings: None,
stun_servers: None,
turn_servers: None,
local_only_server_settings: None,
websocket_listen_addr: None,
#[cfg(target_family = "wasm")]
serverless_config: None,
_ratchet: Default::default(),
_transport: Default::default(),
}
}
}
pub struct NodeFuture<'a, K> {
inner: Pin<Box<dyn FutureContextRequirements<'a, Result<K, NetworkError>>>>,
_pd: PhantomData<fn() -> K>,
}
#[cfg(feature = "multi-threaded")]
trait FutureContextRequirements<'a, Output>:
Future<Output = Output> + Send + LocalContextRequirements<'a>
{
}
#[cfg(feature = "multi-threaded")]
impl<'a, T: Future<Output = Output> + Send + LocalContextRequirements<'a>, Output>
FutureContextRequirements<'a, Output> for T
{
}
#[cfg(not(feature = "multi-threaded"))]
trait FutureContextRequirements<'a, Output>:
Future<Output = Output> + LocalContextRequirements<'a>
{
}
#[cfg(not(feature = "multi-threaded"))]
impl<'a, T: Future<Output = Output> + LocalContextRequirements<'a>, Output>
crate::builder::node_builder::FutureContextRequirements<'a, Output> for T
{
}
impl<K> Debug for NodeFuture<'_, K> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "NodeFuture")
}
}
impl<K> Future for NodeFuture<'_, K> {
type Output = Result<K, NetworkError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.inner.as_mut().poll(cx)
}
}
impl<R: Ratchet + ContextRequirements, T: PlatformOps> NodeBuilder<R, T> {
pub fn build<'a, 'b: 'a, K: NetKernel<R> + 'b>(
&'a mut self,
kernel: K,
) -> anyhow::Result<NodeFuture<'b, K>> {
self.check()?;
let hypernode_type = self.hypernode_type.take().unwrap_or_default();
let backend_type = self.backend_type.take().unwrap_or_default();
let server_argon_settings = self.server_argon_settings.take();
#[cfg(feature = "google-services")]
let server_services_cfg = self.services.take();
#[cfg(not(feature = "google-services"))]
let server_services_cfg = None;
let server_misc_settings = self.server_misc_settings.take();
let client_config = self.client_tls_config.take();
let kernel_executor_settings = self.kernel_executor_settings.take().unwrap_or_default();
let stun_servers = self.stun_servers.take();
let turn_servers = self.turn_servers.take();
let underlying_proto = self.underlying_protocol.take();
let server_only_session_init_settings = self.local_only_server_settings.take();
let websocket_listen_addr = self.websocket_listen_addr.take();
#[cfg(target_family = "wasm")]
let serverless_config = self.serverless_config.take();
Ok(NodeFuture {
_pd: Default::default(),
inner: Box::pin(async move {
let underlying_proto = match underlying_proto {
Some(proto) => proto,
None => T::default_server_config().await.map_err(|err| {
citadel_io::error!(
citadel_io::ErrorCode::NodeDefaultServerConfigFailed,
err.to_string()
)
})?,
};
T::config_warnings(&underlying_proto);
#[cfg(target_family = "wasm")]
let (pre_built_listener, client_config, hypernode_type) = if let Some(sl_config) =
serverless_config
{
let conn = establish_serverless_connection(
sl_config.signaling.as_ref(),
&sl_config.room_token,
&sl_config.ice_servers,
sl_config.poll_interval_ms,
sl_config.timeout_ms,
)
.await
.map_err(|e: std::io::Error| NetworkError::generic(e.to_string()))?;
T::setup_serverless_transport(conn.stream, conn.is_server_role, client_config)
} else {
(None, client_config, hypernode_type)
};
#[cfg(not(target_family = "wasm"))]
let pre_built_listener = None;
log::trace!(target: "citadel", "[NodeBuilder] Checking Tokio runtime ...");
let rt = citadel_io::try_current_runtime().map_err(NetworkError::generic)?;
log::trace!(target: "citadel", "[NodeBuilder] Creating account manager ...");
let account_manager = AccountManager::new(
backend_type,
server_argon_settings,
server_services_cfg,
server_misc_settings,
)
.await?;
let args: KernelExecutorArguments<_, _, T> = KernelExecutorArguments {
rt,
hypernode_type,
account_manager,
kernel,
underlying_proto,
client_config,
kernel_executor_settings,
stun_servers,
turn_servers,
server_only_session_init_settings,
websocket_listen_addr,
pre_built_listener,
};
log::trace!(target: "citadel", "[NodeBuilder] Creating KernelExecutor ...");
let kernel_executor = KernelExecutor::<_, R>::new(args).await?;
log::trace!(target: "citadel", "[NodeBuilder] Executing kernel");
kernel_executor.execute().await
}),
})
}
pub fn with_node_type(&mut self, node_type: NodeType) -> &mut Self {
self.hypernode_type = Some(node_type);
self
}
pub fn with_backend(&mut self, backend_type: BackendType) -> &mut Self {
self.backend_type = Some(backend_type);
self
}
pub fn with_kernel_executor_settings(
&mut self,
kernel_executor_settings: KernelExecutorSettings,
) -> &mut Self {
self.kernel_executor_settings = Some(kernel_executor_settings);
self
}
pub fn with_server_argon_settings(
&mut self,
settings: ArgonDefaultServerSettings,
) -> &mut Self {
self.server_argon_settings = Some(settings);
self
}
#[cfg(feature = "google-services")]
pub fn with_google_services_json_path<V: Into<String>>(&mut self, path: V) -> &mut Self {
let cfg = self.get_or_create_services();
cfg.google_services_json_path = Some(path.into());
self
}
pub fn with_server_misc_settings(&mut self, misc_settings: ServerMiscSettings) -> &mut Self {
self.server_misc_settings = Some(misc_settings);
self
}
#[cfg(feature = "google-services")]
pub fn with_google_realtime_database_config<V: Into<String>, W: Into<String>>(
&mut self,
url: V,
api_key: W,
) -> &mut Self {
let cfg = self.get_or_create_services();
cfg.google_rtdb = Some(RtdbConfig {
url: url.into(),
api_key: api_key.into(),
});
self
}
pub fn with_underlying_protocol(&mut self, proto: ServerMode<T>) -> &mut Self {
self.underlying_protocol = Some(proto);
self
}
pub fn with_client_config(&mut self, config: T::ClientConfig) -> &mut Self {
self.client_tls_config = Some(config);
self
}
#[cfg(feature = "google-services")]
fn get_or_create_services(&mut self) -> &mut ServicesConfig {
if self.services.is_some() {
self.services.as_mut().unwrap()
} else {
let cfg = ServicesConfig::default();
self.services = Some(cfg);
self.services.as_mut().unwrap()
}
}
pub fn with_stun_servers<V: Into<String>, S: Into<Vec<V>>>(&mut self, servers: S) -> &mut Self {
self.stun_servers = Some(servers.into().into_iter().map(|t| t.into()).collect());
self
}
pub fn with_turn_servers<S: Into<Vec<TurnServerConfig>>>(&mut self, servers: S) -> &mut Self {
self.turn_servers = Some(servers.into());
self
}
pub fn with_websocket_listener(&mut self, addr: std::net::SocketAddr) -> &mut Self {
self.websocket_listen_addr = Some(addr);
self
}
#[cfg(target_family = "wasm")]
pub fn with_no_central_server(&mut self, config: ServerlessConfig) -> &mut Self {
self.serverless_config = Some(config);
self
}
pub fn with_server_password<V: Into<PreSharedKey>>(&mut self, password: V) -> &mut Self {
let mut server_only_settings = self.local_only_server_settings.clone().unwrap_or_default();
server_only_settings.declared_pre_shared_key = Some(password.into());
self.local_only_server_settings = Some(server_only_settings);
self
}
pub fn with_server_declared_header_obfuscation<V: Into<HeaderObfuscatorSettings>>(
&mut self,
header_obfuscator_settings: V,
) -> &mut Self {
let mut server_only_settings = self.local_only_server_settings.clone().unwrap_or_default();
server_only_settings.declared_header_obfuscation_setting =
header_obfuscator_settings.into();
self.local_only_server_settings = Some(server_only_settings);
self
}
fn check(&self) -> anyhow::Result<()> {
#[cfg(feature = "google-services")]
if let Some(svc) = self.services.as_ref() {
if svc.google_rtdb.is_some() && svc.google_services_json_path.is_none() {
return Err(anyhow::Error::msg(
"Google realtime database is enabled, yet, a services path is not provided",
));
}
}
if let Some(stun_servers) = self.stun_servers.as_ref() {
if stun_servers.len() != 3 {
return Err(anyhow::Error::msg(
"There must be exactly 3 specified STUN servers",
));
}
}
Ok(())
}
}
#[cfg(not(target_family = "wasm"))]
impl<R: Ratchet + ContextRequirements> NodeBuilder<R, NativeIO> {
fn set_client_rustls_config(
&mut self,
config: std::sync::Arc<citadel_proto::re_imports::RustlsClientConfig>,
) {
let require_cert_verification = self
.client_tls_config
.as_ref()
.map(|cfg| cfg.require_cert_verification)
.unwrap_or(false);
self.client_tls_config = Some(citadel_proto::re_imports::NativeClientConfig {
config,
require_cert_verification,
});
}
pub async fn with_native_certs(&mut self) -> anyhow::Result<&mut Self> {
let certs = citadel_proto::re_imports::load_native_certs_async().await?;
let cfg = citadel_proto::re_imports::cert_vec_to_secure_client_config(&certs)?;
self.set_client_rustls_config(std::sync::Arc::new(cfg));
Ok(self)
}
pub fn with_insecure_skip_cert_verification(&mut self) -> &mut Self {
self.client_tls_config = Some(citadel_proto::re_imports::NativeClientConfig::new(
std::sync::Arc::new(citadel_proto::re_imports::insecure::rustls_client_config()),
));
self
}
pub async fn with_require_cert_verification(&mut self) -> anyhow::Result<&mut Self> {
if self.client_tls_config.is_none() {
let _ = self.with_native_certs().await?;
}
if let Some(cfg) = self.client_tls_config.as_mut() {
cfg.require_cert_verification = true;
}
Ok(self)
}
pub fn with_custom_certs<V: AsRef<[u8]>>(
&mut self,
custom_certs: &[V],
) -> anyhow::Result<&mut Self> {
let cfg = citadel_proto::re_imports::create_rustls_client_config(custom_certs)?;
self.set_client_rustls_config(std::sync::Arc::new(cfg));
Ok(self)
}
#[cfg(feature = "std")]
pub async fn with_pem_file<P: AsRef<std::path::Path>>(
&mut self,
path: P,
) -> anyhow::Result<&mut Self> {
use citadel_wire::exports::{Certificate, PemObject};
let mut der = std::io::Cursor::new(citadel_io::tokio::fs::read(path).await?);
let certs: Vec<Certificate<'static>> =
Certificate::pem_reader_iter(&mut der).collect::<Result<Vec<_>, _>>()?;
let cfg = citadel_proto::re_imports::create_rustls_client_config(&certs)?;
self.set_client_rustls_config(std::sync::Arc::new(cfg));
Ok(self)
}
}
#[cfg(all(test, not(target_family = "wasm")))]
mod tests {
use crate::builder::node_builder::DefaultNodeBuilder;
use crate::prefabs::server::empty::EmptyKernel;
use crate::prelude::{BackendType, NodeType};
use citadel_io::tokio;
use citadel_proto::prelude::{
KernelExecutorSettings, NativeIO, NativeP2PConfig, NativeSecureConfig, ServerMode,
};
use rstest::rstest;
use std::str::FromStr;
#[test]
#[cfg(feature = "google-services")]
fn okay_config() {
let _ = DefaultNodeBuilder::default()
.with_google_realtime_database_config("123", "456")
.with_google_services_json_path("abc")
.build(EmptyKernel::default())
.unwrap();
}
#[test]
#[cfg(feature = "google-services")]
fn bad_config() {
assert!(DefaultNodeBuilder::default()
.with_google_realtime_database_config("123", "456")
.build(EmptyKernel::default())
.is_err());
}
#[test]
fn bad_config2() {
assert!(DefaultNodeBuilder::default()
.with_stun_servers(["dummy1", "dummy2"])
.build(EmptyKernel::default())
.is_err());
}
#[rstest]
#[tokio::test]
#[timeout(std::time::Duration::from_secs(60))]
#[allow(clippy::let_underscore_must_use)]
async fn test_options(
#[values(ServerMode::P2P(NativeP2PConfig::self_signed()), ServerMode::OrderedReliableSecure(NativeSecureConfig::self_signed().unwrap())
)]
underlying_protocol: ServerMode<NativeIO>,
#[values(NodeType::Peer, NodeType::Server(std::net::SocketAddr::from_str("127.0.0.1:9999").unwrap()
))]
node_type: NodeType,
#[values(KernelExecutorSettings::default(), KernelExecutorSettings::default().with_max_concurrency(2)
)]
kernel_settings: KernelExecutorSettings,
#[values(BackendType::InMemory, BackendType::new("file:/hello_world/path/").unwrap())]
backend_type: BackendType,
) {
let mut builder = DefaultNodeBuilder::default();
let _ = builder
.with_underlying_protocol(underlying_protocol.clone())
.with_backend(backend_type.clone())
.with_node_type(node_type)
.with_kernel_executor_settings(kernel_settings.clone())
.with_insecure_skip_cert_verification()
.with_stun_servers(["dummy1", "dummy1", "dummy3"])
.with_native_certs()
.await
.unwrap();
assert!(builder.underlying_protocol.is_some());
assert_eq!(backend_type, builder.backend_type.clone().unwrap());
assert_eq!(node_type, builder.hypernode_type.unwrap());
assert_eq!(
kernel_settings,
builder.kernel_executor_settings.clone().unwrap()
);
drop(builder.build(EmptyKernel::default()).unwrap());
}
}