1use crate::error::UdpError;
2use crate::network::ConnectionInfo;
3use crate::state::{State, StatePhase};
4
5use futures_util::future::join4;
6use futures_util::stream::{SplitSink, SplitStream, Stream};
7use futures_util::{FutureExt, SinkExt, StreamExt};
8use log::*;
9use mumble_protocol::crypt::ClientCryptState;
10use mumble_protocol::ping::{PingPacket, PongPacket};
11use mumble_protocol::voice::VoicePacket;
12use mumble_protocol::Serverbound;
13use std::collections::{hash_map::Entry, HashMap};
14use std::net::{Ipv6Addr, SocketAddr};
15use std::sync::{
16 atomic::{AtomicU64, Ordering},
17 Arc, RwLock,
18};
19use tokio::sync::{mpsc, oneshot, watch, Mutex};
20use tokio::time::{interval, timeout, Duration};
21use tokio::{join, net::UdpSocket};
22use tokio_util::udp::UdpFramed;
23
24use super::{run_until, VoiceStreamType};
25
26pub type PingRequest = (u64, SocketAddr, Box<dyn FnOnce(Option<PongPacket>) + Send>);
27
28type UdpSender = SplitSink<UdpFramed<ClientCryptState>, (VoicePacket<Serverbound>, SocketAddr)>;
29type UdpReceiver = SplitStream<UdpFramed<ClientCryptState>>;
30
31pub async fn handle(
32 state: Arc<RwLock<State>>,
33 mut connection_info_receiver: watch::Receiver<Option<ConnectionInfo>>,
34 mut crypt_state_receiver: mpsc::Receiver<ClientCryptState>,
35) -> Result<(), UdpError> {
36 let receiver = state.read().unwrap().audio_input().receiver();
37
38 loop {
39 let connection_info = loop {
40 if connection_info_receiver.changed().await.is_ok() {
41 if let Some(data) = connection_info_receiver.borrow().clone() {
42 break data;
43 }
44 } else {
45 return Err(UdpError::NoConnectionInfoReceived);
46 }
47 };
48 let (sink, source) = connect(&mut crypt_state_receiver).await?;
49
50 let sink = Arc::new(Mutex::new(sink));
51 let source = Arc::new(Mutex::new(source));
52
53 let phase_watcher = state.read().unwrap().phase_receiver();
54 let last_ping_recv = AtomicU64::new(0);
55
56 run_until(
57 |phase| matches!(phase, StatePhase::Disconnected),
58 join4(
59 listen(Arc::clone(&state), Arc::clone(&source), &last_ping_recv),
60 send_voice(
61 Arc::clone(&sink),
62 connection_info.socket_addr,
63 phase_watcher.clone(),
64 Arc::clone(&receiver),
65 ),
66 send_pings(
67 Arc::clone(&state),
68 Arc::clone(&sink),
69 connection_info.socket_addr,
70 &last_ping_recv,
71 ),
72 new_crypt_state(&mut crypt_state_receiver, sink, source),
73 )
74 .map(|_| ()),
75 phase_watcher,
76 )
77 .await;
78
79 debug!("Fully disconnected UDP stream, waiting for new connection info");
80 }
81}
82
83async fn connect(
84 crypt_state: &mut mpsc::Receiver<ClientCryptState>,
85) -> Result<(UdpSender, UdpReceiver), UdpError> {
86 let udp_socket = UdpSocket::bind((Ipv6Addr::from(0u128), 0u16)).await?;
88
89 let crypt_state = match crypt_state.recv().await {
91 Some(crypt_state) => crypt_state,
92 None => return Err(UdpError::DisconnectBeforeCryptSetup),
94 };
95 debug!("UDP connected");
96
97 Ok(UdpFramed::new(udp_socket, crypt_state).split())
99}
100
101async fn new_crypt_state(
102 crypt_state: &mut mpsc::Receiver<ClientCryptState>,
103 sink: Arc<Mutex<UdpSender>>,
104 source: Arc<Mutex<UdpReceiver>>,
105) {
106 loop {
107 if let Some(crypt_state) = crypt_state.recv().await {
108 info!("Received new crypt state");
109 let udp_socket = UdpSocket::bind((Ipv6Addr::from(0u128), 0u16))
110 .await
111 .expect("Failed to bind UDP socket");
112 let (new_sink, new_source) = UdpFramed::new(udp_socket, crypt_state).split();
113 *sink.lock().await = new_sink;
114 *source.lock().await = new_source;
115 }
116 }
117}
118
119async fn listen(
120 state: Arc<RwLock<State>>,
121 source: Arc<Mutex<UdpReceiver>>,
122 last_ping_recv: &AtomicU64,
123) {
124 loop {
125 let packet = source.lock().await.next().await.unwrap();
126 let (packet, _src_addr) = match packet {
127 Ok(packet) => packet,
128 Err(err) => {
129 warn!("Got an invalid UDP packet: {}", err);
130 continue;
132 }
133 };
134 match packet {
135 VoicePacket::Ping { timestamp } => {
136 state
137 .read()
138 .unwrap()
139 .broadcast_phase(StatePhase::Connected(VoiceStreamType::Udp));
140 last_ping_recv.store(timestamp, Ordering::Relaxed);
141 }
142 VoicePacket::Audio {
143 session_id,
144 payload,
146 ..
148 } => {
149 state.read().unwrap().audio_output().decode_packet_payload(
150 VoiceStreamType::Udp,
151 session_id,
152 payload,
153 );
154 }
155 }
156 }
157}
158
159async fn send_pings(
160 state: Arc<RwLock<State>>,
161 sink: Arc<Mutex<UdpSender>>,
162 server_addr: SocketAddr,
163 last_ping_recv: &AtomicU64,
164) {
165 let mut last_send = None;
166 let mut interval = interval(Duration::from_millis(1000));
167
168 loop {
169 interval.tick().await;
170 let last_recv = last_ping_recv.load(Ordering::Relaxed);
171 if last_send.is_some() && last_send.unwrap() != last_recv {
172 debug!("Sending TCP voice");
173 state
174 .read()
175 .unwrap()
176 .broadcast_phase(StatePhase::Connected(VoiceStreamType::Tcp));
177 }
178 match sink
179 .lock()
180 .await
181 .send((
182 VoicePacket::Ping {
183 timestamp: last_recv + 1,
184 },
185 server_addr,
186 ))
187 .await
188 {
189 Ok(_) => {
190 last_send = Some(last_recv + 1);
191 }
192 Err(e) => {
193 debug!("Error sending UDP ping: {}", e);
194 }
195 }
196 }
197}
198
199async fn send_voice(
200 sink: Arc<Mutex<UdpSender>>,
201 server_addr: SocketAddr,
202 phase_watcher: watch::Receiver<StatePhase>,
203 receiver: Arc<Mutex<Box<(dyn Stream<Item = VoicePacket<Serverbound>> + Unpin)>>>,
204) {
205 loop {
206 let mut inner_phase_watcher = phase_watcher.clone();
207 loop {
208 inner_phase_watcher.changed().await.unwrap();
209 if matches!(
210 *inner_phase_watcher.borrow(),
211 StatePhase::Connected(VoiceStreamType::Udp)
212 ) {
213 break;
214 }
215 }
216 run_until(
217 |phase| !matches!(phase, StatePhase::Connected(VoiceStreamType::Udp)),
218 async {
219 let mut receiver = receiver.lock().await;
220 loop {
221 let sending = (receiver.next().await.unwrap(), server_addr);
222 sink.lock().await.send(sending).await.unwrap();
223 }
224 },
225 phase_watcher.clone(),
226 )
227 .await;
228 }
229}
230
231pub async fn handle_pings(mut ping_request_receiver: mpsc::UnboundedReceiver<PingRequest>) {
232 let udp_socket = UdpSocket::bind((Ipv6Addr::from(0u128), 0u16))
233 .await
234 .expect("Failed to bind UDP socket");
235
236 let pending = Mutex::new(HashMap::new());
237
238 let sender = async {
239 while let Some((id, socket_addr, handle)) = ping_request_receiver.recv().await {
240 debug!("Sending ping with id {} to {}", id, socket_addr);
241 let packet = PingPacket { id };
242 let packet: [u8; 12] = packet.into();
243 udp_socket.send_to(&packet, &socket_addr).await.unwrap();
244 let (tx, rx) = oneshot::channel();
245 match pending.lock().await.entry(id) {
246 Entry::Occupied(_) => {
247 warn!("Tried to send duplicate ping with id {}", id);
248 continue;
249 }
250 Entry::Vacant(v) => {
251 v.insert(tx);
252 }
253 }
254
255 tokio::spawn(async move {
256 handle(match timeout(Duration::from_secs(1), rx).await {
257 Ok(Ok(r)) => Some(r),
258 Ok(Err(_)) => {
259 warn!(
260 "Ping response sender for server {}, ping id {} dropped",
261 socket_addr, id
262 );
263 None
264 }
265 Err(_) => {
266 debug!(
267 "Server {} timed out when sending ping id {}",
268 socket_addr, id
269 );
270 None
271 }
272 });
273 });
274 }
275 };
276
277 let receiver = async {
278 let mut buf = vec![0; 24];
279
280 while let Ok(read) = udp_socket.recv(&mut buf).await {
281 if read != 24 {
282 warn!("Ping response had length {}, expected 24", read);
283 continue;
284 }
285
286 let packet = PongPacket::try_from(buf.as_slice()).unwrap();
287
288 match pending.lock().await.entry(packet.id) {
289 Entry::Occupied(o) => {
290 let id = *o.key();
291 if o.remove().send(packet).is_err() {
292 debug!("Received response to ping with id {} too late", id);
293 }
294 }
295 Entry::Vacant(v) => {
296 warn!("Received ping with id {} that we didn't send", v.key());
297 }
298 }
299 }
300 };
301
302 debug!("Waiting for ping requests");
303 join!(sender, receiver);
304}