1use crate::error::{ServerSendError, TcpError};
2use crate::network::ConnectionInfo;
3use crate::notifications;
4use crate::state::server::Server;
5use crate::state::{State, StatePhase};
6
7use futures_util::select;
8use futures_util::stream::{SplitSink, SplitStream, Stream};
9use futures_util::{FutureExt, SinkExt, StreamExt};
10use log::*;
11use mumble_protocol::control::{msgs, ClientControlCodec, ControlCodec, ControlPacket};
12use mumble_protocol::crypt::ClientCryptState;
13use mumble_protocol::voice::VoicePacket;
14use mumble_protocol::{Clientbound, Serverbound};
15use mumlib::command::MumbleEventKind;
16use std::collections::HashMap;
17use std::convert::Into;
18use std::fmt::Debug;
19use std::net::SocketAddr;
20use std::sync::{Arc, RwLock};
21use tokio::net::TcpStream;
22use tokio::sync::{mpsc, watch, Mutex};
23use tokio::time::{self, Duration};
24use tokio_native_tls::{TlsConnector, TlsStream};
25use tokio_util::codec::{Decoder, Framed};
26
27use super::{run_until, VoiceStreamType};
28
29type TcpSender = SplitSink<
30 Framed<TlsStream<TcpStream>, ControlCodec<Serverbound, Clientbound>>,
31 ControlPacket<Serverbound>,
32>;
33type TcpReceiver =
34 SplitStream<Framed<TlsStream<TcpStream>, ControlCodec<Serverbound, Clientbound>>>;
35
36pub(crate) type TcpEventCallback = Box<dyn FnOnce(TcpEventData<'_>)>;
37pub(crate) type TcpEventSubscriber = Box<dyn FnMut(TcpEventData<'_>) -> bool>; #[derive(Debug, Clone, Copy, Hash, Eq, PartialEq)]
41pub enum DisconnectedReason {
42 InvalidTls,
43 User,
44 TcpError,
45}
46
47#[derive(Debug, Clone, Copy, Hash, Eq, PartialEq)]
49pub enum TcpEvent {
50 Connected, Disconnected(DisconnectedReason), TextMessage, }
54
55#[derive(Clone, Debug)]
63pub enum TcpEventData<'a> {
64 Connected(Result<&'a msgs::ServerSync, mumlib::Error>),
65 Disconnected(DisconnectedReason),
66 TextMessage(&'a msgs::TextMessage),
67}
68
69impl From<&TcpEventData<'_>> for TcpEvent {
70 fn from(t: &TcpEventData<'_>) -> Self {
71 match t {
72 TcpEventData::Connected(_) => TcpEvent::Connected,
73 TcpEventData::Disconnected(reason) => TcpEvent::Disconnected(*reason),
74 TcpEventData::TextMessage(_) => TcpEvent::TextMessage,
75 }
76 }
77}
78
79#[derive(Clone, Default)]
80pub struct TcpEventQueue {
81 callbacks: Arc<RwLock<HashMap<TcpEvent, Vec<TcpEventCallback>>>>,
82 subscribers: Arc<RwLock<HashMap<TcpEvent, Vec<TcpEventSubscriber>>>>,
83}
84
85impl TcpEventQueue {
86 pub fn new() -> Self {
88 Self {
89 callbacks: Arc::new(RwLock::new(HashMap::new())),
90 subscribers: Arc::new(RwLock::new(HashMap::new())),
91 }
92 }
93
94 pub fn register_callback(&self, at: TcpEvent, callback: TcpEventCallback) {
96 self.callbacks
97 .write()
98 .unwrap()
99 .entry(at)
100 .or_default()
101 .push(callback);
102 }
103
104 pub fn register_subscriber(&self, at: TcpEvent, callback: TcpEventSubscriber) {
106 self.subscribers
107 .write()
108 .unwrap()
109 .entry(at)
110 .or_default()
111 .push(callback);
112 }
113
114 pub fn resolve(&self, data: TcpEventData<'_>) {
117 if let Some(vec) = self
118 .callbacks
119 .write()
120 .unwrap()
121 .get_mut(&TcpEvent::from(&data))
122 {
123 let old = std::mem::take(vec);
124 for handler in old {
125 handler(data.clone());
126 }
127 }
128 if let Some(vec) = self
129 .subscribers
130 .write()
131 .unwrap()
132 .get_mut(&TcpEvent::from(&data))
133 {
134 let old = std::mem::take(vec);
135 for mut e in old {
136 if e(data.clone()) {
137 vec.push(e)
138 }
139 }
140 }
141 }
142}
143
144impl Debug for TcpEventQueue {
145 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
146 f.debug_struct("TcpEventQueue").finish()
147 }
148}
149
150pub async fn handle(
151 state: Arc<RwLock<State>>,
152 mut connection_info_receiver: watch::Receiver<Option<ConnectionInfo>>,
153 crypt_state_sender: mpsc::Sender<ClientCryptState>,
154 packet_sender: mpsc::UnboundedSender<ControlPacket<Serverbound>>,
155 mut packet_receiver: mpsc::UnboundedReceiver<ControlPacket<Serverbound>>,
156 event_queue: TcpEventQueue,
157) -> Result<(), TcpError> {
158 loop {
159 let connection_info = loop {
160 if connection_info_receiver.changed().await.is_ok() {
161 if let Some(data) = connection_info_receiver.borrow().clone() {
162 break data;
163 }
164 } else {
165 return Err(TcpError::NoConnectionInfoReceived);
166 }
167 };
168 let connect_result = connect(
169 connection_info.socket_addr,
170 connection_info.hostname,
171 connection_info.accept_invalid_cert,
172 )
173 .await;
174
175 let (mut sink, stream) = match connect_result {
176 Ok(ok) => ok,
177 Err(TcpError::TlsConnectError(_)) => {
178 warn!("Invalid TLS");
179 state
180 .read()
181 .unwrap()
182 .broadcast_phase(StatePhase::Disconnected);
183 event_queue.resolve(TcpEventData::Disconnected(DisconnectedReason::InvalidTls));
184 continue;
185 }
186 Err(e) => {
187 return Err(e);
188 }
189 };
190
191 let (username, password) = {
193 let state_lock = state.read().unwrap();
194 (
195 state_lock.username().unwrap().to_string(),
196 state_lock.password().map(|x| x.to_string()),
197 )
198 };
199 authenticate(&mut sink, username, password).await?;
200 let (phase_watcher, input_receiver) = {
201 let state_lock = state.read().unwrap();
202 (
203 state_lock.phase_receiver(),
204 state_lock.audio_input().receiver(),
205 )
206 };
207
208 info!("Logging in...");
209
210 let phase_watcher_inner = phase_watcher.clone();
211
212 let result = run_until(
213 |phase| matches!(phase, StatePhase::Disconnected),
214 async {
215 select! {
216 r = send_pings(packet_sender.clone(), 10).fuse() => r,
217 r = listen(
218 Arc::clone(&state),
219 stream,
220 crypt_state_sender.clone(),
221 event_queue.clone(),
222 ).fuse() => r,
223 r = send_voice(
224 packet_sender.clone(),
225 Arc::clone(&input_receiver),
226 phase_watcher_inner,
227 ).fuse() => r,
228 r = send_packets(sink, &mut packet_receiver).fuse() => r,
229 }
230 },
231 phase_watcher,
232 )
233 .await
234 .unwrap_or(Ok(()));
235
236 match result {
237 Ok(()) => event_queue.resolve(TcpEventData::Disconnected(DisconnectedReason::User)),
238 Err(_) => event_queue.resolve(TcpEventData::Disconnected(DisconnectedReason::TcpError)),
239 }
240
241 debug!("Fully disconnected TCP stream, waiting for new connection info");
242 }
243}
244
245async fn connect(
246 server_addr: SocketAddr,
247 server_host: String,
248 accept_invalid_cert: bool,
249) -> Result<(TcpSender, TcpReceiver), TcpError> {
250 let stream = TcpStream::connect(&server_addr).await?;
251 debug!("TCP connected");
252
253 let mut builder = native_tls::TlsConnector::builder();
254 builder.danger_accept_invalid_certs(accept_invalid_cert);
255 let connector: TlsConnector = builder
256 .build()
257 .map_err(TcpError::TlsConnectorBuilderError)?
258 .into();
259 let tls_stream = connector
260 .connect(&server_host, stream)
261 .await
262 .map_err(TcpError::TlsConnectError)?;
263 debug!("TLS connected");
264
265 Ok(ClientControlCodec::new().framed(tls_stream).split())
267}
268
269async fn authenticate(
270 sink: &mut TcpSender,
271 username: String,
272 password: Option<String>,
273) -> Result<(), TcpError> {
274 let mut msg = msgs::Authenticate::new();
275 msg.set_username(username);
276 if let Some(password) = password {
277 msg.set_password(password);
278 }
279 msg.set_opus(true);
280 sink.send(msg.into()).await?;
281 Ok(())
282}
283
284async fn send_pings(
285 packet_sender: mpsc::UnboundedSender<ControlPacket<Serverbound>>,
286 delay_seconds: u64,
287) -> Result<(), TcpError> {
288 let mut interval = time::interval(Duration::from_secs(delay_seconds));
289 loop {
290 interval.tick().await;
291 trace!("Sending TCP ping");
292 let msg = msgs::Ping::new();
293 packet_sender.send(msg.into())?;
294 }
295}
296
297async fn send_packets(
298 mut sink: TcpSender,
299 packet_receiver: &mut mpsc::UnboundedReceiver<ControlPacket<Serverbound>>,
300) -> Result<(), TcpError> {
301 loop {
302 let packet = packet_receiver.recv().await.unwrap();
304 sink.send(packet).await?;
305 }
306}
307
308async fn send_voice(
309 packet_sender: mpsc::UnboundedSender<ControlPacket<Serverbound>>,
310 receiver: Arc<Mutex<Box<(dyn Stream<Item = VoicePacket<Serverbound>> + Unpin)>>>,
311 phase_watcher: watch::Receiver<StatePhase>,
312) -> Result<(), TcpError> {
313 loop {
314 let mut inner_phase_watcher = phase_watcher.clone();
315 loop {
316 inner_phase_watcher.changed().await.unwrap();
317 if matches!(
318 *inner_phase_watcher.borrow(),
319 StatePhase::Connected(VoiceStreamType::Tcp)
320 ) {
321 break;
322 }
323 }
324 run_until(
325 |phase| !matches!(phase, StatePhase::Connected(VoiceStreamType::Tcp)),
326 async {
327 loop {
328 packet_sender.send(
329 receiver
330 .lock()
331 .await
332 .next()
333 .await
334 .expect("No audio stream")
335 .into(),
336 )?;
337 }
338 },
339 inner_phase_watcher.clone(),
340 )
341 .await
342 .unwrap_or(Ok::<(), ServerSendError>(()))?;
343 }
344}
345
346async fn listen(
347 state: Arc<RwLock<State>>,
348 mut stream: TcpReceiver,
349 crypt_state_sender: mpsc::Sender<ClientCryptState>,
350 event_queue: TcpEventQueue,
351) -> Result<(), TcpError> {
352 let mut crypt_state = None;
353 let mut crypt_state_sender = Some(crypt_state_sender);
354
355 let mut last_late = 0;
356 let mut last_lost = 0;
357 let mut last_resync = 0;
358
359 loop {
360 let packet = match stream.next().await {
361 Some(Ok(packet)) => packet,
362 Some(Err(e)) => {
363 error!("TCP error: {:?}", e);
364 continue; }
366 None => {
367 warn!("TCP stream gone");
370 state
371 .read()
372 .unwrap()
373 .broadcast_phase(StatePhase::Disconnected);
374 break;
375 }
376 };
377 match packet {
378 ControlPacket::TextMessage(msg) => {
379 let mut state = state.write().unwrap();
380 let server = state.server();
381 let user = (if let Server::Connected(s) = server {
382 Some(s)
383 } else {
384 None
385 })
386 .and_then(|server| server.users().get(&msg.get_actor()))
387 .map(|user| user.name());
388 if let Some(user) = user {
389 notifications::send(format!("{}: {}", user, msg.get_message()));
390 let user = user.to_string();
392 state.push_event(MumbleEventKind::TextMessageReceived(user))
393 }
395 state.register_message((msg.get_message().to_owned(), msg.get_actor()));
396 drop(state);
397 event_queue.resolve(TcpEventData::TextMessage(&*msg));
398 }
399 ControlPacket::CryptSetup(msg) => {
400 debug!("Crypt setup");
401 crypt_state = Some(ClientCryptState::new_from(
403 msg.get_key()
404 .try_into()
405 .expect("Server sent private key with incorrect size"),
406 msg.get_client_nonce()
407 .try_into()
408 .expect("Server sent client_nonce with incorrect size"),
409 msg.get_server_nonce()
410 .try_into()
411 .expect("Server sent server_nonce with incorrect size"),
412 ));
413 }
414 ControlPacket::ServerSync(msg) => {
415 info!("Logged in");
416 if let Some(sender) = crypt_state_sender.take() {
417 let _ = sender
418 .send(
419 crypt_state
420 .take()
421 .expect("Server didn't send us any CryptSetup packet!"),
422 )
423 .await;
424 }
425 let mut state = state.write().unwrap();
426 let server = state.server_mut();
427 if let Server::Connecting(sb) = server {
428 let s = sb.clone().server_sync(*msg.clone());
429 *server = Server::Connected(s);
430 state.initialized();
431 } else {
432 warn!(
433 "Got a ServerSync packet while not connecting. Current state is:\n{:#?}",
434 server
435 );
436 }
437 drop(state);
438 event_queue.resolve(TcpEventData::Connected(Ok(&msg)));
439 }
440 ControlPacket::Reject(msg) => {
441 debug!("Login rejected: {:?}", msg);
442 match msg.get_field_type() {
443 msgs::Reject_RejectType::WrongServerPW => {
444 event_queue.resolve(TcpEventData::Connected(Err(
445 mumlib::Error::InvalidServerPassword,
446 )));
447 }
448 ty => {
449 warn!("Unhandled reject type: {:?}", ty);
450 }
451 }
452 }
453 ControlPacket::UserState(msg) => {
454 state.write().unwrap().user_state(*msg);
455 }
456 ControlPacket::UserRemove(msg) => {
457 state.write().unwrap().remove_user(*msg);
458 }
459 ControlPacket::ChannelState(msg) => {
460 if let Server::Connecting(sb) = state.write().unwrap().server_mut() {
461 sb.channel_state(*msg);
462 }
463 }
464 ControlPacket::ChannelRemove(msg) => match state.write().unwrap().server_mut() {
465 Server::Connecting(sb) => sb.channel_remove(*msg),
466 Server::Connected(server) => server.channel_remove(*msg),
467 Server::Disconnected => warn!("Got ChannelRemove packet while disconnected"),
468 },
469 ControlPacket::UDPTunnel(msg) => {
470 match *msg {
471 VoicePacket::Ping { .. } => {}
472 VoicePacket::Audio {
473 session_id,
474 payload,
476 ..
478 } => {
479 state.read().unwrap().audio_output().decode_packet_payload(
480 VoiceStreamType::Tcp,
481 session_id,
482 payload,
483 );
484 }
485 }
486 }
487 ControlPacket::Ping(msg) => {
488 trace!("Received Ping {:?}", *msg);
489
490 let late = msg.get_late();
491 let lost = msg.get_lost();
492 let resync = msg.get_resync();
493
494 let late = late - last_late;
495 let lost = lost - last_lost;
496 let resync = resync - last_resync;
497
498 last_late += late;
499 last_lost += lost;
500 last_resync += resync;
501
502 macro_rules! format_if_nonzero {
503 ($value:expr) => {
504 if $value != 0 {
505 format!("\n {}: {}", stringify!($value), $value)
506 } else {
507 String::new()
508 }
509 };
510 }
511
512 if late != 0 || lost != 0 || resync != 0 {
513 debug!(
514 "Ping:{}{}{}",
515 format_if_nonzero!(late),
516 format_if_nonzero!(lost),
517 format_if_nonzero!(resync),
518 );
519 }
520 }
521 packet => {
522 debug!("Received unhandled ControlPacket {:#?}", packet);
523 }
524 }
525 }
526 Ok(())
527}