1use std::{future::Future, io, marker::PhantomData, pin::Pin, task::Poll};
2
3use tokio::{
4 net::{self, tcp::OwnedReadHalf, TcpStream, ToSocketAddrs},
5 sync::{mpsc, oneshot},
6};
7
8use crate::{
9 executor::{Executor, RawExecutor},
10 Actor, Ask, AsyncAsk, Ctx, Message, TcpRequest,
11};
12
13use super::{
14 create_reader_actor, create_writer_actor, Component, ComponentFuture, DataFrameReceiver,
15 IoRead, Reader, Writer,
16};
17
18pub struct TcpAcceptFut<
19 'a,
20 RouterAct: Actor,
21 ConnAct: Actor,
22 Reader: IoRead<OwnedReadHalf> + Send,
23 O: DataFrameReceiver<Frame = Reader>,
24> {
25 listener: &'a net::TcpListener,
26 executor: &'a mut RawExecutor<RouterAct>,
27 _actor: PhantomData<ConnAct>,
28 _reader: PhantomData<Reader>,
29 _payload: PhantomData<O>,
30}
31
32impl<'a, O, RouterAct, ConnAct, Reader> Unpin for TcpAcceptFut<'a, RouterAct, ConnAct, Reader, O>
33where
34 RouterAct: Actor + Ask<TcpRequest, Result = ConnAct>,
35 ConnAct: Actor,
36 Reader: IoRead<OwnedReadHalf> + Default + Send + 'static,
37 O: DataFrameReceiver<Frame = Reader>,
38{
39}
40
41impl<'a, O, RouterAct, ConnAct, Reader> Future for TcpAcceptFut<'a, RouterAct, ConnAct, Reader, O>
42where
43 RouterAct: Actor + Ask<TcpRequest, Result = ConnAct>,
44 ConnAct: Actor,
45 Reader: IoRead<OwnedReadHalf> + Default + Send + 'static,
46 O: DataFrameReceiver<Frame = Reader>,
47{
48 type Output = io::Result<(crate::io::Reader<Reader, O::Request>, ConnAct)>;
49
50 fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
51 if let Poll::Ready(result) = self.listener.poll_accept(cx) {
52 let result = match result {
53 Ok((stream, address)) => {
54 let this = self.get_mut();
55 let (read, write) = this
56 .executor
57 .with_ctx(move |ctx| tcp_actors::<RouterAct, Reader, O>(ctx, stream));
58 let request = TcpRequest(write, address);
59 let actor = this.executor.ask(request);
60 Ok((read, actor))
61 }
62 Err(err) => Err(err),
63 };
64 Poll::Ready(result)
65 } else {
66 Poll::Pending
67 }
68 }
69}
70
71pub struct TcpListener<
72 P: Actor,
73 A: Actor,
74 R: IoRead<OwnedReadHalf>,
75 O: DataFrameReceiver<Frame = R>,
76> {
77 executor: RawExecutor<P>,
78 listener: net::TcpListener,
79 _actor: PhantomData<A>,
80 _reader: PhantomData<R>,
81 _payload: PhantomData<O>,
82}
83
84impl<'a, P: Actor, A: Actor, R: IoRead<OwnedReadHalf>, O: DataFrameReceiver<Frame = R>>
85 TcpListener<P, A, R, O>
86{
87 pub async fn new(
88 address: impl ToSocketAddrs,
89 parent: P,
90 ) -> io::Result<TcpListener<P, A, R, O>> {
91 let listener = net::TcpListener::bind(address).await?;
92 let mut executor = Executor::new(parent, Ctx::<P>::new()).into_raw();
93 executor.raw_start();
94 Ok(Self {
95 executor,
96 listener,
97 _actor: PhantomData,
98 _reader: PhantomData,
99 _payload: PhantomData,
100 })
101 }
102}
103
104impl<P, A, R, O> ComponentFuture for TcpListener<P, A, R, O>
105where
106 P: Actor + Ask<TcpRequest, Result = A>,
107 A: Actor + AsyncAsk<O::Request>,
108 R: IoRead<OwnedReadHalf> + Default + Message + std::fmt::Debug + Send + Sync + 'static,
109 O: DataFrameReceiver<Frame = R>,
110{
111 type Payload = O;
112 type Reader = crate::io::Reader<R, O::Request>;
113 type Actor = A;
114 type Error = std::io::Error;
115 type Future<'a> = TcpAcceptFut<'a, P, A, R, O>;
116}
117
118impl<P, A, R, O> Component for TcpListener<P, A, R, O>
119where
120 P: Actor + Ask<TcpRequest, Result = A>,
121 A: Actor + AsyncAsk<O::Request>,
122 R: IoRead<OwnedReadHalf> + Default + Message + std::fmt::Debug + Send + Sync + 'static,
123 O: DataFrameReceiver<Frame = R>,
124{
125 type Shutdown = Pin<Box<dyn Future<Output = ()> + Send + Sync + 'static>>;
126
127 #[allow(clippy::needless_lifetimes)]
128 fn accept<'a>(&'a mut self) -> Self::Future<'a> {
129 TcpAcceptFut {
130 listener: &self.listener,
131 executor: &mut self.executor,
132 _actor: PhantomData,
133 _reader: PhantomData,
134 _payload: PhantomData,
135 }
136 }
137
138 fn shutdown(self) -> Self::Shutdown {
139 Box::pin(async move { self.executor.raw_shutdown().await })
140 }
141}
142
143fn tcp_actors<
144 A: Actor,
145 R: IoRead<OwnedReadHalf> + Default + Send + 'static,
146 Payload: DataFrameReceiver<Frame = R>,
147>(
148 ctx: &Ctx<A>,
149 stream: TcpStream,
150) -> (Reader<R, Payload::Request>, Writer) {
151 let (read, write) = stream.into_split();
152 let (reader_tx, reader_rx) =
153 mpsc::channel::<(R, oneshot::Sender<std::io::Result<Payload::Request>>)>(10);
154 let (writer_tx, writer_rx) =
155 mpsc::channel::<(Vec<u8>, oneshot::Sender<std::io::Result<()>>)>(10);
156 let (shutdown_tx, shutdown_rx) = oneshot::channel();
157
158 let parent_rx = ctx.notifier.subscribe();
159 tokio::spawn(create_reader_actor::<OwnedReadHalf, R, Payload>(
160 read,
161 reader_rx,
162 parent_rx,
163 shutdown_tx,
164 ));
165
166 let parent_rx = ctx.notifier.subscribe();
167 tokio::spawn(create_writer_actor(
168 write,
169 writer_rx,
170 parent_rx,
171 shutdown_rx,
172 ));
173
174 let reader = Reader::<R, Payload::Request>::new(reader_tx);
175 let writer = Writer::new(writer_tx);
176
177 (reader, writer)
178}