mum_cli/network/
udp.rs

1use crate::error::UdpError;
2use crate::network::ConnectionInfo;
3use crate::state::{State, StatePhase};
4
5use futures_util::future::join4;
6use futures_util::stream::{SplitSink, SplitStream, Stream};
7use futures_util::{FutureExt, SinkExt, StreamExt};
8use log::*;
9use mumble_protocol::crypt::ClientCryptState;
10use mumble_protocol::ping::{PingPacket, PongPacket};
11use mumble_protocol::voice::VoicePacket;
12use mumble_protocol::Serverbound;
13use std::collections::{hash_map::Entry, HashMap};
14use std::net::{Ipv6Addr, SocketAddr};
15use std::sync::{
16    atomic::{AtomicU64, Ordering},
17    Arc, RwLock,
18};
19use tokio::sync::{mpsc, oneshot, watch, Mutex};
20use tokio::time::{interval, timeout, Duration};
21use tokio::{join, net::UdpSocket};
22use tokio_util::udp::UdpFramed;
23
24use super::{run_until, VoiceStreamType};
25
26pub type PingRequest = (u64, SocketAddr, Box<dyn FnOnce(Option<PongPacket>) + Send>);
27
28type UdpSender = SplitSink<UdpFramed<ClientCryptState>, (VoicePacket<Serverbound>, SocketAddr)>;
29type UdpReceiver = SplitStream<UdpFramed<ClientCryptState>>;
30
31pub async fn handle(
32    state: Arc<RwLock<State>>,
33    mut connection_info_receiver: watch::Receiver<Option<ConnectionInfo>>,
34    mut crypt_state_receiver: mpsc::Receiver<ClientCryptState>,
35) -> Result<(), UdpError> {
36    let receiver = state.read().unwrap().audio_input().receiver();
37
38    loop {
39        let connection_info = loop {
40            if connection_info_receiver.changed().await.is_ok() {
41                if let Some(data) = connection_info_receiver.borrow().clone() {
42                    break data;
43                }
44            } else {
45                return Err(UdpError::NoConnectionInfoReceived);
46            }
47        };
48        let (sink, source) = connect(&mut crypt_state_receiver).await?;
49
50        let sink = Arc::new(Mutex::new(sink));
51        let source = Arc::new(Mutex::new(source));
52
53        let phase_watcher = state.read().unwrap().phase_receiver();
54        let last_ping_recv = AtomicU64::new(0);
55
56        run_until(
57            |phase| matches!(phase, StatePhase::Disconnected),
58            join4(
59                listen(Arc::clone(&state), Arc::clone(&source), &last_ping_recv),
60                send_voice(
61                    Arc::clone(&sink),
62                    connection_info.socket_addr,
63                    phase_watcher.clone(),
64                    Arc::clone(&receiver),
65                ),
66                send_pings(
67                    Arc::clone(&state),
68                    Arc::clone(&sink),
69                    connection_info.socket_addr,
70                    &last_ping_recv,
71                ),
72                new_crypt_state(&mut crypt_state_receiver, sink, source),
73            )
74            .map(|_| ()),
75            phase_watcher,
76        )
77        .await;
78
79        debug!("Fully disconnected UDP stream, waiting for new connection info");
80    }
81}
82
83async fn connect(
84    crypt_state: &mut mpsc::Receiver<ClientCryptState>,
85) -> Result<(UdpSender, UdpReceiver), UdpError> {
86    // Bind UDP socket
87    let udp_socket = UdpSocket::bind((Ipv6Addr::from(0u128), 0u16)).await?;
88
89    // Wait for initial CryptState
90    let crypt_state = match crypt_state.recv().await {
91        Some(crypt_state) => crypt_state,
92        // disconnected before we received the CryptSetup packet, oh well
93        None => return Err(UdpError::DisconnectBeforeCryptSetup),
94    };
95    debug!("UDP connected");
96
97    // Wrap the raw UDP packets in Mumble's crypto and voice codec (CryptState does both)
98    Ok(UdpFramed::new(udp_socket, crypt_state).split())
99}
100
101async fn new_crypt_state(
102    crypt_state: &mut mpsc::Receiver<ClientCryptState>,
103    sink: Arc<Mutex<UdpSender>>,
104    source: Arc<Mutex<UdpReceiver>>,
105) {
106    loop {
107        if let Some(crypt_state) = crypt_state.recv().await {
108            info!("Received new crypt state");
109            let udp_socket = UdpSocket::bind((Ipv6Addr::from(0u128), 0u16))
110                .await
111                .expect("Failed to bind UDP socket");
112            let (new_sink, new_source) = UdpFramed::new(udp_socket, crypt_state).split();
113            *sink.lock().await = new_sink;
114            *source.lock().await = new_source;
115        }
116    }
117}
118
119async fn listen(
120    state: Arc<RwLock<State>>,
121    source: Arc<Mutex<UdpReceiver>>,
122    last_ping_recv: &AtomicU64,
123) {
124    loop {
125        let packet = source.lock().await.next().await.unwrap();
126        let (packet, _src_addr) = match packet {
127            Ok(packet) => packet,
128            Err(err) => {
129                warn!("Got an invalid UDP packet: {}", err);
130                // To be expected, considering this is the internet, just ignore it
131                continue;
132            }
133        };
134        match packet {
135            VoicePacket::Ping { timestamp } => {
136                state
137                    .read()
138                    .unwrap()
139                    .broadcast_phase(StatePhase::Connected(VoiceStreamType::Udp));
140                last_ping_recv.store(timestamp, Ordering::Relaxed);
141            }
142            VoicePacket::Audio {
143                session_id,
144                // seq_num,
145                payload,
146                // position_info,
147                ..
148            } => {
149                state.read().unwrap().audio_output().decode_packet_payload(
150                    VoiceStreamType::Udp,
151                    session_id,
152                    payload,
153                );
154            }
155        }
156    }
157}
158
159async fn send_pings(
160    state: Arc<RwLock<State>>,
161    sink: Arc<Mutex<UdpSender>>,
162    server_addr: SocketAddr,
163    last_ping_recv: &AtomicU64,
164) {
165    let mut last_send = None;
166    let mut interval = interval(Duration::from_millis(1000));
167
168    loop {
169        interval.tick().await;
170        let last_recv = last_ping_recv.load(Ordering::Relaxed);
171        if last_send.is_some() && last_send.unwrap() != last_recv {
172            debug!("Sending TCP voice");
173            state
174                .read()
175                .unwrap()
176                .broadcast_phase(StatePhase::Connected(VoiceStreamType::Tcp));
177        }
178        match sink
179            .lock()
180            .await
181            .send((
182                VoicePacket::Ping {
183                    timestamp: last_recv + 1,
184                },
185                server_addr,
186            ))
187            .await
188        {
189            Ok(_) => {
190                last_send = Some(last_recv + 1);
191            }
192            Err(e) => {
193                debug!("Error sending UDP ping: {}", e);
194            }
195        }
196    }
197}
198
199async fn send_voice(
200    sink: Arc<Mutex<UdpSender>>,
201    server_addr: SocketAddr,
202    phase_watcher: watch::Receiver<StatePhase>,
203    receiver: Arc<Mutex<Box<(dyn Stream<Item = VoicePacket<Serverbound>> + Unpin)>>>,
204) {
205    loop {
206        let mut inner_phase_watcher = phase_watcher.clone();
207        loop {
208            inner_phase_watcher.changed().await.unwrap();
209            if matches!(
210                *inner_phase_watcher.borrow(),
211                StatePhase::Connected(VoiceStreamType::Udp)
212            ) {
213                break;
214            }
215        }
216        run_until(
217            |phase| !matches!(phase, StatePhase::Connected(VoiceStreamType::Udp)),
218            async {
219                let mut receiver = receiver.lock().await;
220                loop {
221                    let sending = (receiver.next().await.unwrap(), server_addr);
222                    sink.lock().await.send(sending).await.unwrap();
223                }
224            },
225            phase_watcher.clone(),
226        )
227        .await;
228    }
229}
230
231pub async fn handle_pings(mut ping_request_receiver: mpsc::UnboundedReceiver<PingRequest>) {
232    let udp_socket = UdpSocket::bind((Ipv6Addr::from(0u128), 0u16))
233        .await
234        .expect("Failed to bind UDP socket");
235
236    let pending = Mutex::new(HashMap::new());
237
238    let sender = async {
239        while let Some((id, socket_addr, handle)) = ping_request_receiver.recv().await {
240            debug!("Sending ping with id {} to {}", id, socket_addr);
241            let packet = PingPacket { id };
242            let packet: [u8; 12] = packet.into();
243            udp_socket.send_to(&packet, &socket_addr).await.unwrap();
244            let (tx, rx) = oneshot::channel();
245            match pending.lock().await.entry(id) {
246                Entry::Occupied(_) => {
247                    warn!("Tried to send duplicate ping with id {}", id);
248                    continue;
249                }
250                Entry::Vacant(v) => {
251                    v.insert(tx);
252                }
253            }
254
255            tokio::spawn(async move {
256                handle(match timeout(Duration::from_secs(1), rx).await {
257                    Ok(Ok(r)) => Some(r),
258                    Ok(Err(_)) => {
259                        warn!(
260                            "Ping response sender for server {}, ping id {} dropped",
261                            socket_addr, id
262                        );
263                        None
264                    }
265                    Err(_) => {
266                        debug!(
267                            "Server {} timed out when sending ping id {}",
268                            socket_addr, id
269                        );
270                        None
271                    }
272                });
273            });
274        }
275    };
276
277    let receiver = async {
278        let mut buf = vec![0; 24];
279
280        while let Ok(read) = udp_socket.recv(&mut buf).await {
281            if read != 24 {
282                warn!("Ping response had length {}, expected 24", read);
283                continue;
284            }
285
286            let packet = PongPacket::try_from(buf.as_slice()).unwrap();
287
288            match pending.lock().await.entry(packet.id) {
289                Entry::Occupied(o) => {
290                    let id = *o.key();
291                    if o.remove().send(packet).is_err() {
292                        debug!("Received response to ping with id {} too late", id);
293                    }
294                }
295                Entry::Vacant(v) => {
296                    warn!("Received ping with id {} that we didn't send", v.key());
297                }
298            }
299        }
300    };
301
302    debug!("Waiting for ping requests");
303    join!(sender, receiver);
304}