use crate::error::{Error, Result};
use crate::error_event::ErrorEventData;
use crate::event_handler::Response;
use crate::events::error_event::ERROR_EVENT_NAME;
use crate::events::event::Event;
use crate::events::event_handler::EventHandler;
use crate::ipc::client::IPCClient;
use crate::ipc::context::{Context, PooledContext, ReplyListeners};
use crate::ipc::server::IPCServer;
use crate::namespaces::builder::NamespaceBuilder;
use crate::namespaces::namespace::Namespace;
#[cfg(feature = "serialize")]
use crate::payload::DynamicSerializer;
use crate::prelude::AsyncProtocolStream;
use crate::protocol::AsyncStreamProtocolListener;
use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::RwLock;
use trait_bound_typemap::{KeyCanExtend, SendSyncTypeMap, TypeMap, TypeMapEntry, TypeMapKey};
pub struct IPCBuilder<L: AsyncStreamProtocolListener> {
handler: EventHandler,
address: Option<L::AddressType>,
namespaces: HashMap<String, Namespace>,
data: SendSyncTypeMap,
timeout: Duration,
#[cfg(feature = "serialize")]
default_serializer: DynamicSerializer,
listener_options: L::ListenerOptions,
stream_options: <L::Stream as AsyncProtocolStream>::StreamOptions,
}
impl<L: AsyncStreamProtocolListener> Default for IPCBuilder<L> {
fn default() -> Self {
Self::new()
}
}
impl<L> IPCBuilder<L>
where
L: AsyncStreamProtocolListener,
{
pub fn new() -> Self {
let mut handler = EventHandler::new();
handler.on(ERROR_EVENT_NAME, |_, event| {
Box::pin(async move {
let error_data = event.payload::<ErrorEventData>()?;
tracing::warn!(error_data.code);
tracing::warn!("error_data.message = '{}'", error_data.message);
Ok(Response::empty())
})
});
Self {
handler,
address: None,
namespaces: HashMap::new(),
data: SendSyncTypeMap::new(),
timeout: Duration::from_secs(60),
#[cfg(feature = "serialize")]
default_serializer: DynamicSerializer::first_available(),
listener_options: L::ListenerOptions::default(),
stream_options: <L::Stream as AsyncProtocolStream>::StreamOptions::default(),
}
}
pub fn insert<K: TypeMapKey>(mut self, value: K::Value) -> Self
where
<K as TypeMapKey>::Value: Send + Sync,
{
self.data.insert::<K>(value);
self
}
pub fn insert_all<I: IntoIterator<Item = TypeMapEntry<K>>, K: KeyCanExtend<SendSyncTypeMap>>(
mut self,
value: I,
) -> Self {
self.data.extend(value);
self
}
pub fn on<F: 'static>(mut self, event: &str, callback: F) -> Self
where
F: for<'a> Fn(
&'a Context,
Event,
) -> Pin<Box<(dyn Future<Output = Result<Response>> + Send + 'a)>>
+ Send
+ Sync,
{
self.handler.on(event, callback);
self
}
pub fn address(mut self, address: L::AddressType) -> Self {
self.address = Some(address);
self
}
pub fn namespace<S: ToString>(self, name: S) -> NamespaceBuilder<L> {
NamespaceBuilder::new(self, name.to_string())
}
pub fn add_namespace(mut self, namespace: Namespace) -> Self {
self.namespaces
.insert(namespace.name().to_owned(), namespace);
self
}
pub fn timeout(mut self, timeout: Duration) -> Self {
self.timeout = timeout;
self
}
#[cfg(feature = "serialize")]
pub fn default_serializer(mut self, serializer: DynamicSerializer) -> Self {
self.default_serializer = serializer;
self
}
pub fn server_options(mut self, options: L::ListenerOptions) -> Self {
self.listener_options = options;
self
}
pub fn client_options(
mut self,
options: <L::Stream as AsyncProtocolStream>::StreamOptions,
) -> Self {
self.stream_options = options;
self
}
#[tracing::instrument(skip(self))]
pub async fn build_server(self) -> Result<()> {
self.validate()?;
let server = IPCServer {
namespaces: self.namespaces,
handler: self.handler,
data: self.data,
timeout: self.timeout,
#[cfg(feature = "serialize")]
default_serializer: self.default_serializer,
};
server
.start::<L>(self.address.unwrap(), self.listener_options)
.await?;
Ok(())
}
#[tracing::instrument(skip(self))]
pub async fn build_client(self) -> Result<Context> {
self.validate()?;
let data = Arc::new(RwLock::new(self.data));
let reply_listeners = ReplyListeners::default();
let client = IPCClient {
namespaces: self.namespaces,
handler: self.handler,
data,
reply_listeners,
timeout: self.timeout,
#[cfg(feature = "serialize")]
default_serializer: self.default_serializer,
};
let ctx = client
.connect::<L::Stream>(self.address.unwrap(), self.stream_options.clone())
.await?;
Ok(ctx)
}
#[tracing::instrument(skip(self))]
pub async fn build_pooled_client(self, pool_size: usize) -> Result<PooledContext> {
if pool_size == 0 {
return Err(Error::BuildError("Pool size must be greater than 0".to_string()));
}
self.validate()?;
let data = Arc::new(RwLock::new(self.data));
let mut contexts = Vec::new();
let address = self.address.unwrap();
let reply_listeners = ReplyListeners::default();
for _ in 0..pool_size {
let client = IPCClient {
namespaces: self.namespaces.clone(),
handler: self.handler.clone(),
data: Arc::clone(&data),
reply_listeners: reply_listeners.clone(),
timeout: self.timeout,
#[cfg(feature = "serialize")]
default_serializer: self.default_serializer.clone(),
};
let ctx = client
.connect::<L::Stream>(address.clone(), self.stream_options.clone())
.await?;
contexts.push(ctx);
}
Ok(PooledContext::new(contexts))
}
#[tracing::instrument(skip(self))]
fn validate(&self) -> Result<()> {
if self.address.is_none() {
Err(Error::BuildError("Missing Address".to_string()))
} else {
Ok(())
}
}
}