use crate::prelude::{CitadelClientServerConnection, TargetLockedRemote};
use bytes::Bytes;
use citadel_io::tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use citadel_io::tokio_stream::wrappers::UnboundedReceiverStream;
use citadel_io::tokio_util::io::StreamReader;
use citadel_proto::prelude::NetworkError;
use citadel_proto::prelude::*;
use citadel_types::crypto::SecBuffer;
use futures::StreamExt;
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
pub async fn internal_service<F, Fut, R: Ratchet>(
connection: CitadelClientServerConnection<R>,
service: F,
) -> Result<(), NetworkError>
where
F: Send + Copy + Sync + FnOnce(InternalServerCommunicator) -> Fut,
Fut: Send + Sync + Future<Output = Result<(), NetworkError>>,
{
let remote = connection.remote.clone();
let (tx_to_service, rx_from_kernel) = citadel_io::tokio::sync::mpsc::unbounded_channel();
let (tx_to_kernel, mut rx_from_service) = citadel_io::tokio::sync::mpsc::unbounded_channel();
let internal_server_communicator = InternalServerCommunicator {
tx_to_kernel,
rx_from_kernel: StreamReader::new(rx_from_kernel.into()),
};
let internal_server = service(internal_server_communicator);
let (mut sink, mut stream) = connection.split();
let from_proto = async move {
while let Some(packet) = stream.next().await {
tx_to_service.send(Ok(packet.into_buffer().freeze()))?;
}
Ok(())
};
let from_webserver = async move {
while let Some(packet) = rx_from_service.recv().await {
sink.send(packet).await?;
}
Ok(())
};
let res = citadel_io::tokio::select! {
res0 = from_proto => {
res0
},
res1 = from_webserver => {
res1
},
res2 = internal_server => {
res2
}
};
citadel_logging::warn!(target: "citadel", "Internal Server Stopped: {res:?}");
remote.remote().shutdown().await?;
res
}
pub struct InternalServerCommunicator {
pub(crate) tx_to_kernel: citadel_io::tokio::sync::mpsc::UnboundedSender<SecBuffer>,
pub(crate) rx_from_kernel:
StreamReader<UnboundedReceiverStream<Result<Bytes, std::io::Error>>, Bytes>,
}
impl AsyncWrite for InternalServerCommunicator {
fn poll_write(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
let len = buf.len();
match self.tx_to_kernel.send(buf.into()) {
Ok(_) => Poll::Ready(Ok(len)),
Err(err) => Poll::Ready(Err(std::io::Error::other(err.to_string()))),
}
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
Poll::Ready(Ok(()))
}
}
impl AsyncRead for InternalServerCommunicator {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
Pin::new(&mut self.rx_from_kernel).poll_read(cx, buf)
}
}
impl Unpin for InternalServerCommunicator {}