tokactor/io/
tcp.rs

1use std::{future::Future, io, marker::PhantomData, pin::Pin, task::Poll};
2
3use tokio::{
4    net::{self, tcp::OwnedReadHalf, TcpStream, ToSocketAddrs},
5    sync::{mpsc, oneshot},
6};
7
8use crate::{
9    executor::{Executor, RawExecutor},
10    Actor, Ask, AsyncAsk, Ctx, Message, TcpRequest,
11};
12
13use super::{
14    create_reader_actor, create_writer_actor, Component, ComponentFuture, DataFrameReceiver,
15    IoRead, Reader, Writer,
16};
17
18pub struct TcpAcceptFut<
19    'a,
20    RouterAct: Actor,
21    ConnAct: Actor,
22    Reader: IoRead<OwnedReadHalf> + Send,
23    O: DataFrameReceiver<Frame = Reader>,
24> {
25    listener: &'a net::TcpListener,
26    executor: &'a mut RawExecutor<RouterAct>,
27    _actor: PhantomData<ConnAct>,
28    _reader: PhantomData<Reader>,
29    _payload: PhantomData<O>,
30}
31
32impl<'a, O, RouterAct, ConnAct, Reader> Unpin for TcpAcceptFut<'a, RouterAct, ConnAct, Reader, O>
33where
34    RouterAct: Actor + Ask<TcpRequest, Result = ConnAct>,
35    ConnAct: Actor,
36    Reader: IoRead<OwnedReadHalf> + Default + Send + 'static,
37    O: DataFrameReceiver<Frame = Reader>,
38{
39}
40
41impl<'a, O, RouterAct, ConnAct, Reader> Future for TcpAcceptFut<'a, RouterAct, ConnAct, Reader, O>
42where
43    RouterAct: Actor + Ask<TcpRequest, Result = ConnAct>,
44    ConnAct: Actor,
45    Reader: IoRead<OwnedReadHalf> + Default + Send + 'static,
46    O: DataFrameReceiver<Frame = Reader>,
47{
48    type Output = io::Result<(crate::io::Reader<Reader, O::Request>, ConnAct)>;
49
50    fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
51        if let Poll::Ready(result) = self.listener.poll_accept(cx) {
52            let result = match result {
53                Ok((stream, address)) => {
54                    let this = self.get_mut();
55                    let (read, write) = this
56                        .executor
57                        .with_ctx(move |ctx| tcp_actors::<RouterAct, Reader, O>(ctx, stream));
58                    let request = TcpRequest(write, address);
59                    let actor = this.executor.ask(request);
60                    Ok((read, actor))
61                }
62                Err(err) => Err(err),
63            };
64            Poll::Ready(result)
65        } else {
66            Poll::Pending
67        }
68    }
69}
70
71pub struct TcpListener<
72    P: Actor,
73    A: Actor,
74    R: IoRead<OwnedReadHalf>,
75    O: DataFrameReceiver<Frame = R>,
76> {
77    executor: RawExecutor<P>,
78    listener: net::TcpListener,
79    _actor: PhantomData<A>,
80    _reader: PhantomData<R>,
81    _payload: PhantomData<O>,
82}
83
84impl<'a, P: Actor, A: Actor, R: IoRead<OwnedReadHalf>, O: DataFrameReceiver<Frame = R>>
85    TcpListener<P, A, R, O>
86{
87    pub async fn new(
88        address: impl ToSocketAddrs,
89        parent: P,
90    ) -> io::Result<TcpListener<P, A, R, O>> {
91        let listener = net::TcpListener::bind(address).await?;
92        let mut executor = Executor::new(parent, Ctx::<P>::new()).into_raw();
93        executor.raw_start();
94        Ok(Self {
95            executor,
96            listener,
97            _actor: PhantomData,
98            _reader: PhantomData,
99            _payload: PhantomData,
100        })
101    }
102}
103
104impl<P, A, R, O> ComponentFuture for TcpListener<P, A, R, O>
105where
106    P: Actor + Ask<TcpRequest, Result = A>,
107    A: Actor + AsyncAsk<O::Request>,
108    R: IoRead<OwnedReadHalf> + Default + Message + std::fmt::Debug + Send + Sync + 'static,
109    O: DataFrameReceiver<Frame = R>,
110{
111    type Payload = O;
112    type Reader = crate::io::Reader<R, O::Request>;
113    type Actor = A;
114    type Error = std::io::Error;
115    type Future<'a> = TcpAcceptFut<'a, P, A, R, O>;
116}
117
118impl<P, A, R, O> Component for TcpListener<P, A, R, O>
119where
120    P: Actor + Ask<TcpRequest, Result = A>,
121    A: Actor + AsyncAsk<O::Request>,
122    R: IoRead<OwnedReadHalf> + Default + Message + std::fmt::Debug + Send + Sync + 'static,
123    O: DataFrameReceiver<Frame = R>,
124{
125    type Shutdown = Pin<Box<dyn Future<Output = ()> + Send + Sync + 'static>>;
126
127    #[allow(clippy::needless_lifetimes)]
128    fn accept<'a>(&'a mut self) -> Self::Future<'a> {
129        TcpAcceptFut {
130            listener: &self.listener,
131            executor: &mut self.executor,
132            _actor: PhantomData,
133            _reader: PhantomData,
134            _payload: PhantomData,
135        }
136    }
137
138    fn shutdown(self) -> Self::Shutdown {
139        Box::pin(async move { self.executor.raw_shutdown().await })
140    }
141}
142
143fn tcp_actors<
144    A: Actor,
145    R: IoRead<OwnedReadHalf> + Default + Send + 'static,
146    Payload: DataFrameReceiver<Frame = R>,
147>(
148    ctx: &Ctx<A>,
149    stream: TcpStream,
150) -> (Reader<R, Payload::Request>, Writer) {
151    let (read, write) = stream.into_split();
152    let (reader_tx, reader_rx) =
153        mpsc::channel::<(R, oneshot::Sender<std::io::Result<Payload::Request>>)>(10);
154    let (writer_tx, writer_rx) =
155        mpsc::channel::<(Vec<u8>, oneshot::Sender<std::io::Result<()>>)>(10);
156    let (shutdown_tx, shutdown_rx) = oneshot::channel();
157
158    let parent_rx = ctx.notifier.subscribe();
159    tokio::spawn(create_reader_actor::<OwnedReadHalf, R, Payload>(
160        read,
161        reader_rx,
162        parent_rx,
163        shutdown_tx,
164    ));
165
166    let parent_rx = ctx.notifier.subscribe();
167    tokio::spawn(create_writer_actor(
168        write,
169        writer_rx,
170        parent_rx,
171        shutdown_rx,
172    ));
173
174    let reader = Reader::<R, Payload::Request>::new(reader_tx);
175    let writer = Writer::new(writer_tx);
176
177    (reader, writer)
178}