1use std::{
2 io,
3 net::SocketAddr,
4 sync::Arc,
5};
6
7use mid_net::{
8 prelude::{
9 impl_::interface::{
10 ICompressor,
11 IDecompressor,
12 },
13 *,
14 },
15 proto::{
16 PacketType,
17 Protocol,
18 ProtocolError,
19 },
20 utils::flags,
21};
22use tokio::net::TcpListener;
23
24use super::{
25 message_types::SlaveMessage,
26 utils::send_slave_message_to,
27};
28use crate::{
29 config::base::{
30 Config,
31 ProtocolPermissionsCfg,
32 },
33 tcp::{
34 slave::listener,
35 state::{
36 Permissions,
37 State,
38 },
39 views::MasterStateView,
40 },
41};
42
43pub async fn on_forward<W, R, C, D>(
45 writer: &mut MidWriter<W, C>,
46 reader: &mut MidReader<R, D>,
47 state: &State,
48 from: &SocketAddr,
49 flags: u8,
50 constraint: DecompressionConstraint,
51) -> io::Result<()>
52where
53 W: WriterUnderlyingExt,
54 R: ReaderUnderlyingExt,
55 C: ICompressor,
56 D: IDecompressor,
57{
58 match state.server {
59 Some(ref server) => {
60 let client_id = reader.read_client_id(flags).await?;
62 let length = reader.read_length(flags).await?;
63 let buffer = if flags::is_compressed(flags) {
64 reader
65 .read_compressed(
66 length as usize,
67 DecompressionStrategy::ConstrainedConst { constraint },
68 )
69 .await
70 } else {
71 reader
72 .read_buffer(length as usize)
73 .await
74 .map_err(|e| e.into())
75 };
76 let buffer = match buffer {
77 Ok(b) => b,
78 Err(CompressedReadError::Io(error)) => return Err(error),
79 Err(e) => {
80 tracing::error!(
81 %from,
82 "Failed to decompress forward packet: {e}"
83 );
84 return Ok(());
85 }
86 };
87
88 if server.forward(client_id, buffer).await.is_err() {
89 writer
90 .server()
91 .write_failure(ProtocolError::ClientDoesNotExists)
92 .await
93 } else {
94 Ok(())
95 }
96 }
97
98 None => {
99 writer
100 .server()
101 .write_failure(ProtocolError::ServerIsNotCreated)
102 .await
103 }
104 }
105}
106
107pub async fn on_disconnect<W, R, C, D>(
109 writer: &mut MidWriter<W, C>,
110 reader: &mut MidReader<R, D>,
111 state: &mut State,
112 flags: u8,
113) -> io::Result<()>
114where
115 W: WriterUnderlyingExt,
116 R: ReaderUnderlyingExt,
117{
118 let client_id = reader.read_client_id(flags).await?;
119 send_slave_message_to(writer, client_id, state, SlaveMessage::Disconnect)
120 .await?;
121 Ok(())
122}
123
124pub async fn on_create_server<W, R, C, D>(
127 writer: &mut MidWriter<W, C>,
128 reader: &mut MidReader<R, D>,
129 state: &mut State,
130 from: &SocketAddr,
131 packet_flags: u8,
132) -> io::Result<()>
133where
134 W: WriterUnderlyingExt,
135 R: ReaderUnderlyingExt,
136{
137 if state.has_server() {
138 return writer
139 .server()
140 .write_failure(ProtocolError::AlreadyCreated)
141 .await;
142 }
143
144 let protocol = if flags::is_compressed(packet_flags) {
145 Protocol::Tcp
146 } else {
147 Protocol::Udp
148 };
149
150 match protocol {
151 Protocol::Tcp if state.permissions.can(Permissions::CREATE_TCP) => {
152 let port = if flags::is_compressed(packet_flags) {
153 0
154 } else {
155 let port = reader.read_u16().await?;
156 if state
157 .permissions
158 .can(Permissions::SELECT_TCP_PORT)
159 {
160 port
161 } else {
162 tracing::error!(
163 %from,
164 port,
165 "Create server with custom port failed: access denied"
166 );
167 return writer
168 .server()
169 .write_failure(ProtocolError::AccessDenied)
170 .await;
171 }
172 };
173 let listener = match TcpListener::bind(("0.0.0.0", port)).await {
174 Ok(l) => l,
175 Err(error) => {
176 tracing::error!(
177 %error,
178 %from,
179 "Failed to create TCP listener"
180 );
181
182 return writer
183 .server()
184 .write_failure(ProtocolError::FailedToCreateListener)
185 .await;
186 }
187 };
188 let listening_at_port = if port == 0 {
189 match listener.local_addr().map(|a| a.port()) {
190 Ok(p) => p,
191 Err(error) => {
192 tracing::error!(
193 %from,
194 %error,
195 "Failed to retrieve TCP port from the system"
196 );
197
198 return writer
199 .server()
200 .write_failure(ProtocolError::FailedToRetrievePort)
201 .await;
202 }
203 }
204 } else {
205 port
206 };
207
208 let (shutdown_token, master_tx, created_server) =
209 state.create_server(listening_at_port);
210 tracing::info!(%from, "Started server at 0.0.0.0:{listening_at_port}");
211
212 tokio::spawn(listener::run_slave_tcp_listener(
213 listener,
214 *from,
215 shutdown_token,
216 MasterStateView {
217 pool: Arc::clone(&created_server.pool),
218 master: master_tx,
219 },
220 ));
221
222 writer
223 .server()
224 .write_server(listening_at_port)
225 .await
226 }
227
228 Protocol::Udp if state.permissions.can(Permissions::CREATE_UDP) => {
229 writer
230 .server()
231 .write_failure(ProtocolError::Unimplemented)
232 .await
233 }
234
235 tried_proto => {
236 tracing::error!(
237 %from,
238 ?tried_proto,
239 "Create server with custom protocol failed: access denied"
240 );
241 writer
242 .server()
243 .write_failure(ProtocolError::AccessDenied)
244 .await
245 }
246 }
247}
248
249pub async fn on_authorize<W, R, C, D>(
253 writer: &mut MidWriter<W, C>,
254 reader: &mut MidReader<R, D>,
255 state: &mut State,
256 from: &SocketAddr,
257 success_perms: &ProtocolPermissionsCfg,
258 actual_password: &Option<String>,
259) -> io::Result<()>
260where
261 W: WriterUnderlyingExt,
262 R: ReaderUnderlyingExt,
263{
264 let supplied_password = reader.read_string_prefixed().await?;
265 if let Some(actual_password) = actual_password {
266 if &supplied_password == actual_password {
267 state.permissions = Permissions::from_cfg(success_perms);
268 tracing::info!(
269 %from,
270 supplied_password,
271 "Universal password authorization request: access granted"
272 );
273 writer
274 .server()
275 .write_update_rights(state.permissions.bits())
276 .await
277 } else {
278 tracing::error!(
279 %from,
280 supplied_password,
281 "Universal password authorization request: wrong password"
282 );
283 writer
284 .server()
285 .write_failure(ProtocolError::AccessDenied)
286 .await
287 }
288 } else {
289 tracing::error!(
290 %from,
291 supplied_password,
292 "Universal password authorization request: feature is disabled"
293 );
294
295 writer
296 .server()
297 .write_failure(ProtocolError::Disabled)
298 .await
299 }
300}
301
302pub async fn on_ping<W: WriterUnderlyingExt, C>(
306 writer: &mut MidWriter<W, C>,
307 config: &Config,
308) -> io::Result<()> {
309 writer
310 .server()
311 .write_ping(
312 &config.server.name,
313 config.compression.tcp.algorithm,
314 config
315 .server
316 .bufferization
317 .read
318 .try_into()
319 .unwrap_or_else(|e| {
320 let fallback_maximum = u16::MAX;
321 tracing::error!(
322 fallback_maximum,
323 "Failed to write bufferization value ({e}), writing \
324 back fallback maximum"
325 );
326
327 fallback_maximum
328 }),
329 )
330 .await
331}
332
333pub async fn on_unexpected<W: WriterUnderlyingExt, C>(
336 writer: &mut MidWriter<W, C>,
337 from: &SocketAddr,
338 packet_type: PacketType,
339) -> io::Result<()> {
340 tracing::error!(?packet_type, %from, "Sent unexpected packet");
341 writer
342 .server()
343 .write_failure(ProtocolError::UnexpectedPacket)
344 .await
345}
346
347pub async fn on_unknown_packet<W: WriterUnderlyingExt, C>(
350 writer: &mut MidWriter<W, C>,
351 from: SocketAddr,
352 packet_type: u8,
353 packet_flags: u8,
354) -> io::Result<()> {
355 tracing::error!(
356 packet_type,
357 packet_flags,
358 %from,
359 "Unknown packet type received"
360 );
361 writer
362 .server()
363 .write_failure(ProtocolError::UnknownPacket)
364 .await
365}