1use super::{AddrEventManager, EventManager};
2use crate::core::Msg;
3
4use log::error;
5use crate::core::transport::{
6 Authenticator, Bicrypter, Decrypter, Encrypter, Signer,
7 TcpStreamInboundWire, TcpStreamOutboundWire, Verifier, Wire,
8};
9use std::collections::HashMap;
10use std::net::SocketAddr;
11use std::sync::Arc;
12use tokio::{
13 net::{TcpListener, TcpStream},
14 runtime::Handle,
15 sync::{mpsc, Mutex},
16};
17
18impl EventManager {
20 pub fn for_tcp_stream<A, B>(
21 handle: Handle,
22 max_outbound_queue: usize,
23 stream: TcpStream,
24 remote_addr: SocketAddr,
25 wire: Wire<A, B>,
26 on_inbound_tx: mpsc::Sender<(Msg, SocketAddr, mpsc::Sender<Vec<u8>>)>,
27 ) -> EventManager
28 where
29 A: Authenticator + Send + Sync + 'static,
30 B: Bicrypter + Send + Sync + 'static,
31 {
32 let (reader, writer) =
33 wire.with_tcp_stream(stream, remote_addr).arc_split();
34
35 let (tx, rx) = mpsc::channel::<Vec<u8>>(max_outbound_queue);
36
37 let inbound_handle = handle.spawn(tcp_stream_outbound_loop(rx, writer));
38 let outbound_handle = handle.spawn(tcp_stream_inbound_loop(
39 tx.clone(),
40 reader,
41 on_inbound_tx,
42 ));
43
44 EventManager {
45 inbound_handle,
46 outbound_handle,
47 tx,
48 }
49 }
50}
51
52impl AddrEventManager {
55 pub fn for_tcp_listener<A, B>(
56 handle: Handle,
57 max_outbound_queue: usize,
58 listener: TcpListener,
59 wire: Wire<A, B>,
60 on_inbound_tx: mpsc::Sender<(Msg, SocketAddr, mpsc::Sender<Vec<u8>>)>,
61 ) -> AddrEventManager
62 where
63 A: Authenticator + Send + Sync + Clone + 'static,
64 B: Bicrypter + Send + Sync + Clone + 'static,
65 {
66 let connections: Arc<
67 Mutex<HashMap<SocketAddr, mpsc::Sender<Vec<u8>>>>,
68 > = Arc::new(Mutex::new(HashMap::new()));
69 let (tx, rx) =
70 mpsc::channel::<(Vec<u8>, SocketAddr)>(max_outbound_queue);
71
72 let outbound_handle = handle
73 .spawn(tcp_listener_outbound_loop(rx, Arc::clone(&connections)));
74
75 let inbound_handle = handle.spawn(tcp_listener_inbound_loop(
76 handle.clone(),
77 listener,
78 wire,
79 connections,
80 on_inbound_tx,
81 max_outbound_queue,
82 ));
83
84 AddrEventManager {
85 outbound_handle,
86 inbound_handle,
87 tx,
88 }
89 }
90}
91
92async fn tcp_listener_outbound_loop(
95 mut rx: mpsc::Receiver<(Vec<u8>, SocketAddr)>,
96 connections: Arc<Mutex<HashMap<SocketAddr, mpsc::Sender<Vec<u8>>>>>,
97) {
98 while let Some((msg, addr)) = rx.recv().await {
99 if let Some(stream) = connections.lock().await.get_mut(&addr) {
100 if stream.send(msg).await.is_err() {
101 error!("Failed to send to {}", addr);
102 }
103 }
104 }
105}
106
107async fn tcp_listener_inbound_loop<A, B>(
111 handle: Handle,
112 mut listener: TcpListener,
113 wire: Wire<A, B>,
114 connections: Arc<Mutex<HashMap<SocketAddr, mpsc::Sender<Vec<u8>>>>>,
115 on_inbound_tx: mpsc::Sender<(Msg, SocketAddr, mpsc::Sender<Vec<u8>>)>,
116 max_outbound_queue: usize,
117) where
118 A: Authenticator + Send + Sync + Clone + 'static,
119 B: Bicrypter + Send + Sync + Clone + 'static,
120{
121 loop {
122 match listener.accept().await {
123 Ok((stream, addr)) => {
124 handle.spawn(tcp_listener_spawn_stream(
125 stream,
126 addr,
127 handle.clone(),
128 wire.clone(),
129 Arc::clone(&connections),
130 on_inbound_tx.clone(),
131 max_outbound_queue,
132 ));
133 }
134 Err(x) => {
135 error!("Listening for connections encountered error: {}", x);
136 break;
137 }
138 }
139 }
140}
141
142async fn tcp_listener_spawn_stream<A, B>(
146 stream: TcpStream,
147 addr: SocketAddr,
148 handle: Handle,
149 wire: Wire<A, B>,
150 connections: Arc<Mutex<HashMap<SocketAddr, mpsc::Sender<Vec<u8>>>>>,
151 on_inbound_tx: mpsc::Sender<(Msg, SocketAddr, mpsc::Sender<Vec<u8>>)>,
152 max_outbound_queue: usize,
153) where
154 A: Authenticator + Send + Sync + 'static,
155 B: Bicrypter + Send + Sync + 'static,
156{
157 let event_manager = EventManager::for_tcp_stream(
158 handle,
159 max_outbound_queue,
160 stream,
161 addr,
162 wire,
163 on_inbound_tx,
164 );
165
166 connections
167 .lock()
168 .await
169 .insert(addr, event_manager.tx.clone());
170
171 if let Err(x) = event_manager.wait().await {
174 error!("Event manager exited badly: {}", x);
175 }
176
177 connections.lock().await.remove(&addr);
178}
179
180async fn tcp_stream_outbound_loop<S, E>(
182 mut rx: mpsc::Receiver<Vec<u8>>,
183 mut writer: TcpStreamOutboundWire<S, E>,
184) where
185 S: Signer,
186 E: Encrypter,
187{
188 while let Some(msg) = rx.recv().await {
189 if let Err(x) = writer.write(&msg).await {
190 error!("Failed to send: {}", x);
191 }
192 }
193}
194
195async fn tcp_stream_inbound_loop<V, D>(
198 tx: mpsc::Sender<Vec<u8>>,
199 mut reader: TcpStreamInboundWire<V, D>,
200 on_inbound_tx: mpsc::Sender<(Msg, SocketAddr, mpsc::Sender<Vec<u8>>)>,
201) where
202 V: Verifier,
203 D: Decrypter,
204{
205 loop {
206 let tx_2 = tx.clone();
207 let result = reader.read().await;
208 if !super::process_inbound(result, tx_2, on_inbound_tx.clone()).await {
209 break;
210 }
211 }
212}