use super::*;
use ractor::SupervisionEvent;
use ractor::{Actor, ActorProcessingErr, ActorRef};
use std::marker::PhantomData;
#[cfg(not(feature = "async-trait"))]
use std::marker::Send;
use tokio::io::ErrorKind;
use tokio::io::{AsyncReadExt, ReadHalf, WriteHalf};
use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
use tokio::net::TcpStream;
use TcpSessionMessage::*;
pub type Frame = Vec<u8>;
#[cfg_attr(feature = "async-trait", ractor::async_trait)]
pub trait FrameReceiver: ractor::State {
#[cfg(not(feature = "async-trait"))]
fn frame_ready(
&self,
f: Frame,
) -> impl std::future::Future<Output = Result<(), ActorProcessingErr>> + Send;
#[cfg(feature = "async-trait")]
async fn frame_ready(&self, f: Frame) -> Result<(), ActorProcessingErr>;
}
pub enum TcpSessionMessage {
Send(Frame),
FrameReady(Frame),
}
pub struct TcpSessionState<R>
where
R: FrameReceiver,
{
writer: ActorRef<SessionWriterMessage>,
reader: ActorRef<SessionReaderMessage>,
receiver: R,
stream_info: NetworkStreamInfo,
}
pub struct TcpSessionStartupArguments<R>
where
R: FrameReceiver,
{
pub receiver: R,
pub tcp_session: NetworkStream,
}
pub struct TcpSession<R>
where
R: FrameReceiver,
{
_r: PhantomData<fn() -> R>,
}
impl<R> Default for TcpSession<R>
where
R: FrameReceiver,
{
fn default() -> Self {
Self::new()
}
}
impl<R> TcpSession<R>
where
R: FrameReceiver,
{
pub fn new() -> Self {
Self { _r: PhantomData }
}
}
#[cfg_attr(feature = "async-trait", async_trait::async_trait)]
impl<R> Actor for TcpSession<R>
where
R: FrameReceiver,
{
type Msg = TcpSessionMessage;
type State = TcpSessionState<R>;
type Arguments = TcpSessionStartupArguments<R>;
async fn pre_start(
&self,
myself: ActorRef<Self::Msg>,
args: Self::Arguments,
) -> Result<Self::State, ActorProcessingErr> {
let stream_info = args.tcp_session.info();
let (read, write) = match args.tcp_session {
NetworkStream::Raw { stream, .. } => {
let (read, write) = stream.into_split();
(ActorReadHalf::Regular(read), ActorWriteHalf::Regular(write))
}
NetworkStream::TlsClient { stream, .. } => {
let (read_half, write_half) = tokio::io::split(stream);
(
ActorReadHalf::ClientTls(read_half),
ActorWriteHalf::ClientTls(write_half),
)
}
NetworkStream::TlsServer { stream, .. } => {
let (read_half, write_half) = tokio::io::split(stream);
(
ActorReadHalf::ServerTls(read_half),
ActorWriteHalf::ServerTls(write_half),
)
}
};
let (writer, _) =
Actor::spawn_linked(None, SessionWriter, write, myself.get_cell()).await?;
let (reader, _) = Actor::spawn_linked(
None,
SessionReader {
session: myself.clone(),
},
read,
myself.get_cell(),
)
.await?;
Ok(Self::State {
writer,
reader,
receiver: args.receiver,
stream_info,
})
}
async fn post_stop(
&self,
_myself: ActorRef<Self::Msg>,
state: &mut Self::State,
) -> Result<(), ActorProcessingErr> {
tracing::info!("TCP Session closed for {}", state.stream_info.peer_addr);
Ok(())
}
async fn handle(
&self,
_myself: ActorRef<Self::Msg>,
message: Self::Msg,
state: &mut Self::State,
) -> Result<(), ActorProcessingErr> {
match message {
Send(msg) => {
tracing::trace!(
"SEND: {} -> {} - '{msg:?}'",
state.stream_info.local_addr,
state.stream_info.peer_addr
);
let _ = state.writer.cast(SessionWriterMessage::Write(msg));
}
FrameReady(msg) => {
tracing::trace!(
"RECEIVE {} <- {} - '{msg:?}'",
state.stream_info.local_addr,
state.stream_info.peer_addr,
);
state.receiver.frame_ready(msg).await?;
}
}
Ok(())
}
async fn handle_supervisor_evt(
&self,
myself: ActorRef<Self::Msg>,
message: SupervisionEvent,
state: &mut Self::State,
) -> Result<(), ActorProcessingErr> {
match message {
SupervisionEvent::ActorFailed(actor, panic_msg) => {
if actor.get_id() == state.reader.get_id() {
tracing::error!("TCP Session's reader panicked with '{}'", panic_msg);
} else if actor.get_id() == state.writer.get_id() {
tracing::error!("TCP Session's writer panicked with '{}'", panic_msg);
} else {
tracing::error!("TCP Session received a child panic from an unknown child actor ({}) - '{}'", actor.get_id(), panic_msg);
}
myself.stop_children(Some("session_stop_panic".to_string()));
myself.stop(Some("child_panic".to_string()));
}
SupervisionEvent::ActorTerminated(actor, _, exit_reason) => {
if actor.get_id() == state.reader.get_id() {
tracing::debug!("TCP Session's reader exited");
} else if actor.get_id() == state.writer.get_id() {
tracing::debug!("TCP Session's writer exited");
} else {
tracing::warn!("TCP Session received a child exit from an unknown child actor ({}) - '{:?}'", actor.get_id(), exit_reason);
}
myself.stop_children(Some("session_stop_terminated".to_string()));
myself.stop(Some("child_terminate".to_string()));
}
_ => {
}
}
Ok(())
}
}
enum ActorWriteHalf {
ServerTls(WriteHalf<tokio_rustls::server::TlsStream<TcpStream>>),
ClientTls(WriteHalf<tokio_rustls::client::TlsStream<TcpStream>>),
Regular(OwnedWriteHalf),
}
impl ActorWriteHalf {
async fn write_u64(&mut self, n: u64) -> tokio::io::Result<()> {
use tokio::io::AsyncWriteExt;
match self {
Self::ServerTls(t) => t.write_u64(n).await,
Self::ClientTls(t) => t.write_u64(n).await,
Self::Regular(t) => t.write_u64(n).await,
}
}
async fn write_all(&mut self, data: &[u8]) -> tokio::io::Result<()> {
use tokio::io::AsyncWriteExt;
match self {
Self::ServerTls(t) => t.write_all(data).await,
Self::ClientTls(t) => t.write_all(data).await,
Self::Regular(t) => t.write_all(data).await,
}
}
async fn flush(&mut self) -> tokio::io::Result<()> {
use tokio::io::AsyncWriteExt;
match self {
Self::ServerTls(t) => t.flush().await,
Self::ClientTls(t) => t.flush().await,
Self::Regular(t) => t.flush().await,
}
}
}
enum ActorReadHalf {
ServerTls(ReadHalf<tokio_rustls::server::TlsStream<TcpStream>>),
ClientTls(ReadHalf<tokio_rustls::client::TlsStream<TcpStream>>),
Regular(OwnedReadHalf),
}
impl ActorReadHalf {
async fn read_u64(&mut self) -> tokio::io::Result<u64> {
match self {
Self::ServerTls(t) => t.read_u64().await,
Self::ClientTls(t) => t.read_u64().await,
Self::Regular(t) => t.read_u64().await,
}
}
}
async fn read_n_bytes(stream: &mut ActorReadHalf, len: usize) -> Result<Vec<u8>, tokio::io::Error> {
let mut buf = vec![0u8; len];
let mut c_len = 0;
if let ActorReadHalf::Regular(r) = stream {
r.readable().await?;
}
while c_len < len {
let n = match stream {
ActorReadHalf::ServerTls(t) => t.read(&mut buf[c_len..]).await?,
ActorReadHalf::ClientTls(t) => t.read(&mut buf[c_len..]).await?,
ActorReadHalf::Regular(t) => t.read(&mut buf[c_len..]).await?,
};
if n == 0 {
return Err(tokio::io::Error::new(ErrorKind::UnexpectedEof, "EOF"));
}
c_len += n;
}
Ok(buf)
}
struct SessionWriter;
struct SessionWriterState {
writer: Option<ActorWriteHalf>,
}
enum SessionWriterMessage {
Write(Frame),
}
#[cfg_attr(feature = "async-trait", ractor::async_trait)]
impl Actor for SessionWriter {
type Msg = SessionWriterMessage;
type State = SessionWriterState;
type Arguments = ActorWriteHalf;
async fn pre_start(
&self,
_myself: ActorRef<Self::Msg>,
writer: ActorWriteHalf,
) -> Result<Self::State, ActorProcessingErr> {
Ok(Self::State {
writer: Some(writer),
})
}
async fn post_stop(
&self,
_myself: ActorRef<Self::Msg>,
state: &mut Self::State,
) -> Result<(), ActorProcessingErr> {
drop(state.writer.take());
Ok(())
}
async fn handle(
&self,
myself: ActorRef<Self::Msg>,
message: Self::Msg,
state: &mut Self::State,
) -> Result<(), ActorProcessingErr> {
match message {
SessionWriterMessage::Write(msg) if state.writer.is_some() => {
if let Some(stream) = &mut state.writer {
if let ActorWriteHalf::Regular(w) = stream {
w.writable().await?;
}
if let Err(write_err) = stream.write_u64(msg.len() as u64).await {
tracing::warn!("Error writing to the stream '{}'", write_err);
} else {
tracing::trace!("Wrote length, writing payload (len={})", msg.len());
if let Err(write_err) = stream.write_all(&msg).await {
tracing::warn!("Error writing to the stream '{}'", write_err);
myself.stop(Some("channel_closed".to_string()));
return Ok(());
}
stream.flush().await?;
}
}
}
_ => {
}
}
Ok(())
}
}
struct SessionReader {
session: ActorRef<TcpSessionMessage>,
}
pub enum SessionReaderMessage {
WaitForFrame,
ReadFrame(u64),
}
struct SessionReaderState {
reader: Option<ActorReadHalf>,
}
#[cfg_attr(feature = "async-trait", ractor::async_trait)]
impl Actor for SessionReader {
type Msg = SessionReaderMessage;
type State = SessionReaderState;
type Arguments = ActorReadHalf;
async fn pre_start(
&self,
myself: ActorRef<Self::Msg>,
reader: ActorReadHalf,
) -> Result<Self::State, ActorProcessingErr> {
let _ = myself.cast(SessionReaderMessage::WaitForFrame);
Ok(Self::State {
reader: Some(reader),
})
}
async fn post_stop(
&self,
_myself: ActorRef<Self::Msg>,
state: &mut Self::State,
) -> Result<(), ActorProcessingErr> {
drop(state.reader.take());
Ok(())
}
async fn handle(
&self,
myself: ActorRef<Self::Msg>,
message: Self::Msg,
state: &mut Self::State,
) -> Result<(), ActorProcessingErr> {
match message {
Self::Msg::WaitForFrame if state.reader.is_some() => {
if let Some(stream) = &mut state.reader {
match stream.read_u64().await {
Ok(length) => {
tracing::trace!("Payload length message ({}) received", length);
let _ = myself.cast(SessionReaderMessage::ReadFrame(length));
return Ok(());
}
Err(err) if err.kind() == ErrorKind::UnexpectedEof => {
tracing::trace!("Error (EOF) on stream");
drop(state.reader.take());
myself.stop(Some("channel_closed".to_string()));
}
Err(_other_err) => {
tracing::trace!("Error ({:?}) on stream", _other_err);
}
}
}
let _ = myself.cast(SessionReaderMessage::WaitForFrame);
}
Self::Msg::ReadFrame(length) if state.reader.is_some() => {
if let Some(stream) = &mut state.reader {
match read_n_bytes(stream, length as usize).await {
Ok(buf) => {
tracing::trace!("Payload of length({}) received", buf.len());
self.session.cast(FrameReady(buf))?;
}
Err(err) if err.kind() == ErrorKind::UnexpectedEof => {
drop(state.reader.take());
myself.stop(Some("channel_closed".to_string()));
return Ok(());
}
Err(_other_err) => {
}
}
}
let _ = myself.cast(SessionReaderMessage::WaitForFrame);
}
_ => {
let _ = myself.cast(SessionReaderMessage::WaitForFrame);
}
}
Ok(())
}
}