mmids_core/net/tcp/
socket_manager.rs

1use super::listener::{start as start_listener, ListenerParams};
2use super::{TcpSocketRequest, TcpSocketResponse};
3use crate::net::tcp::{RequestFailureReason, TlsOptions};
4use futures::future::BoxFuture;
5use futures::stream::{FuturesUnordered, StreamExt};
6use futures::FutureExt;
7use std::collections::HashMap;
8use std::sync::Arc;
9use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};
10use tracing::{debug, error, info};
11
12/// Starts a new instance of a socket manager task.  A socket manager can be requested to open
13/// ports on behalf of another system.  If the port is successfully opened it will begin listening
14/// for TCP connections on that port, and then manage the reading and writing of network traffic
15/// for that connection.
16pub fn start(tls_options: Option<TlsOptions>) -> UnboundedSender<TcpSocketRequest> {
17    let (request_sender, request_receiver) = unbounded_channel();
18
19    let manager = SocketManager::new();
20    tokio::spawn(manager.run(request_receiver, tls_options));
21
22    request_sender
23}
24
25enum SocketManagerFutureResult {
26    IncomingRequest {
27        request: Option<TcpSocketRequest>,
28        receiver: UnboundedReceiver<TcpSocketRequest>,
29    },
30    ListenerShutdown {
31        port: u16,
32    },
33}
34
35struct OpenPort {
36    response_channel: UnboundedSender<TcpSocketResponse>,
37}
38
39struct SocketManager {
40    open_ports: HashMap<u16, OpenPort>,
41    futures: FuturesUnordered<BoxFuture<'static, SocketManagerFutureResult>>,
42}
43
44impl SocketManager {
45    fn new() -> Self {
46        SocketManager {
47            open_ports: HashMap::new(),
48            futures: FuturesUnordered::new(),
49        }
50    }
51
52    async fn run(
53        mut self,
54        request_receiver: UnboundedReceiver<TcpSocketRequest>,
55        tls_options: Option<TlsOptions>,
56    ) {
57        info!("Starting TCP socket manager");
58        let tls_options = Arc::new(tls_options);
59
60        self.futures
61            .push(request_receiver_future(request_receiver).boxed());
62
63        while let Some(future_result) = self.futures.next().await {
64            match future_result {
65                SocketManagerFutureResult::IncomingRequest { request, receiver } => {
66                    self.futures.push(request_receiver_future(receiver).boxed());
67
68                    match request {
69                        Some(request) => self.handle_request(request, tls_options.clone()),
70                        None => break, // no more senders of requests
71                    }
72                }
73
74                SocketManagerFutureResult::ListenerShutdown { port } => {
75                    match self.open_ports.remove(&port) {
76                        None => (),
77                        Some(details) => {
78                            let _ = details
79                                .response_channel
80                                .send(TcpSocketResponse::PortForciblyClosed { port });
81                        }
82                    }
83                }
84            }
85        }
86
87        info!("Socket manager closing");
88    }
89
90    fn handle_request(&mut self, request: TcpSocketRequest, tls_options: Arc<Option<TlsOptions>>) {
91        match request {
92            TcpSocketRequest::OpenPort {
93                port,
94                response_channel,
95                use_tls,
96            } => {
97                if use_tls && tls_options.as_ref().is_none() {
98                    error!(
99                        port = port,
100                        "Request to open port with tls, but we have no tls options"
101                    );
102                    let _ = response_channel.send(TcpSocketResponse::RequestDenied {
103                        reason: RequestFailureReason::NoTlsDetailsGiven,
104                    });
105
106                    return;
107                }
108
109                if self.open_ports.contains_key(&port) {
110                    debug!(port = port, "Port is already in use!");
111                    let message = TcpSocketResponse::RequestDenied {
112                        reason: RequestFailureReason::PortInUse,
113                    };
114
115                    let _ = response_channel.send(message);
116                } else {
117                    debug!(port = port, use_tls = use_tls, "TCP port being opened");
118                    let details = OpenPort {
119                        response_channel: response_channel.clone(),
120                    };
121
122                    self.open_ports.insert(port, details);
123
124                    let listener_shutdown = start_listener(ListenerParams {
125                        port,
126                        response_channel: response_channel.clone(),
127                        use_tls,
128                        tls_options: tls_options.clone(),
129                    });
130
131                    self.futures
132                        .push(listener_shutdown_future(port, listener_shutdown).boxed());
133
134                    let _ = response_channel.send(TcpSocketResponse::RequestAccepted {});
135                }
136            }
137        }
138    }
139}
140
141async fn request_receiver_future(
142    mut receiver: UnboundedReceiver<TcpSocketRequest>,
143) -> SocketManagerFutureResult {
144    let result = receiver.recv().await;
145
146    SocketManagerFutureResult::IncomingRequest {
147        request: result,
148        receiver,
149    }
150}
151
152async fn listener_shutdown_future(
153    port: u16,
154    signal: UnboundedSender<()>,
155) -> SocketManagerFutureResult {
156    signal.closed().await;
157
158    SocketManagerFutureResult::ListenerShutdown { port }
159}