mum_cli/network/
tcp.rs

1use crate::error::{ServerSendError, TcpError};
2use crate::network::ConnectionInfo;
3use crate::notifications;
4use crate::state::server::Server;
5use crate::state::{State, StatePhase};
6
7use futures_util::select;
8use futures_util::stream::{SplitSink, SplitStream, Stream};
9use futures_util::{FutureExt, SinkExt, StreamExt};
10use log::*;
11use mumble_protocol::control::{msgs, ClientControlCodec, ControlCodec, ControlPacket};
12use mumble_protocol::crypt::ClientCryptState;
13use mumble_protocol::voice::VoicePacket;
14use mumble_protocol::{Clientbound, Serverbound};
15use mumlib::command::MumbleEventKind;
16use std::collections::HashMap;
17use std::convert::Into;
18use std::fmt::Debug;
19use std::net::SocketAddr;
20use std::sync::{Arc, RwLock};
21use tokio::net::TcpStream;
22use tokio::sync::{mpsc, watch, Mutex};
23use tokio::time::{self, Duration};
24use tokio_native_tls::{TlsConnector, TlsStream};
25use tokio_util::codec::{Decoder, Framed};
26
27use super::{run_until, VoiceStreamType};
28
29type TcpSender = SplitSink<
30    Framed<TlsStream<TcpStream>, ControlCodec<Serverbound, Clientbound>>,
31    ControlPacket<Serverbound>,
32>;
33type TcpReceiver =
34    SplitStream<Framed<TlsStream<TcpStream>, ControlCodec<Serverbound, Clientbound>>>;
35
36pub(crate) type TcpEventCallback = Box<dyn FnOnce(TcpEventData<'_>)>;
37pub(crate) type TcpEventSubscriber = Box<dyn FnMut(TcpEventData<'_>) -> bool>; //the bool indicates if it should be kept or not
38
39/// Why the TCP was disconnected.
40#[derive(Debug, Clone, Copy, Hash, Eq, PartialEq)]
41pub enum DisconnectedReason {
42    InvalidTls,
43    User,
44    TcpError,
45}
46
47/// Something a callback can register to. Data is sent via a respective [TcpEventData].
48#[derive(Debug, Clone, Copy, Hash, Eq, PartialEq)]
49pub enum TcpEvent {
50    Connected,                        //fires when the client has connected to a server
51    Disconnected(DisconnectedReason), //fires when the client has disconnected from a server
52    TextMessage,                      //fires when a text message comes in
53}
54
55/// When a [TcpEvent] occurs, this contains the data for the event.
56///
57/// The events are picked up by a [crate::state::ExecutionContext].
58///
59/// Having two different types might feel a bit confusing. Essentially, a
60/// callback _registers_ to a [TcpEvent] but _takes_ a [TcpEventData] as
61/// parameter.
62#[derive(Clone, Debug)]
63pub enum TcpEventData<'a> {
64    Connected(Result<&'a msgs::ServerSync, mumlib::Error>),
65    Disconnected(DisconnectedReason),
66    TextMessage(&'a msgs::TextMessage),
67}
68
69impl From<&TcpEventData<'_>> for TcpEvent {
70    fn from(t: &TcpEventData<'_>) -> Self {
71        match t {
72            TcpEventData::Connected(_) => TcpEvent::Connected,
73            TcpEventData::Disconnected(reason) => TcpEvent::Disconnected(*reason),
74            TcpEventData::TextMessage(_) => TcpEvent::TextMessage,
75        }
76    }
77}
78
79#[derive(Clone, Default)]
80pub struct TcpEventQueue {
81    callbacks: Arc<RwLock<HashMap<TcpEvent, Vec<TcpEventCallback>>>>,
82    subscribers: Arc<RwLock<HashMap<TcpEvent, Vec<TcpEventSubscriber>>>>,
83}
84
85impl TcpEventQueue {
86    /// Creates a new `TcpEventQueue`.
87    pub fn new() -> Self {
88        Self {
89            callbacks: Arc::new(RwLock::new(HashMap::new())),
90            subscribers: Arc::new(RwLock::new(HashMap::new())),
91        }
92    }
93
94    /// Registers a new callback to be triggered when an event is fired.
95    pub fn register_callback(&self, at: TcpEvent, callback: TcpEventCallback) {
96        self.callbacks
97            .write()
98            .unwrap()
99            .entry(at)
100            .or_default()
101            .push(callback);
102    }
103
104    /// Registers a new callback to be triggered when an event is fired.
105    pub fn register_subscriber(&self, at: TcpEvent, callback: TcpEventSubscriber) {
106        self.subscribers
107            .write()
108            .unwrap()
109            .entry(at)
110            .or_default()
111            .push(callback);
112    }
113
114    /// Fires all callbacks related to a specific TCP event and removes them from the event queue.
115    /// Also calls all event subscribers, but keeps them in the queue
116    pub fn resolve(&self, data: TcpEventData<'_>) {
117        if let Some(vec) = self
118            .callbacks
119            .write()
120            .unwrap()
121            .get_mut(&TcpEvent::from(&data))
122        {
123            let old = std::mem::take(vec);
124            for handler in old {
125                handler(data.clone());
126            }
127        }
128        if let Some(vec) = self
129            .subscribers
130            .write()
131            .unwrap()
132            .get_mut(&TcpEvent::from(&data))
133        {
134            let old = std::mem::take(vec);
135            for mut e in old {
136                if e(data.clone()) {
137                    vec.push(e)
138                }
139            }
140        }
141    }
142}
143
144impl Debug for TcpEventQueue {
145    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
146        f.debug_struct("TcpEventQueue").finish()
147    }
148}
149
150pub async fn handle(
151    state: Arc<RwLock<State>>,
152    mut connection_info_receiver: watch::Receiver<Option<ConnectionInfo>>,
153    crypt_state_sender: mpsc::Sender<ClientCryptState>,
154    packet_sender: mpsc::UnboundedSender<ControlPacket<Serverbound>>,
155    mut packet_receiver: mpsc::UnboundedReceiver<ControlPacket<Serverbound>>,
156    event_queue: TcpEventQueue,
157) -> Result<(), TcpError> {
158    loop {
159        let connection_info = loop {
160            if connection_info_receiver.changed().await.is_ok() {
161                if let Some(data) = connection_info_receiver.borrow().clone() {
162                    break data;
163                }
164            } else {
165                return Err(TcpError::NoConnectionInfoReceived);
166            }
167        };
168        let connect_result = connect(
169            connection_info.socket_addr,
170            connection_info.hostname,
171            connection_info.accept_invalid_cert,
172        )
173        .await;
174
175        let (mut sink, stream) = match connect_result {
176            Ok(ok) => ok,
177            Err(TcpError::TlsConnectError(_)) => {
178                warn!("Invalid TLS");
179                state
180                    .read()
181                    .unwrap()
182                    .broadcast_phase(StatePhase::Disconnected);
183                event_queue.resolve(TcpEventData::Disconnected(DisconnectedReason::InvalidTls));
184                continue;
185            }
186            Err(e) => {
187                return Err(e);
188            }
189        };
190
191        // Handshake (omitting `Version` message for brevity)
192        let (username, password) = {
193            let state_lock = state.read().unwrap();
194            (
195                state_lock.username().unwrap().to_string(),
196                state_lock.password().map(|x| x.to_string()),
197            )
198        };
199        authenticate(&mut sink, username, password).await?;
200        let (phase_watcher, input_receiver) = {
201            let state_lock = state.read().unwrap();
202            (
203                state_lock.phase_receiver(),
204                state_lock.audio_input().receiver(),
205            )
206        };
207
208        info!("Logging in...");
209
210        let phase_watcher_inner = phase_watcher.clone();
211
212        let result = run_until(
213            |phase| matches!(phase, StatePhase::Disconnected),
214            async {
215                select! {
216                    r = send_pings(packet_sender.clone(), 10).fuse() => r,
217                    r = listen(
218                        Arc::clone(&state),
219                        stream,
220                        crypt_state_sender.clone(),
221                        event_queue.clone(),
222                    ).fuse() => r,
223                    r = send_voice(
224                        packet_sender.clone(),
225                        Arc::clone(&input_receiver),
226                        phase_watcher_inner,
227                    ).fuse() => r,
228                    r = send_packets(sink, &mut packet_receiver).fuse() => r,
229                }
230            },
231            phase_watcher,
232        )
233        .await
234        .unwrap_or(Ok(()));
235
236        match result {
237            Ok(()) => event_queue.resolve(TcpEventData::Disconnected(DisconnectedReason::User)),
238            Err(_) => event_queue.resolve(TcpEventData::Disconnected(DisconnectedReason::TcpError)),
239        }
240
241        debug!("Fully disconnected TCP stream, waiting for new connection info");
242    }
243}
244
245async fn connect(
246    server_addr: SocketAddr,
247    server_host: String,
248    accept_invalid_cert: bool,
249) -> Result<(TcpSender, TcpReceiver), TcpError> {
250    let stream = TcpStream::connect(&server_addr).await?;
251    debug!("TCP connected");
252
253    let mut builder = native_tls::TlsConnector::builder();
254    builder.danger_accept_invalid_certs(accept_invalid_cert);
255    let connector: TlsConnector = builder
256        .build()
257        .map_err(TcpError::TlsConnectorBuilderError)?
258        .into();
259    let tls_stream = connector
260        .connect(&server_host, stream)
261        .await
262        .map_err(TcpError::TlsConnectError)?;
263    debug!("TLS connected");
264
265    // Wrap the TLS stream with Mumble's client-side control-channel codec
266    Ok(ClientControlCodec::new().framed(tls_stream).split())
267}
268
269async fn authenticate(
270    sink: &mut TcpSender,
271    username: String,
272    password: Option<String>,
273) -> Result<(), TcpError> {
274    let mut msg = msgs::Authenticate::new();
275    msg.set_username(username);
276    if let Some(password) = password {
277        msg.set_password(password);
278    }
279    msg.set_opus(true);
280    sink.send(msg.into()).await?;
281    Ok(())
282}
283
284async fn send_pings(
285    packet_sender: mpsc::UnboundedSender<ControlPacket<Serverbound>>,
286    delay_seconds: u64,
287) -> Result<(), TcpError> {
288    let mut interval = time::interval(Duration::from_secs(delay_seconds));
289    loop {
290        interval.tick().await;
291        trace!("Sending TCP ping");
292        let msg = msgs::Ping::new();
293        packet_sender.send(msg.into())?;
294    }
295}
296
297async fn send_packets(
298    mut sink: TcpSender,
299    packet_receiver: &mut mpsc::UnboundedReceiver<ControlPacket<Serverbound>>,
300) -> Result<(), TcpError> {
301    loop {
302        // Safe since we always have at least one sender alive.
303        let packet = packet_receiver.recv().await.unwrap();
304        sink.send(packet).await?;
305    }
306}
307
308async fn send_voice(
309    packet_sender: mpsc::UnboundedSender<ControlPacket<Serverbound>>,
310    receiver: Arc<Mutex<Box<(dyn Stream<Item = VoicePacket<Serverbound>> + Unpin)>>>,
311    phase_watcher: watch::Receiver<StatePhase>,
312) -> Result<(), TcpError> {
313    loop {
314        let mut inner_phase_watcher = phase_watcher.clone();
315        loop {
316            inner_phase_watcher.changed().await.unwrap();
317            if matches!(
318                *inner_phase_watcher.borrow(),
319                StatePhase::Connected(VoiceStreamType::Tcp)
320            ) {
321                break;
322            }
323        }
324        run_until(
325            |phase| !matches!(phase, StatePhase::Connected(VoiceStreamType::Tcp)),
326            async {
327                loop {
328                    packet_sender.send(
329                        receiver
330                            .lock()
331                            .await
332                            .next()
333                            .await
334                            .expect("No audio stream")
335                            .into(),
336                    )?;
337                }
338            },
339            inner_phase_watcher.clone(),
340        )
341        .await
342        .unwrap_or(Ok::<(), ServerSendError>(()))?;
343    }
344}
345
346async fn listen(
347    state: Arc<RwLock<State>>,
348    mut stream: TcpReceiver,
349    crypt_state_sender: mpsc::Sender<ClientCryptState>,
350    event_queue: TcpEventQueue,
351) -> Result<(), TcpError> {
352    let mut crypt_state = None;
353    let mut crypt_state_sender = Some(crypt_state_sender);
354
355    let mut last_late = 0;
356    let mut last_lost = 0;
357    let mut last_resync = 0;
358
359    loop {
360        let packet = match stream.next().await {
361            Some(Ok(packet)) => packet,
362            Some(Err(e)) => {
363                error!("TCP error: {:?}", e);
364                continue; //TODO Break here? Maybe look at the error and handle it
365            }
366            None => {
367                // We end up here if the login was rejected. We probably want
368                // to exit before that.
369                warn!("TCP stream gone");
370                state
371                    .read()
372                    .unwrap()
373                    .broadcast_phase(StatePhase::Disconnected);
374                break;
375            }
376        };
377        match packet {
378            ControlPacket::TextMessage(msg) => {
379                let mut state = state.write().unwrap();
380                let server = state.server();
381                let user = (if let Server::Connected(s) = server {
382                    Some(s)
383                } else {
384                    None
385                })
386                .and_then(|server| server.users().get(&msg.get_actor()))
387                .map(|user| user.name());
388                if let Some(user) = user {
389                    notifications::send(format!("{}: {}", user, msg.get_message()));
390                    //TODO: probably want a config flag for this
391                    let user = user.to_string();
392                    state.push_event(MumbleEventKind::TextMessageReceived(user))
393                    //TODO also include message target
394                }
395                state.register_message((msg.get_message().to_owned(), msg.get_actor()));
396                drop(state);
397                event_queue.resolve(TcpEventData::TextMessage(&*msg));
398            }
399            ControlPacket::CryptSetup(msg) => {
400                debug!("Crypt setup");
401                // Wait until we're fully connected before initiating UDP voice
402                crypt_state = Some(ClientCryptState::new_from(
403                    msg.get_key()
404                        .try_into()
405                        .expect("Server sent private key with incorrect size"),
406                    msg.get_client_nonce()
407                        .try_into()
408                        .expect("Server sent client_nonce with incorrect size"),
409                    msg.get_server_nonce()
410                        .try_into()
411                        .expect("Server sent server_nonce with incorrect size"),
412                ));
413            }
414            ControlPacket::ServerSync(msg) => {
415                info!("Logged in");
416                if let Some(sender) = crypt_state_sender.take() {
417                    let _ = sender
418                        .send(
419                            crypt_state
420                                .take()
421                                .expect("Server didn't send us any CryptSetup packet!"),
422                        )
423                        .await;
424                }
425                let mut state = state.write().unwrap();
426                let server = state.server_mut();
427                if let Server::Connecting(sb) = server {
428                    let s = sb.clone().server_sync(*msg.clone());
429                    *server = Server::Connected(s);
430                    state.initialized();
431                } else {
432                    warn!(
433                        "Got a ServerSync packet while not connecting. Current state is:\n{:#?}",
434                        server
435                    );
436                }
437                drop(state);
438                event_queue.resolve(TcpEventData::Connected(Ok(&msg)));
439            }
440            ControlPacket::Reject(msg) => {
441                debug!("Login rejected: {:?}", msg);
442                match msg.get_field_type() {
443                    msgs::Reject_RejectType::WrongServerPW => {
444                        event_queue.resolve(TcpEventData::Connected(Err(
445                            mumlib::Error::InvalidServerPassword,
446                        )));
447                    }
448                    ty => {
449                        warn!("Unhandled reject type: {:?}", ty);
450                    }
451                }
452            }
453            ControlPacket::UserState(msg) => {
454                state.write().unwrap().user_state(*msg);
455            }
456            ControlPacket::UserRemove(msg) => {
457                state.write().unwrap().remove_user(*msg);
458            }
459            ControlPacket::ChannelState(msg) => {
460                if let Server::Connecting(sb) = state.write().unwrap().server_mut() {
461                    sb.channel_state(*msg);
462                }
463            }
464            ControlPacket::ChannelRemove(msg) => match state.write().unwrap().server_mut() {
465                Server::Connecting(sb) => sb.channel_remove(*msg),
466                Server::Connected(server) => server.channel_remove(*msg),
467                Server::Disconnected => warn!("Got ChannelRemove packet while disconnected"),
468            },
469            ControlPacket::UDPTunnel(msg) => {
470                match *msg {
471                    VoicePacket::Ping { .. } => {}
472                    VoicePacket::Audio {
473                        session_id,
474                        // seq_num,
475                        payload,
476                        // position_info,
477                        ..
478                    } => {
479                        state.read().unwrap().audio_output().decode_packet_payload(
480                            VoiceStreamType::Tcp,
481                            session_id,
482                            payload,
483                        );
484                    }
485                }
486            }
487            ControlPacket::Ping(msg) => {
488                trace!("Received Ping {:?}", *msg);
489
490                let late = msg.get_late();
491                let lost = msg.get_lost();
492                let resync = msg.get_resync();
493
494                let late = late - last_late;
495                let lost = lost - last_lost;
496                let resync = resync - last_resync;
497
498                last_late += late;
499                last_lost += lost;
500                last_resync += resync;
501
502                macro_rules! format_if_nonzero {
503                    ($value:expr) => {
504                        if $value != 0 {
505                            format!("\n  {}: {}", stringify!($value), $value)
506                        } else {
507                            String::new()
508                        }
509                    };
510                }
511
512                if late != 0 || lost != 0 || resync != 0 {
513                    debug!(
514                        "Ping:{}{}{}",
515                        format_if_nonzero!(late),
516                        format_if_nonzero!(lost),
517                        format_if_nonzero!(resync),
518                    );
519                }
520            }
521            packet => {
522                debug!("Received unhandled ControlPacket {:#?}", packet);
523            }
524        }
525    }
526    Ok(())
527}