neogrok/tcp/handlers/
network.rs

1use std::{
2    io,
3    net::SocketAddr,
4    sync::Arc,
5};
6
7use mid_net::{
8    prelude::{
9        impl_::interface::{
10            ICompressor,
11            IDecompressor,
12        },
13        *,
14    },
15    proto::{
16        PacketType,
17        Protocol,
18        ProtocolError,
19    },
20    utils::flags,
21};
22use tokio::net::TcpListener;
23
24use super::{
25    message_types::SlaveMessage,
26    utils::send_slave_message_to,
27};
28use crate::{
29    config::base::{
30        Config,
31        ProtocolPermissionsCfg,
32    },
33    tcp::{
34        slave::listener,
35        state::{
36            Permissions,
37            State,
38        },
39        views::MasterStateView,
40    },
41};
42
43/// Triggered when `forward` packet arrived.
44pub async fn on_forward<W, R, C, D>(
45    writer: &mut MidWriter<W, C>,
46    reader: &mut MidReader<R, D>,
47    state: &State,
48    from: &SocketAddr,
49    flags: u8,
50    constraint: DecompressionConstraint,
51) -> io::Result<()>
52where
53    W: WriterUnderlyingExt,
54    R: ReaderUnderlyingExt,
55    C: ICompressor,
56    D: IDecompressor,
57{
58    match state.server {
59        Some(ref server) => {
60            // TODO: restrict maximum length
61            let client_id = reader.read_client_id(flags).await?;
62            let length = reader.read_length(flags).await?;
63            let buffer = if flags::is_compressed(flags) {
64                reader
65                    .read_compressed(
66                        length as usize,
67                        DecompressionStrategy::ConstrainedConst { constraint },
68                    )
69                    .await
70            } else {
71                reader
72                    .read_buffer(length as usize)
73                    .await
74                    .map_err(|e| e.into())
75            };
76            let buffer = match buffer {
77                Ok(b) => b,
78                Err(CompressedReadError::Io(error)) => return Err(error),
79                Err(e) => {
80                    tracing::error!(
81                        %from,
82                        "Failed to decompress forward packet: {e}"
83                    );
84                    return Ok(());
85                }
86            };
87
88            if server.forward(client_id, buffer).await.is_err() {
89                writer
90                    .server()
91                    .write_failure(ProtocolError::ClientDoesNotExists)
92                    .await
93            } else {
94                Ok(())
95            }
96        }
97
98        None => {
99            writer
100                .server()
101                .write_failure(ProtocolError::ServerIsNotCreated)
102                .await
103        }
104    }
105}
106
107/// Called when `disconnected` packet arrived
108pub async fn on_disconnect<W, R, C, D>(
109    writer: &mut MidWriter<W, C>,
110    reader: &mut MidReader<R, D>,
111    state: &mut State,
112    flags: u8,
113) -> io::Result<()>
114where
115    W: WriterUnderlyingExt,
116    R: ReaderUnderlyingExt,
117{
118    let client_id = reader.read_client_id(flags).await?;
119    send_slave_message_to(writer, client_id, state, SlaveMessage::Disconnect)
120        .await?;
121    Ok(())
122}
123
124/// This is called when `create server` packet issued.
125/// Creates server for supplied protocol.
126pub async fn on_create_server<W, R, C, D>(
127    writer: &mut MidWriter<W, C>,
128    reader: &mut MidReader<R, D>,
129    state: &mut State,
130    from: &SocketAddr,
131    packet_flags: u8,
132) -> io::Result<()>
133where
134    W: WriterUnderlyingExt,
135    R: ReaderUnderlyingExt,
136{
137    if state.has_server() {
138        return writer
139            .server()
140            .write_failure(ProtocolError::AlreadyCreated)
141            .await;
142    }
143
144    let protocol = if flags::is_compressed(packet_flags) {
145        Protocol::Tcp
146    } else {
147        Protocol::Udp
148    };
149
150    match protocol {
151        Protocol::Tcp if state.permissions.can(Permissions::CREATE_TCP) => {
152            let port = if flags::is_compressed(packet_flags) {
153                0
154            } else {
155                let port = reader.read_u16().await?;
156                if state
157                    .permissions
158                    .can(Permissions::SELECT_TCP_PORT)
159                {
160                    port
161                } else {
162                    tracing::error!(
163                        %from,
164                        port,
165                        "Create server with custom port failed: access denied"
166                    );
167                    return writer
168                        .server()
169                        .write_failure(ProtocolError::AccessDenied)
170                        .await;
171                }
172            };
173            let listener = match TcpListener::bind(("0.0.0.0", port)).await {
174                Ok(l) => l,
175                Err(error) => {
176                    tracing::error!(
177                        %error,
178                        %from,
179                        "Failed to create TCP listener"
180                    );
181
182                    return writer
183                        .server()
184                        .write_failure(ProtocolError::FailedToCreateListener)
185                        .await;
186                }
187            };
188            let listening_at_port = if port == 0 {
189                match listener.local_addr().map(|a| a.port()) {
190                    Ok(p) => p,
191                    Err(error) => {
192                        tracing::error!(
193                            %from,
194                            %error,
195                            "Failed to retrieve TCP port from the system"
196                        );
197
198                        return writer
199                            .server()
200                            .write_failure(ProtocolError::FailedToRetrievePort)
201                            .await;
202                    }
203                }
204            } else {
205                port
206            };
207
208            let (shutdown_token, master_tx, created_server) =
209                state.create_server(listening_at_port);
210            tracing::info!(%from, "Started server at 0.0.0.0:{listening_at_port}");
211
212            tokio::spawn(listener::run_slave_tcp_listener(
213                listener,
214                *from,
215                shutdown_token,
216                MasterStateView {
217                    pool: Arc::clone(&created_server.pool),
218                    master: master_tx,
219                },
220            ));
221
222            writer
223                .server()
224                .write_server(listening_at_port)
225                .await
226        }
227
228        Protocol::Udp if state.permissions.can(Permissions::CREATE_UDP) => {
229            writer
230                .server()
231                .write_failure(ProtocolError::Unimplemented)
232                .await
233        }
234
235        tried_proto => {
236            tracing::error!(
237                %from,
238                ?tried_proto,
239                "Create server with custom protocol failed: access denied"
240            );
241            writer
242                .server()
243                .write_failure(ProtocolError::AccessDenied)
244                .await
245        }
246    }
247}
248
249/// Tries to authorize user using supplied password. On
250/// success changes its permissions to the
251/// `universal_password` permissions level.
252pub async fn on_authorize<W, R, C, D>(
253    writer: &mut MidWriter<W, C>,
254    reader: &mut MidReader<R, D>,
255    state: &mut State,
256    from: &SocketAddr,
257    success_perms: &ProtocolPermissionsCfg,
258    actual_password: &Option<String>,
259) -> io::Result<()>
260where
261    W: WriterUnderlyingExt,
262    R: ReaderUnderlyingExt,
263{
264    let supplied_password = reader.read_string_prefixed().await?;
265    if let Some(actual_password) = actual_password {
266        if &supplied_password == actual_password {
267            state.permissions = Permissions::from_cfg(success_perms);
268            tracing::info!(
269                %from,
270                supplied_password,
271                "Universal password authorization request: access granted"
272            );
273            writer
274                .server()
275                .write_update_rights(state.permissions.bits())
276                .await
277        } else {
278            tracing::error!(
279                %from,
280                supplied_password,
281                "Universal password authorization request: wrong password"
282            );
283            writer
284                .server()
285                .write_failure(ProtocolError::AccessDenied)
286                .await
287        }
288    } else {
289        tracing::error!(
290            %from,
291            supplied_password,
292            "Universal password authorization request: feature is disabled"
293        );
294
295        writer
296            .server()
297            .write_failure(ProtocolError::Disabled)
298            .await
299    }
300}
301
302/// Reacts to the `ping`. Basically writes the server name,
303/// compression algorithm and the read bufferization
304/// settings
305pub async fn on_ping<W: WriterUnderlyingExt, C>(
306    writer: &mut MidWriter<W, C>,
307    config: &Config,
308) -> io::Result<()> {
309    writer
310        .server()
311        .write_ping(
312            &config.server.name,
313            config.compression.tcp.algorithm,
314            config
315                .server
316                .bufferization
317                .read
318                .try_into()
319                .unwrap_or_else(|e| {
320                    let fallback_maximum = u16::MAX;
321                    tracing::error!(
322                        fallback_maximum,
323                        "Failed to write bufferization value ({e}), writing \
324                         back fallback maximum"
325                    );
326
327                    fallback_maximum
328                }),
329        )
330        .await
331}
332
333/// Triggered if packet type is unexpected for the server
334/// side, for example: `error` packet is unexpected.
335pub async fn on_unexpected<W: WriterUnderlyingExt, C>(
336    writer: &mut MidWriter<W, C>,
337    from: &SocketAddr,
338    packet_type: PacketType,
339) -> io::Result<()> {
340    tracing::error!(?packet_type, %from, "Sent unexpected packet");
341    writer
342        .server()
343        .write_failure(ProtocolError::UnexpectedPacket)
344        .await
345}
346
347/// Called when router receives unknown packet type.
348/// Basically just logs & writes the error
349pub async fn on_unknown_packet<W: WriterUnderlyingExt, C>(
350    writer: &mut MidWriter<W, C>,
351    from: SocketAddr,
352    packet_type: u8,
353    packet_flags: u8,
354) -> io::Result<()> {
355    tracing::error!(
356        packet_type,
357        packet_flags,
358        %from,
359        "Unknown packet type received"
360    );
361    writer
362        .server()
363        .write_failure(ProtocolError::UnknownPacket)
364        .await
365}