use crate::error::NetworkError;
use crate::kernel::kernel_communicator::{KernelAsyncCallbackHandler, KernelStreamSubscription};
use crate::prelude::{NodeRequest, NodeResult};
use crate::proto::node::HdpServerRemoteInner;
use crate::proto::outbound_sender::BoundedSender;
use citadel_user::account_manager::AccountManager;
use citadel_wire::hypernode_type::NodeType;
use futures::channel::mpsc::TrySendError;
use futures::{Sink, SinkExt};
use serde::{Deserialize, Serialize};
use std::fmt::{Debug, Display, Formatter};
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Duration;
#[derive(Clone)]
pub struct NodeRemote {
outbound_send_request_tx: BoundedSender<(NodeRequest, Ticket)>,
inner: Arc<HdpServerRemoteInner>,
}
#[async_trait::async_trait]
#[auto_impl::auto_impl(Box, &mut)]
pub trait Remote: Clone + Send {
async fn send(&mut self, request: NodeRequest) -> Result<Ticket, NetworkError> {
let ticket = self.get_next_ticket();
self.send_with_custom_ticket(ticket, request)
.await
.map(|_| ticket)
}
async fn send_with_custom_ticket(
&mut self,
ticket: Ticket,
request: NodeRequest,
) -> Result<(), NetworkError>;
async fn send_callback_subscription(
&mut self,
request: NodeRequest,
) -> Result<KernelStreamSubscription, NetworkError>;
async fn send_callback(&mut self, request: NodeRequest) -> Result<NodeResult, NetworkError>;
fn account_manager(&self) -> &AccountManager;
fn get_next_ticket(&self) -> Ticket;
}
#[async_trait::async_trait]
impl Remote for NodeRemote {
async fn send(&mut self, request: NodeRequest) -> Result<Ticket, NetworkError> {
NodeRemote::send(self, request).await
}
async fn send_with_custom_ticket(
&mut self,
ticket: Ticket,
request: NodeRequest,
) -> Result<(), NetworkError> {
NodeRemote::send_with_custom_ticket(self, ticket, request).await
}
async fn send_callback_subscription(
&mut self,
request: NodeRequest,
) -> Result<KernelStreamSubscription, NetworkError> {
NodeRemote::send_callback_subscription(self, request).await
}
async fn send_callback(&mut self, request: NodeRequest) -> Result<NodeResult, NetworkError> {
NodeRemote::send_callback(self, request).await
}
fn account_manager(&self) -> &AccountManager {
NodeRemote::account_manager(self)
}
fn get_next_ticket(&self) -> Ticket {
NodeRemote::get_next_ticket(self)
}
}
impl Debug for NodeRemote {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "HdpServerRemote")
}
}
impl NodeRemote {
pub(crate) fn new(
outbound_send_request_tx: BoundedSender<(NodeRequest, Ticket)>,
callback_handler: KernelAsyncCallbackHandler,
account_manager: AccountManager,
node_type: NodeType,
) -> Self {
Self {
outbound_send_request_tx,
inner: Arc::new(HdpServerRemoteInner {
callback_handler,
account_manager,
node_type,
}),
}
}
pub async fn send_with_custom_ticket(
&mut self,
ticket: Ticket,
request: NodeRequest,
) -> Result<(), NetworkError> {
self.outbound_send_request_tx.send((request, ticket)).await
}
pub async fn send(&mut self, request: NodeRequest) -> Result<Ticket, NetworkError> {
let ticket = self.get_next_ticket();
self.send_with_custom_ticket(ticket, request)
.await
.map(|_| ticket)
}
pub async fn send_callback_custom_ticket(
&mut self,
request: NodeRequest,
ticket: Ticket,
) -> Result<NodeResult, NetworkError> {
let rx = self.inner.callback_handler.register_future(ticket)?;
match self.send_with_custom_ticket(ticket, request).await {
Ok(_) => rx
.await
.map_err(|err| NetworkError::Generic(err.to_string())),
Err(err) => {
self.inner.callback_handler.remove_listener(ticket);
Err(err)
}
}
}
pub(crate) async fn send_callback_subscription_custom_ticket(
&mut self,
request: NodeRequest,
ticket: Ticket,
) -> Result<KernelStreamSubscription, NetworkError> {
let rx = self.inner.callback_handler.register_stream(ticket)?;
match self.send_with_custom_ticket(ticket, request).await {
Ok(_) => Ok(rx),
Err(err) => {
self.inner.callback_handler.remove_listener(ticket);
Err(err)
}
}
}
pub async fn send_callback_subscription(
&mut self,
request: NodeRequest,
) -> Result<KernelStreamSubscription, NetworkError> {
let ticket = self.get_next_ticket();
self.send_callback_subscription_custom_ticket(request, ticket)
.await
}
pub async fn send_callback(
&mut self,
request: NodeRequest,
) -> Result<NodeResult, NetworkError> {
let ticket = self.get_next_ticket();
self.send_callback_custom_ticket(request, ticket).await
}
pub async fn send_callback_timeout(
&mut self,
request: NodeRequest,
timeout: Duration,
) -> Result<NodeResult, NetworkError> {
tokio::time::timeout(timeout, self.send_callback(request))
.await
.map_err(|_| NetworkError::Timeout(0))?
}
pub async fn shutdown(&mut self) -> Result<(), NetworkError> {
let _ = self.send(NodeRequest::Shutdown).await?;
self.outbound_send_request_tx.close().await
}
pub fn get_next_ticket(&self) -> Ticket {
uuid::Uuid::new_v4().as_u128().into()
}
#[allow(clippy::result_large_err)]
pub fn try_send_with_custom_ticket(
&mut self,
ticket: Ticket,
request: NodeRequest,
) -> Result<(), TrySendError<(NodeRequest, Ticket)>> {
self.outbound_send_request_tx.try_send((request, ticket))
}
#[allow(clippy::result_large_err)]
pub fn try_send(
&mut self,
request: NodeRequest,
) -> Result<(), TrySendError<(NodeRequest, Ticket)>> {
let ticket = self.get_next_ticket();
self.try_send_with_custom_ticket(ticket, request)
}
pub fn local_node_type(&self) -> &NodeType {
&self.inner.node_type
}
pub fn account_manager(&self) -> &AccountManager {
&self.inner.account_manager
}
}
impl Unpin for NodeRemote {}
impl Sink<(Ticket, NodeRequest)> for NodeRemote {
type Error = NetworkError;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
<Self as Sink<NodeRequest>>::poll_ready(self, cx)
}
fn start_send(
mut self: Pin<&mut Self>,
item: (Ticket, NodeRequest),
) -> Result<(), Self::Error> {
Pin::new(&mut self.outbound_send_request_tx).start_send((item.1, item.0))
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
<Self as Sink<NodeRequest>>::poll_flush(self, cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
<Self as Sink<NodeRequest>>::poll_close(self, cx)
}
}
impl Sink<NodeRequest> for NodeRemote {
type Error = NetworkError;
fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Pin::new(&mut self.outbound_send_request_tx).poll_ready(cx)
}
fn start_send(mut self: Pin<&mut Self>, item: NodeRequest) -> Result<(), Self::Error> {
let ticket = self.get_next_ticket();
Pin::new(&mut self.outbound_send_request_tx).start_send((item, ticket))
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Pin::new(&mut self.outbound_send_request_tx)
.poll_flush(cx)
.map_err(|err| NetworkError::Generic(err.to_string()))
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Pin::new(&mut self.outbound_send_request_tx)
.poll_close(cx)
.map_err(|err| NetworkError::Generic(err.to_string()))
}
}
#[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize, Deserialize)]
pub struct Ticket(pub u128);
impl From<u128> for Ticket {
fn from(val: u128) -> Self {
Ticket(val)
}
}
impl From<usize> for Ticket {
fn from(val: usize) -> Self {
(val as u128).into()
}
}
impl Display for Ticket {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}