cloudpub_client/
client.rs

1use anyhow::{anyhow, bail, Context, Result};
2use backoff::backoff::Backoff;
3use cloudpub_common::data::DataChannel;
4use cloudpub_common::fair_channel::{fair_channel, FairSender};
5use cloudpub_common::protocol::message::Message;
6use cloudpub_common::protocol::{
7    AgentInfo, ConnectState, Data, DataChannelAck, DataChannelData, DataChannelDataUdp,
8    DataChannelEof, ErrorInfo, ErrorKind, HeartBeat, Protocol,
9};
10use cloudpub_common::transport::{AddrMaybeCached, SocketOpts, Transport, WebsocketTransport};
11use cloudpub_common::utils::{
12    get_platform, proto_to_socket_addr, socket_addr_to_proto, udp_connect,
13};
14use cloudpub_common::VERSION;
15use dashmap::DashMap;
16use parking_lot::RwLock;
17use std::net::SocketAddr;
18use std::sync::Arc;
19use tokio::io::{AsyncReadExt, AsyncWriteExt};
20use tokio::net::{TcpStream, UdpSocket};
21use tokio::sync::mpsc;
22use tokio::time::{self, Duration, Instant};
23use tracing::{debug, error, info, trace, warn};
24
25use cloudpub_common::constants::{
26    run_control_chan_backoff, CONTROL_CHANNEL_SIZE, DATA_BUFFER_SIZE, DATA_CHANNEL_SIZE,
27    DEFAULT_CLIENT_RETRY_INTERVAL_SECS, UDP_BUFFER_SIZE, UDP_TIMEOUT,
28};
29use futures::future::FutureExt;
30
31use crate::config::{ClientConfig, ClientOpts};
32use crate::upgrade::handle_upgrade_available;
33use bytes::Bytes;
34use cloudpub_common::transport::ProtobufStream;
35
36#[cfg(feature = "plugins")]
37use crate::plugins::plugin_trait::PluginHandle;
38#[cfg(feature = "plugins")]
39use crate::plugins::registry::PluginRegistry;
40
41type Service = Arc<DataChannel>;
42
43type Services = Arc<DashMap<String, Service>>;
44
45// Holds the state of a client
46struct Client<T: Transport> {
47    config: Arc<RwLock<ClientConfig>>,
48    opts: ClientOpts,
49    services: Services,
50    transport: Arc<T>,
51    connected: bool,
52    #[cfg(feature = "plugins")]
53    plugin_processes: Arc<DashMap<String, PluginHandle>>,
54    data_channels: Arc<DashMap<u32, Arc<DataChannel>>>,
55}
56
57impl<T: 'static + Transport> Client<T> {
58    // Create a Client from `[client]` config block
59    async fn from(config: Arc<RwLock<ClientConfig>>, opts: ClientOpts) -> Result<Client<T>> {
60        let transport = Arc::new(
61            T::new(&config.clone().read().transport)
62                .with_context(|| "Failed to create the transport")?,
63        );
64        Ok(Client {
65            config,
66            opts,
67            services: Default::default(),
68            transport,
69            connected: false,
70            #[cfg(feature = "plugins")]
71            plugin_processes: Arc::new(DashMap::new()),
72            data_channels: Arc::new(DashMap::new()),
73        })
74    }
75
76    // The entrypoint of Client
77    async fn run(
78        &mut self,
79        mut command_rx: mpsc::Receiver<Message>,
80        result_tx: mpsc::Sender<Message>,
81    ) -> Result<()> {
82        let transport = self.transport.clone();
83
84        let mut retry_backoff = run_control_chan_backoff(DEFAULT_CLIENT_RETRY_INTERVAL_SECS);
85
86        let mut start = Instant::now();
87        result_tx
88            .send(Message::ConnectState(ConnectState::Connecting.into()))
89            .await
90            .context("Can't send Connecting event")?;
91        while let Err(err) = self
92            .run_control_channel(transport.clone(), &mut command_rx, &result_tx)
93            .boxed()
94            .await
95        {
96            if result_tx.is_closed() {
97                // The client is shutting down
98                break;
99            }
100
101            if self.connected {
102                result_tx
103                    .send(Message::Error(ErrorInfo {
104                        kind: ErrorKind::HandshakeFailed.into(),
105                        message: crate::t!("error-network"),
106                        guid: String::new(),
107                    }))
108                    .await
109                    .context("Can't send Error event")?;
110                result_tx
111                    .send(Message::ConnectState(ConnectState::Disconnected.into()))
112                    .await
113                    .context("Can't send Disconnected event")?;
114                result_tx
115                    .send(Message::ConnectState(ConnectState::Connecting.into()))
116                    .await
117                    .context("Can't send Connecting event")?;
118                self.connected = false;
119            }
120
121            self.services.clear();
122            #[cfg(feature = "plugins")]
123            self.plugin_processes.clear();
124            self.data_channels.clear();
125
126            if start.elapsed() > Duration::from_secs(3) {
127                // The client runs for at least 3 secs and then disconnects
128                retry_backoff.reset();
129            }
130
131            if let Some(duration) = retry_backoff.next_backoff() {
132                warn!("{:#}. Retry in {:?}...", err, duration);
133                time::sleep(duration).await;
134            }
135
136            start = Instant::now();
137        }
138
139        self.services.clear();
140        #[cfg(feature = "plugins")]
141        self.plugin_processes.clear();
142        self.data_channels.clear();
143
144        Ok(())
145    }
146
147    async fn run_control_channel(
148        &mut self,
149        transport: Arc<T>,
150        command_rx: &mut mpsc::Receiver<Message>,
151        result_tx: &mpsc::Sender<Message>,
152    ) -> Result<()> {
153        let url = self.config.read().server.clone();
154        let port = url.port().unwrap_or(443);
155        let host = url.host_str().context("Failed to get host")?;
156        let mut host_and_port = format!("{}:{}", host, port);
157
158        let (mut conn, _remote_addr) = loop {
159            let mut remote_addr = AddrMaybeCached::new(&host_and_port);
160            remote_addr
161                .resolve()
162                .await
163                .context("Failed to resolve server address")?;
164
165            let mut conn = transport.connect(&remote_addr).await.context(format!(
166                "Failed to connect control channel to {}",
167                &host_and_port
168            ))?;
169
170            self.connected = true;
171
172            T::hint(&conn, SocketOpts::for_control_channel());
173
174            let (email, password) = if let Some(ref cred) = self.opts.credentials {
175                (cred.0.clone(), cred.1.clone())
176            } else {
177                (String::new(), String::new())
178            };
179
180            let token = self
181                .config
182                .read()
183                .token
184                .clone()
185                .unwrap_or_default()
186                .to_string();
187
188            let hwid = self.config.read().get_hwid();
189
190            let agent_info = AgentInfo {
191                agent_id: self.config.read().agent_id.clone(),
192                token,
193                email,
194                password,
195                hostname: hostname::get()?.to_string_lossy().into_owned(),
196                version: VERSION.to_string(),
197                gui: self.opts.gui,
198                platform: get_platform(),
199                hwid,
200                server_host_and_port: host_and_port.clone(),
201                transient: self.opts.transient,
202                secondary: self.opts.secondary,
203                is_service: self.opts.is_service,
204            };
205
206            debug!("Sending hello: {:?}", agent_info);
207
208            let hello_send = Message::AgentHello(agent_info);
209
210            conn.send_message(&hello_send)
211                .await
212                .context("Failed to send hello message")?;
213
214            debug!("Reading ack");
215            match conn
216                .recv_message()
217                .await
218                .context("Failed to read ack message")?
219            {
220                Some(msg) => match msg {
221                    Message::AgentAck(args) => {
222                        if !args.token.is_empty() {
223                            let mut c = self.config.write();
224                            c.token = Some(args.token.as_str().into());
225                            c.save().context("Write config")?;
226                        }
227                        break (conn, remote_addr);
228                    }
229                    Message::Redirect(r) => {
230                        host_and_port = r.host_and_port.clone();
231                        debug!("Redirecting to {}", host_and_port);
232                        continue;
233                    }
234                    Message::Error(err) => {
235                        result_tx
236                            .send(Message::Error(err.clone()))
237                            .await
238                            .context("Can't send server error event")?;
239                        bail!("Error: {:?}", err.kind);
240                    }
241                    v => bail!("Unexpected ack message: {:?}", v),
242                },
243                None => bail!("Connection closed while reading ack message"),
244            };
245        };
246
247        debug!("Control channel established");
248
249        result_tx
250            .send(Message::ConnectState(ConnectState::Connected.into()))
251            .await
252            .context("Can't send Connected event")?;
253
254        let (to_server_tx, mut to_server_rx) = fair_channel::<Message>(CONTROL_CHANNEL_SIZE);
255        // Used to break long running setup
256
257        let heartbeat_timeout = self.config.read().heartbeat_timeout;
258
259        loop {
260            tokio::select! {
261                cmd = to_server_rx.recv() => {
262                    if let Some(cmd) = cmd {
263                        conn.send_message(&cmd).await.context("Failed to send command")?;
264                    }
265                },
266                cmd = command_rx.recv() => {
267                    if let Some(cmd) = cmd {
268                        debug!("Received message: {:?}", cmd);
269                        match cmd {
270                            Message::PerformUpgrade(info) => {
271                                let config_clone = self.config.clone();
272                                if let Err(e) = handle_upgrade_available(
273                                    &info.version,
274                                    config_clone,
275                                    self.opts.gui,
276                                    command_rx,
277                                    result_tx,
278                                )
279                                .await
280                                {
281                                    result_tx.send(Message::Error(ErrorInfo {
282                                        kind: ErrorKind::Fatal.into(),
283                                        message: e.to_string(),
284                                        guid: String::new(),
285                                    }))
286                                    .await
287                                    .context("Can't send Error event")?;
288                                }
289                            }
290                            Message::Stop(_x) => {
291                                info!("Stopping the client");
292                                break;
293                            }
294                            Message::Break(break_msg) => {
295                                info!("Breaking operation for guid: {}", break_msg.guid);
296                                #[cfg(feature = "plugins")]
297                                if let Some((_, handle)) = self.plugin_processes.remove(&break_msg.guid) {
298                                    info!("Dropped plugin handle for guid: {}", break_msg.guid);
299                                    drop(handle);
300                                }
301                            }
302                            cmd => {
303                                conn.send_message(&cmd).await.context("Failed to send message")?;
304                            }
305                        };
306                    } else {
307                        debug!("No more commands, shutting down...");
308                        break;
309                    }
310                },
311                val = conn.recv_message() => {
312                    match val? {
313                        Some(val) => {
314                            match val {
315                                Message::EndpointAck(mut endpoint) => {
316                                    #[cfg(feature = "plugins")]
317                                    {
318                                        let to_server_tx = to_server_tx.clone();
319                                        let config = self.config.clone();
320                                        let opts = self.opts.clone();
321                                        if endpoint.error.is_empty() {
322                                            let protocol: Protocol = endpoint
323                                                .client
324                                                .as_ref()
325                                                .unwrap()
326                                                .local_proto
327                                                .try_into()
328                                                .unwrap_or(Protocol::Tcp);
329                                            if let Some(plugin) = PluginRegistry::new().get(protocol) {
330                                                let handle = PluginHandle::spawn(
331                                                    plugin,
332                                                    endpoint.clone(),
333                                                    config,
334                                                    opts,
335                                                    to_server_tx,
336                                                );
337                                                self.plugin_processes.insert(endpoint.guid.clone(), handle);
338                                            } else {
339                                                endpoint.status = Some("online".into());
340                                                let _ = to_server_tx.send(Message::EndpointStatus(endpoint.clone())).await;
341                                            }
342                                        }
343                                    }
344                                    #[cfg(not(feature = "plugins"))]
345                                    {
346                                        endpoint.status = Some("online".into());
347                                        let _ = to_server_tx.send(Message::EndpointStatus(endpoint.clone())).await;
348                                    }
349                                    result_tx
350                                        .send(Message::EndpointAck(endpoint))
351                                        .await
352                                        .context("Can't send EndpointAck event")?;
353                                }
354
355                                Message::CreateDataChannelWithId(create_msg) => {
356                                    let channel_id = create_msg.channel_id;
357                                    let endpoint = create_msg.endpoint.unwrap();
358
359                                    trace!("Creating data channel {} for endpoint {:?}", channel_id, endpoint.guid);
360
361                                    // Create channels for data flow
362                                    let (to_service_tx, to_service_rx) = mpsc::channel::<Data>(DATA_CHANNEL_SIZE);
363
364                                    // Register the data channel
365                                    let data_channel = Arc::new(DataChannel::new_client(channel_id, to_service_tx.clone()));
366                                    self.data_channels.insert(channel_id, data_channel.clone());
367
368                                    // Check if endpoint handled by plugin server
369                                    let client = endpoint.client.unwrap();
370                                    #[allow(unused_mut)]
371                                    let mut local_addr = format!("{}:{}", client.local_addr, client.local_port);
372                                    #[cfg(feature = "plugins")]
373                                    if let Some(handle) = self.plugin_processes.get(&endpoint.guid) {
374                                        if let Some(port) = handle.value().port() {
375                                            local_addr = format!("127.0.0.1:{}", port);
376                                        }
377                                    }
378
379                                    // Immediately start handling the data channel
380                                    let data_channels = self.data_channels.clone();
381                                    let protocol: Protocol = client.local_proto.try_into().unwrap();
382
383                                    let to_server_tx_cloned = to_server_tx.clone();
384                                    tokio::spawn(async move {
385                                        if let Err(err) = if protocol == Protocol::Udp {
386                                            handle_udp_data_channel(
387                                                data_channel,
388                                                local_addr,
389                                                to_server_tx_cloned.clone(),
390                                                to_service_rx
391                                            ).await
392                                        } else {
393                                            handle_tcp_data_channel(
394                                                data_channel,
395                                                local_addr,
396                                                to_server_tx_cloned.clone(),
397                                                to_service_rx
398                                            ).await
399                                        } {
400                                            error!("DataChannel {{ channel_id: {} }}: {:?}", channel_id, err);
401                                            to_server_tx_cloned
402                                                .send(Message::DataChannelEof(
403                                                        DataChannelEof {
404                                                            channel_id,
405                                                            error: err.to_string()
406                                                    })
407                                                ).await.ok();
408                                        }
409                                        if let Some((_, dc)) = data_channels.remove(&channel_id) { dc.close() }
410                                    });
411                                },
412
413                                Message::DataChannelData(data) => {
414                                    // Forward data to the appropriate data channel
415                                    let to_service_tx = self.data_channels.get(&data.channel_id).map(|ch| ch.data_tx.clone());
416                                    if let Some(tx) = to_service_tx {
417                                        if let Err(err) = tx.send(Data {
418                                            data: data.data.into(),
419                                            socket_addr: None
420                                        }).await {
421                                            self.data_channels.remove(&data.channel_id);
422                                            error!("Error send to data channel {}: {:?}", data.channel_id, err);
423                                        }
424                                    } else {
425                                        trace!("Data channel {} not found, dropping data", data.channel_id);
426                                    }
427                                },
428
429                                Message::DataChannelDataUdp(data) => {
430                                    // Forward UDP data to the appropriate data channel
431                                    let to_service_tx = self.data_channels.get(&data.channel_id).map(|ch| ch.data_tx.clone());
432                                    if let Some(tx) = to_service_tx {
433                                        let socket_addr = data.socket_addr.as_ref()
434                                            .map(proto_to_socket_addr)
435                                            .transpose()
436                                            .unwrap_or_else(|err| {
437                                                error!("Invalid socket address for UDP data channel {}: {:?}", data.channel_id, err);
438                                                None
439                                            });
440
441                                        if let Err(err) = tx.send(Data {
442                                            data: data.data.into(),
443                                            socket_addr,
444                                        }).await {
445                                            self.data_channels.remove(&data.channel_id);
446                                            error!("Error send to UDP data channel {}: {:?}", data.channel_id, err);
447                                        }
448                                    } else {
449                                        trace!("UDP Data channel {} not found, dropping data", data.channel_id);
450                                    }
451                                },
452
453                                Message::DataChannelEof(eof) => {
454                                    // Signal EOF by dropping the data channel
455                                    if let Some((_, dc)) = self.data_channels.remove(&eof.channel_id) { dc.close() }
456                                    if eof.error.is_empty() {
457                                        // Normal EOF without error
458                                        trace!("Data channel {} closed by server", eof.channel_id);
459                                    } else {
460                                        // EOF with error
461                                        trace!("Data channel {} closed by server with error: {}", eof.channel_id, eof.error);
462                                    }
463                                },
464
465                                Message::DataChannelAck(DataChannelAck { channel_id, consumed }) => {
466                                    if let Some(ch) = self.data_channels.get(&channel_id) {
467                                        ch.add_capacity(consumed);
468                                    }
469                                }
470
471                                Message::EndpointStopAck(ref ep) => {
472                                    self.services.remove(&ep.guid);
473                                    #[cfg(feature = "plugins")]
474                                    self.plugin_processes.remove(&ep.guid);
475                                    result_tx.send(val).await.context("Can't send result message")?;
476                                }
477
478                                Message::EndpointRemoveAck(ref ep) => {
479                                    self.services.remove(&ep.guid);
480                                    #[cfg(feature = "plugins")]
481                                    self.plugin_processes.remove(&ep.guid);
482                                    result_tx.send(val).await.context("Can't send result message")?;
483                                }
484
485                                Message::HeartBeat(_) => {
486                                    conn.send_message(&Message::HeartBeat(HeartBeat{})).await.context("Failed to send heartbeat")?;
487                                },
488
489                                Message::Error(ref err) => {
490                                    let kind: ErrorKind = err.kind.try_into().unwrap_or(ErrorKind::Fatal);
491                                    result_tx.send(val.clone()).await.context("Can't send result message")?;
492                                    if kind == ErrorKind::Fatal || kind == ErrorKind::AuthFailed {
493                                        error!("Fatal error received, stop client: {:?}", err);
494                                        break;
495                                    }
496                                }
497
498                                Message::Break(break_msg) => {
499                                    info!("Breaking operation for guid: {}", break_msg.guid);
500                                    #[cfg(feature = "plugins")]
501                                    self.plugin_processes.remove(&break_msg.guid);
502                                }
503
504                                Message::PerformUpgrade(info) => {
505                                    let config_clone = self.config.clone();
506                                    #[cfg(feature = "plugins")]
507                                    self.plugin_processes.clear();
508                                    self.services.clear();
509                                    self.data_channels.clear();
510
511                                    if let Err(e) = handle_upgrade_available(
512                                        &info.version,
513                                        config_clone,
514                                        self.opts.gui,
515                                        command_rx,
516                                        result_tx,
517                                    )
518                                    .await
519                                    {
520                                        conn.send_message(&Message::Error(ErrorInfo {
521                                            kind: ErrorKind::UpgradeFailed.into(),
522                                            message: e.to_string(),
523                                            guid: String::new(),
524                                        }))
525                                        .await
526                                        .context("Can't send Error event")?;
527                                    }
528                                }
529
530                                v => {
531                                    result_tx.send(v).await.context("Can't send result message")?;
532                                }
533                            }
534                        },
535                        None => {
536                            debug!("Connection closed by server");
537                            break;
538                        }
539                    }
540                },
541                _ = time::sleep(Duration::from_secs(heartbeat_timeout)), if heartbeat_timeout != 0 => {
542                    return Err(anyhow!("Heartbeat timed out"))
543                }
544            }
545        }
546
547        info!("Control channel shutdown");
548        result_tx
549            .send(Message::ConnectState(ConnectState::Disconnected.into()))
550            .await
551            .context("Can't send Disconnected event")?;
552        conn.close().await.ok();
553        time::sleep(Duration::from_millis(100)).await; // Give some time for the connection to close gracefully
554        Ok(())
555    }
556}
557
558pub async fn run_client(
559    config: Arc<RwLock<ClientConfig>>,
560    opts: ClientOpts,
561    command_rx: mpsc::Receiver<Message>,
562    result_tx: mpsc::Sender<Message>,
563) -> Result<()> {
564    let mut client = Client::<WebsocketTransport>::from(config, opts)
565        .await
566        .context("Failed to create Websocket client")?;
567    client.run(command_rx, result_tx).await
568}
569async fn handle_tcp_data_channel(
570    data_channel: Arc<DataChannel>,
571    local_addr: String,
572    to_server_tx: FairSender<Message>,
573    mut data_rx: mpsc::Receiver<Data>,
574) -> Result<()> {
575    trace!("Handling client {:?} to {}", data_channel, local_addr);
576
577    // Connect to local service immediately
578    let mut local_stream = TcpStream::connect(&local_addr)
579        .await
580        .with_context(|| format!("Failed to connect to local service at {}", local_addr))?;
581
582    // Set TCP_NODELAY for low latency
583    local_stream
584        .set_nodelay(true)
585        .context("Failed to set TCP_NODELAY")?;
586
587    let mut buf = [0u8; DATA_BUFFER_SIZE]; // Smaller buffer for low latency
588
589    loop {
590        tokio::select! {
591            res = local_stream.read(&mut buf) => {
592                match res {
593                    Ok(0) => {
594                        trace!("EOF received from local service for {:?}", data_channel);
595                        if let Err(err) = to_server_tx.send(Message::DataChannelEof(DataChannelEof {
596                            channel_id: data_channel.id,
597                            error: String::new()
598                        }))
599                        .await {
600                            trace!("Failed to send EOF to server for {:?}: {:#}", data_channel, err);
601                        }
602                        break;
603                    },
604                    Ok(n) => {
605                        //debug!("Read {} bytes from local service for {:?}", n, data_channel);
606                        if data_channel.wait_for_capacity(n as u32).await.is_err() {
607                            trace!("Data channel {} closed when waiting for capacity", data_channel.id);
608                            break;
609                        }
610                        if let Err(err) = to_server_tx.send(Message::DataChannelData(DataChannelData {
611                            channel_id: data_channel.id,
612                            data: buf[0..n].to_vec()
613                        }))
614                        .await {
615                            trace!("Failed to send data to server for {:?}: {:#}", data_channel, err);
616                            break;
617                        }
618                    },
619                    Err(e) => {
620                        return Err(e).context("Failed to read from local service");
621                    }
622                }
623            }
624
625            // Receive data from server via control channel and write to local service
626            data_result = data_rx.recv() => {
627                match data_result {
628                    Some(data) => {
629                        trace!("Received {} bytes from server for {:?}", data.data.len(), data_channel);
630                        local_stream.write_all(&data.data).await.context("Failed to write data to local service")?;
631                        to_server_tx.send(Message::DataChannelAck(
632                            DataChannelAck {
633                                channel_id: data_channel.id,
634                                consumed: data.data.len() as u32
635                            }
636                        )).await.with_context(|| "Failed to send TCP traffic ack to the server")?;
637                    },
638                    None => {
639                        trace!("EOF received from server for {:?}", data_channel);
640                        break;
641                    }
642                }
643            }
644
645            _ = data_channel.closed() => {
646                trace!("Data channel {} closed", data_channel.id);
647                break;
648            }
649        }
650    }
651    Ok(())
652}
653
654// UDP port map for managing forwarders per remote address
655type UdpPortMap = Arc<DashMap<SocketAddr, mpsc::Sender<Bytes>>>;
656
657async fn handle_udp_data_channel(
658    data_channel: Arc<DataChannel>,
659    local_addr: String,
660    to_server_tx: FairSender<Message>,
661    mut data_rx: mpsc::Receiver<Data>,
662) -> Result<()> {
663    trace!(
664        "Handling client UDP channel {:?} to {}",
665        data_channel,
666        local_addr
667    );
668
669    let port_map: UdpPortMap = Arc::new(DashMap::new());
670
671    loop {
672        let data_channel = data_channel.clone();
673        // Receive data from server via control channel
674        tokio::select! {
675            data = data_rx.recv() => {
676                match data {
677                    Some(data) => {
678                        let external_addr = data.socket_addr.unwrap();
679
680                        if !port_map.contains_key(&external_addr) {
681                            // This packet is from an address we haven't seen for a while,
682                            // which is not in the UdpPortMap.
683                            // So set up a mapping (and a forwarder) for it
684
685                            match udp_connect(&local_addr).await {
686                                Ok(s) => {
687                                    let (to_service_tx, to_service_rx) = mpsc::channel(DATA_CHANNEL_SIZE);
688                                    port_map.insert(external_addr, to_service_tx);
689                                    tokio::spawn(run_udp_forwarder(
690                                        s,
691                                        to_service_rx,
692                                        to_server_tx.clone(),
693                                        external_addr,
694                                        data_channel,
695                                        port_map.clone(),
696                                    ));
697                                }
698                                Err(e) => {
699                                    error!(
700                                        "Failed to create UDP forwarder for {}: {:#}",
701                                        external_addr, e
702                                    );
703                                }
704                            }
705                        }
706
707                        // Now there should be a udp forwarder that can receive the packet
708                        if let Some(tx) = port_map.get(&external_addr) {
709                            let _ = tx.send(data.data).await;
710                        }
711                    }
712                    None => {
713                        trace!("EOF received from server for UDP {:?}", data_channel);
714                        break;
715                    }
716                }
717            }
718            _ = data_channel.closed() => {
719                trace!("Data channel {} closed", data_channel.id);
720                break;
721            }
722        }
723    }
724    Ok(())
725}
726
727// Run a UdpSocket for the visitor `from`
728async fn run_udp_forwarder(
729    s: UdpSocket,
730    mut to_service_rx: mpsc::Receiver<Bytes>,
731    to_server_tx: FairSender<Message>,
732    from: SocketAddr,
733    data_channel: Arc<DataChannel>,
734    port_map: UdpPortMap,
735) -> Result<()> {
736    trace!("UDP forwarder created for {} on {:?}", from, data_channel);
737    let mut buf = vec![0u8; UDP_BUFFER_SIZE];
738
739    loop {
740        tokio::select! {
741            // Receive from the server
742            data = to_service_rx.recv() => {
743                if let Some(data) = data {
744                    s.send(&data).await.with_context(|| "Failed to send UDP traffic to the service")?;
745                    to_server_tx.send(Message::DataChannelAck(
746                        DataChannelAck {
747                            channel_id: data_channel.id,
748                            consumed: data.len() as u32
749                        }
750                    )).await.with_context(|| "Failed to send UDP traffic ack to the server")?;
751                } else {
752                    break;
753                }
754            },
755
756            // Receive from the service
757            val = s.recv(&mut buf) => {
758                let len = match val {
759                    Ok(v) => v,
760                    Err(_) => break
761                };
762
763                if data_channel.wait_for_capacity(len as u32).await.is_err() {
764                    break;
765                }
766
767                to_server_tx.send(Message::DataChannelDataUdp(
768                    DataChannelDataUdp {
769                    channel_id: data_channel.id,
770                    data: buf[..len].to_vec(),
771                    socket_addr: Some(socket_addr_to_proto(&from)),
772                })).await.with_context(|| "Failed to send UDP traffic to the server")?;
773            },
774
775            // No traffic for the duration of UDP_TIMEOUT, clean up the state
776            _ = time::sleep(Duration::from_secs(UDP_TIMEOUT)) => {
777                break;
778            }
779
780            _ = data_channel.closed() => {
781                trace!("Data channel {} closed", data_channel.id);
782                break;
783            }
784        }
785    }
786
787    port_map.remove(&from);
788
789    debug!("UDP forwarder dropped for {} on {:?}", from, data_channel);
790    Ok(())
791}