1use anyhow::{anyhow, bail, Context, Result};
2use backoff::backoff::Backoff;
3use cloudpub_common::data::DataChannel;
4use cloudpub_common::fair_channel::{fair_channel, FairSender};
5use cloudpub_common::protocol::message::Message;
6use cloudpub_common::protocol::{
7 AgentInfo, ConnectState, Data, DataChannelAck, DataChannelData, DataChannelDataUdp,
8 DataChannelEof, ErrorInfo, ErrorKind, HeartBeat, Protocol,
9};
10use cloudpub_common::transport::{AddrMaybeCached, SocketOpts, Transport, WebsocketTransport};
11use cloudpub_common::utils::{
12 get_platform, proto_to_socket_addr, socket_addr_to_proto, udp_connect,
13};
14use cloudpub_common::VERSION;
15use dashmap::DashMap;
16use parking_lot::RwLock;
17use std::net::SocketAddr;
18use std::sync::Arc;
19use tokio::io::{AsyncReadExt, AsyncWriteExt};
20use tokio::net::{TcpStream, UdpSocket};
21use tokio::sync::mpsc;
22use tokio::time::{self, Duration, Instant};
23use tracing::{debug, error, info, trace, warn};
24
25use cloudpub_common::constants::{
26 run_control_chan_backoff, CONTROL_CHANNEL_SIZE, DATA_BUFFER_SIZE, DATA_CHANNEL_SIZE,
27 DEFAULT_CLIENT_RETRY_INTERVAL_SECS, UDP_BUFFER_SIZE, UDP_TIMEOUT,
28};
29use futures::future::FutureExt;
30
31use crate::config::{ClientConfig, ClientOpts};
32use crate::upgrade::handle_upgrade_available;
33use bytes::Bytes;
34use cloudpub_common::transport::ProtobufStream;
35
36#[cfg(feature = "plugins")]
37use crate::plugins::plugin_trait::PluginHandle;
38#[cfg(feature = "plugins")]
39use crate::plugins::registry::PluginRegistry;
40
41type Service = Arc<DataChannel>;
42
43type Services = Arc<DashMap<String, Service>>;
44
45struct Client<T: Transport> {
47 config: Arc<RwLock<ClientConfig>>,
48 opts: ClientOpts,
49 services: Services,
50 transport: Arc<T>,
51 connected: bool,
52 #[cfg(feature = "plugins")]
53 plugin_processes: Arc<DashMap<String, PluginHandle>>,
54 data_channels: Arc<DashMap<u32, Arc<DataChannel>>>,
55}
56
57impl<T: 'static + Transport> Client<T> {
58 async fn from(config: Arc<RwLock<ClientConfig>>, opts: ClientOpts) -> Result<Client<T>> {
60 let transport = Arc::new(
61 T::new(&config.clone().read().transport)
62 .with_context(|| "Failed to create the transport")?,
63 );
64 Ok(Client {
65 config,
66 opts,
67 services: Default::default(),
68 transport,
69 connected: false,
70 #[cfg(feature = "plugins")]
71 plugin_processes: Arc::new(DashMap::new()),
72 data_channels: Arc::new(DashMap::new()),
73 })
74 }
75
76 async fn run(
78 &mut self,
79 mut command_rx: mpsc::Receiver<Message>,
80 result_tx: mpsc::Sender<Message>,
81 ) -> Result<()> {
82 let transport = self.transport.clone();
83
84 let mut retry_backoff = run_control_chan_backoff(DEFAULT_CLIENT_RETRY_INTERVAL_SECS);
85
86 let mut start = Instant::now();
87 result_tx
88 .send(Message::ConnectState(ConnectState::Connecting.into()))
89 .await
90 .context("Can't send Connecting event")?;
91 while let Err(err) = self
92 .run_control_channel(transport.clone(), &mut command_rx, &result_tx)
93 .boxed()
94 .await
95 {
96 if result_tx.is_closed() {
97 break;
99 }
100
101 if self.connected {
102 result_tx
103 .send(Message::Error(ErrorInfo {
104 kind: ErrorKind::HandshakeFailed.into(),
105 message: crate::t!("error-network"),
106 guid: String::new(),
107 }))
108 .await
109 .context("Can't send Error event")?;
110 result_tx
111 .send(Message::ConnectState(ConnectState::Disconnected.into()))
112 .await
113 .context("Can't send Disconnected event")?;
114 result_tx
115 .send(Message::ConnectState(ConnectState::Connecting.into()))
116 .await
117 .context("Can't send Connecting event")?;
118 self.connected = false;
119 }
120
121 self.services.clear();
122 #[cfg(feature = "plugins")]
123 self.plugin_processes.clear();
124 self.data_channels.clear();
125
126 if start.elapsed() > Duration::from_secs(3) {
127 retry_backoff.reset();
129 }
130
131 if let Some(duration) = retry_backoff.next_backoff() {
132 warn!("{:#}. Retry in {:?}...", err, duration);
133 time::sleep(duration).await;
134 }
135
136 start = Instant::now();
137 }
138
139 self.services.clear();
140 #[cfg(feature = "plugins")]
141 self.plugin_processes.clear();
142 self.data_channels.clear();
143
144 Ok(())
145 }
146
147 async fn run_control_channel(
148 &mut self,
149 transport: Arc<T>,
150 command_rx: &mut mpsc::Receiver<Message>,
151 result_tx: &mpsc::Sender<Message>,
152 ) -> Result<()> {
153 let url = self.config.read().server.clone();
154 let port = url.port().unwrap_or(443);
155 let host = url.host_str().context("Failed to get host")?;
156 let mut host_and_port = format!("{}:{}", host, port);
157
158 let (mut conn, _remote_addr) = loop {
159 let mut remote_addr = AddrMaybeCached::new(&host_and_port);
160 remote_addr
161 .resolve()
162 .await
163 .context("Failed to resolve server address")?;
164
165 let mut conn = transport.connect(&remote_addr).await.context(format!(
166 "Failed to connect control channel to {}",
167 &host_and_port
168 ))?;
169
170 self.connected = true;
171
172 T::hint(&conn, SocketOpts::for_control_channel());
173
174 let (email, password) = if let Some(ref cred) = self.opts.credentials {
175 (cred.0.clone(), cred.1.clone())
176 } else {
177 (String::new(), String::new())
178 };
179
180 let token = self
181 .config
182 .read()
183 .token
184 .clone()
185 .unwrap_or_default()
186 .to_string();
187
188 let hwid = self.config.read().get_hwid();
189
190 let agent_info = AgentInfo {
191 agent_id: self.config.read().agent_id.clone(),
192 token,
193 email,
194 password,
195 hostname: hostname::get()?.to_string_lossy().into_owned(),
196 version: VERSION.to_string(),
197 gui: self.opts.gui,
198 platform: get_platform(),
199 hwid,
200 server_host_and_port: host_and_port.clone(),
201 transient: self.opts.transient,
202 secondary: self.opts.secondary,
203 is_service: self.opts.is_service,
204 };
205
206 debug!("Sending hello: {:?}", agent_info);
207
208 let hello_send = Message::AgentHello(agent_info);
209
210 conn.send_message(&hello_send)
211 .await
212 .context("Failed to send hello message")?;
213
214 debug!("Reading ack");
215 match conn
216 .recv_message()
217 .await
218 .context("Failed to read ack message")?
219 {
220 Some(msg) => match msg {
221 Message::AgentAck(args) => {
222 if !args.token.is_empty() {
223 let mut c = self.config.write();
224 c.token = Some(args.token.as_str().into());
225 c.save().context("Write config")?;
226 }
227 break (conn, remote_addr);
228 }
229 Message::Redirect(r) => {
230 host_and_port = r.host_and_port.clone();
231 debug!("Redirecting to {}", host_and_port);
232 continue;
233 }
234 Message::Error(err) => {
235 result_tx
236 .send(Message::Error(err.clone()))
237 .await
238 .context("Can't send server error event")?;
239 bail!("Error: {:?}", err.kind);
240 }
241 v => bail!("Unexpected ack message: {:?}", v),
242 },
243 None => bail!("Connection closed while reading ack message"),
244 };
245 };
246
247 debug!("Control channel established");
248
249 result_tx
250 .send(Message::ConnectState(ConnectState::Connected.into()))
251 .await
252 .context("Can't send Connected event")?;
253
254 let (to_server_tx, mut to_server_rx) = fair_channel::<Message>(CONTROL_CHANNEL_SIZE);
255 let heartbeat_timeout = self.config.read().heartbeat_timeout;
258
259 loop {
260 tokio::select! {
261 cmd = to_server_rx.recv() => {
262 if let Some(cmd) = cmd {
263 conn.send_message(&cmd).await.context("Failed to send command")?;
264 }
265 },
266 cmd = command_rx.recv() => {
267 if let Some(cmd) = cmd {
268 debug!("Received message: {:?}", cmd);
269 match cmd {
270 Message::PerformUpgrade(info) => {
271 let config_clone = self.config.clone();
272 if let Err(e) = handle_upgrade_available(
273 &info.version,
274 config_clone,
275 self.opts.gui,
276 command_rx,
277 result_tx,
278 )
279 .await
280 {
281 result_tx.send(Message::Error(ErrorInfo {
282 kind: ErrorKind::Fatal.into(),
283 message: e.to_string(),
284 guid: String::new(),
285 }))
286 .await
287 .context("Can't send Error event")?;
288 }
289 }
290 Message::Stop(_x) => {
291 info!("Stopping the client");
292 break;
293 }
294 Message::Break(break_msg) => {
295 info!("Breaking operation for guid: {}", break_msg.guid);
296 #[cfg(feature = "plugins")]
297 if let Some((_, handle)) = self.plugin_processes.remove(&break_msg.guid) {
298 info!("Dropped plugin handle for guid: {}", break_msg.guid);
299 drop(handle);
300 }
301 }
302 cmd => {
303 conn.send_message(&cmd).await.context("Failed to send message")?;
304 }
305 };
306 } else {
307 debug!("No more commands, shutting down...");
308 break;
309 }
310 },
311 val = conn.recv_message() => {
312 match val? {
313 Some(val) => {
314 match val {
315 Message::EndpointAck(mut endpoint) => {
316 #[cfg(feature = "plugins")]
317 {
318 let to_server_tx = to_server_tx.clone();
319 let config = self.config.clone();
320 let opts = self.opts.clone();
321 if endpoint.error.is_empty() {
322 let protocol: Protocol = endpoint
323 .client
324 .as_ref()
325 .unwrap()
326 .local_proto
327 .try_into()
328 .unwrap_or(Protocol::Tcp);
329 if let Some(plugin) = PluginRegistry::new().get(protocol) {
330 let handle = PluginHandle::spawn(
331 plugin,
332 endpoint.clone(),
333 config,
334 opts,
335 to_server_tx,
336 );
337 self.plugin_processes.insert(endpoint.guid.clone(), handle);
338 } else {
339 endpoint.status = Some("online".into());
340 let _ = to_server_tx.send(Message::EndpointStatus(endpoint.clone())).await;
341 }
342 }
343 }
344 #[cfg(not(feature = "plugins"))]
345 {
346 endpoint.status = Some("online".into());
347 let _ = to_server_tx.send(Message::EndpointStatus(endpoint.clone())).await;
348 }
349 result_tx
350 .send(Message::EndpointAck(endpoint))
351 .await
352 .context("Can't send EndpointAck event")?;
353 }
354
355 Message::CreateDataChannelWithId(create_msg) => {
356 let channel_id = create_msg.channel_id;
357 let endpoint = create_msg.endpoint.unwrap();
358
359 trace!("Creating data channel {} for endpoint {:?}", channel_id, endpoint.guid);
360
361 let (to_service_tx, to_service_rx) = mpsc::channel::<Data>(DATA_CHANNEL_SIZE);
363
364 let data_channel = Arc::new(DataChannel::new_client(channel_id, to_service_tx.clone()));
366 self.data_channels.insert(channel_id, data_channel.clone());
367
368 let client = endpoint.client.unwrap();
370 #[allow(unused_mut)]
371 let mut local_addr = format!("{}:{}", client.local_addr, client.local_port);
372 #[cfg(feature = "plugins")]
373 if let Some(handle) = self.plugin_processes.get(&endpoint.guid) {
374 if let Some(port) = handle.value().port() {
375 local_addr = format!("127.0.0.1:{}", port);
376 }
377 }
378
379 let data_channels = self.data_channels.clone();
381 let protocol: Protocol = client.local_proto.try_into().unwrap();
382
383 let to_server_tx_cloned = to_server_tx.clone();
384 tokio::spawn(async move {
385 if let Err(err) = if protocol == Protocol::Udp {
386 handle_udp_data_channel(
387 data_channel,
388 local_addr,
389 to_server_tx_cloned.clone(),
390 to_service_rx
391 ).await
392 } else {
393 handle_tcp_data_channel(
394 data_channel,
395 local_addr,
396 to_server_tx_cloned.clone(),
397 to_service_rx
398 ).await
399 } {
400 error!("DataChannel {{ channel_id: {} }}: {:?}", channel_id, err);
401 to_server_tx_cloned
402 .send(Message::DataChannelEof(
403 DataChannelEof {
404 channel_id,
405 error: err.to_string()
406 })
407 ).await.ok();
408 }
409 if let Some((_, dc)) = data_channels.remove(&channel_id) { dc.close() }
410 });
411 },
412
413 Message::DataChannelData(data) => {
414 let to_service_tx = self.data_channels.get(&data.channel_id).map(|ch| ch.data_tx.clone());
416 if let Some(tx) = to_service_tx {
417 if let Err(err) = tx.send(Data {
418 data: data.data.into(),
419 socket_addr: None
420 }).await {
421 self.data_channels.remove(&data.channel_id);
422 error!("Error send to data channel {}: {:?}", data.channel_id, err);
423 }
424 } else {
425 trace!("Data channel {} not found, dropping data", data.channel_id);
426 }
427 },
428
429 Message::DataChannelDataUdp(data) => {
430 let to_service_tx = self.data_channels.get(&data.channel_id).map(|ch| ch.data_tx.clone());
432 if let Some(tx) = to_service_tx {
433 let socket_addr = data.socket_addr.as_ref()
434 .map(proto_to_socket_addr)
435 .transpose()
436 .unwrap_or_else(|err| {
437 error!("Invalid socket address for UDP data channel {}: {:?}", data.channel_id, err);
438 None
439 });
440
441 if let Err(err) = tx.send(Data {
442 data: data.data.into(),
443 socket_addr,
444 }).await {
445 self.data_channels.remove(&data.channel_id);
446 error!("Error send to UDP data channel {}: {:?}", data.channel_id, err);
447 }
448 } else {
449 trace!("UDP Data channel {} not found, dropping data", data.channel_id);
450 }
451 },
452
453 Message::DataChannelEof(eof) => {
454 if let Some((_, dc)) = self.data_channels.remove(&eof.channel_id) { dc.close() }
456 if eof.error.is_empty() {
457 trace!("Data channel {} closed by server", eof.channel_id);
459 } else {
460 trace!("Data channel {} closed by server with error: {}", eof.channel_id, eof.error);
462 }
463 },
464
465 Message::DataChannelAck(DataChannelAck { channel_id, consumed }) => {
466 if let Some(ch) = self.data_channels.get(&channel_id) {
467 ch.add_capacity(consumed);
468 }
469 }
470
471 Message::EndpointStopAck(ref ep) => {
472 self.services.remove(&ep.guid);
473 #[cfg(feature = "plugins")]
474 self.plugin_processes.remove(&ep.guid);
475 result_tx.send(val).await.context("Can't send result message")?;
476 }
477
478 Message::EndpointRemoveAck(ref ep) => {
479 self.services.remove(&ep.guid);
480 #[cfg(feature = "plugins")]
481 self.plugin_processes.remove(&ep.guid);
482 result_tx.send(val).await.context("Can't send result message")?;
483 }
484
485 Message::HeartBeat(_) => {
486 conn.send_message(&Message::HeartBeat(HeartBeat{})).await.context("Failed to send heartbeat")?;
487 },
488
489 Message::Error(ref err) => {
490 let kind: ErrorKind = err.kind.try_into().unwrap_or(ErrorKind::Fatal);
491 result_tx.send(val.clone()).await.context("Can't send result message")?;
492 if kind == ErrorKind::Fatal || kind == ErrorKind::AuthFailed {
493 error!("Fatal error received, stop client: {:?}", err);
494 break;
495 }
496 }
497
498 Message::Break(break_msg) => {
499 info!("Breaking operation for guid: {}", break_msg.guid);
500 #[cfg(feature = "plugins")]
501 self.plugin_processes.remove(&break_msg.guid);
502 }
503
504 Message::PerformUpgrade(info) => {
505 let config_clone = self.config.clone();
506 #[cfg(feature = "plugins")]
507 self.plugin_processes.clear();
508 self.services.clear();
509 self.data_channels.clear();
510
511 if let Err(e) = handle_upgrade_available(
512 &info.version,
513 config_clone,
514 self.opts.gui,
515 command_rx,
516 result_tx,
517 )
518 .await
519 {
520 conn.send_message(&Message::Error(ErrorInfo {
521 kind: ErrorKind::UpgradeFailed.into(),
522 message: e.to_string(),
523 guid: String::new(),
524 }))
525 .await
526 .context("Can't send Error event")?;
527 }
528 }
529
530 v => {
531 result_tx.send(v).await.context("Can't send result message")?;
532 }
533 }
534 },
535 None => {
536 debug!("Connection closed by server");
537 break;
538 }
539 }
540 },
541 _ = time::sleep(Duration::from_secs(heartbeat_timeout)), if heartbeat_timeout != 0 => {
542 return Err(anyhow!("Heartbeat timed out"))
543 }
544 }
545 }
546
547 info!("Control channel shutdown");
548 result_tx
549 .send(Message::ConnectState(ConnectState::Disconnected.into()))
550 .await
551 .context("Can't send Disconnected event")?;
552 conn.close().await.ok();
553 time::sleep(Duration::from_millis(100)).await; Ok(())
555 }
556}
557
558pub async fn run_client(
559 config: Arc<RwLock<ClientConfig>>,
560 opts: ClientOpts,
561 command_rx: mpsc::Receiver<Message>,
562 result_tx: mpsc::Sender<Message>,
563) -> Result<()> {
564 let mut client = Client::<WebsocketTransport>::from(config, opts)
565 .await
566 .context("Failed to create Websocket client")?;
567 client.run(command_rx, result_tx).await
568}
569async fn handle_tcp_data_channel(
570 data_channel: Arc<DataChannel>,
571 local_addr: String,
572 to_server_tx: FairSender<Message>,
573 mut data_rx: mpsc::Receiver<Data>,
574) -> Result<()> {
575 trace!("Handling client {:?} to {}", data_channel, local_addr);
576
577 let mut local_stream = TcpStream::connect(&local_addr)
579 .await
580 .with_context(|| format!("Failed to connect to local service at {}", local_addr))?;
581
582 local_stream
584 .set_nodelay(true)
585 .context("Failed to set TCP_NODELAY")?;
586
587 let mut buf = [0u8; DATA_BUFFER_SIZE]; loop {
590 tokio::select! {
591 res = local_stream.read(&mut buf) => {
592 match res {
593 Ok(0) => {
594 trace!("EOF received from local service for {:?}", data_channel);
595 if let Err(err) = to_server_tx.send(Message::DataChannelEof(DataChannelEof {
596 channel_id: data_channel.id,
597 error: String::new()
598 }))
599 .await {
600 trace!("Failed to send EOF to server for {:?}: {:#}", data_channel, err);
601 }
602 break;
603 },
604 Ok(n) => {
605 if data_channel.wait_for_capacity(n as u32).await.is_err() {
607 trace!("Data channel {} closed when waiting for capacity", data_channel.id);
608 break;
609 }
610 if let Err(err) = to_server_tx.send(Message::DataChannelData(DataChannelData {
611 channel_id: data_channel.id,
612 data: buf[0..n].to_vec()
613 }))
614 .await {
615 trace!("Failed to send data to server for {:?}: {:#}", data_channel, err);
616 break;
617 }
618 },
619 Err(e) => {
620 return Err(e).context("Failed to read from local service");
621 }
622 }
623 }
624
625 data_result = data_rx.recv() => {
627 match data_result {
628 Some(data) => {
629 trace!("Received {} bytes from server for {:?}", data.data.len(), data_channel);
630 local_stream.write_all(&data.data).await.context("Failed to write data to local service")?;
631 to_server_tx.send(Message::DataChannelAck(
632 DataChannelAck {
633 channel_id: data_channel.id,
634 consumed: data.data.len() as u32
635 }
636 )).await.with_context(|| "Failed to send TCP traffic ack to the server")?;
637 },
638 None => {
639 trace!("EOF received from server for {:?}", data_channel);
640 break;
641 }
642 }
643 }
644
645 _ = data_channel.closed() => {
646 trace!("Data channel {} closed", data_channel.id);
647 break;
648 }
649 }
650 }
651 Ok(())
652}
653
654type UdpPortMap = Arc<DashMap<SocketAddr, mpsc::Sender<Bytes>>>;
656
657async fn handle_udp_data_channel(
658 data_channel: Arc<DataChannel>,
659 local_addr: String,
660 to_server_tx: FairSender<Message>,
661 mut data_rx: mpsc::Receiver<Data>,
662) -> Result<()> {
663 trace!(
664 "Handling client UDP channel {:?} to {}",
665 data_channel,
666 local_addr
667 );
668
669 let port_map: UdpPortMap = Arc::new(DashMap::new());
670
671 loop {
672 let data_channel = data_channel.clone();
673 tokio::select! {
675 data = data_rx.recv() => {
676 match data {
677 Some(data) => {
678 let external_addr = data.socket_addr.unwrap();
679
680 if !port_map.contains_key(&external_addr) {
681 match udp_connect(&local_addr).await {
686 Ok(s) => {
687 let (to_service_tx, to_service_rx) = mpsc::channel(DATA_CHANNEL_SIZE);
688 port_map.insert(external_addr, to_service_tx);
689 tokio::spawn(run_udp_forwarder(
690 s,
691 to_service_rx,
692 to_server_tx.clone(),
693 external_addr,
694 data_channel,
695 port_map.clone(),
696 ));
697 }
698 Err(e) => {
699 error!(
700 "Failed to create UDP forwarder for {}: {:#}",
701 external_addr, e
702 );
703 }
704 }
705 }
706
707 if let Some(tx) = port_map.get(&external_addr) {
709 let _ = tx.send(data.data).await;
710 }
711 }
712 None => {
713 trace!("EOF received from server for UDP {:?}", data_channel);
714 break;
715 }
716 }
717 }
718 _ = data_channel.closed() => {
719 trace!("Data channel {} closed", data_channel.id);
720 break;
721 }
722 }
723 }
724 Ok(())
725}
726
727async fn run_udp_forwarder(
729 s: UdpSocket,
730 mut to_service_rx: mpsc::Receiver<Bytes>,
731 to_server_tx: FairSender<Message>,
732 from: SocketAddr,
733 data_channel: Arc<DataChannel>,
734 port_map: UdpPortMap,
735) -> Result<()> {
736 trace!("UDP forwarder created for {} on {:?}", from, data_channel);
737 let mut buf = vec![0u8; UDP_BUFFER_SIZE];
738
739 loop {
740 tokio::select! {
741 data = to_service_rx.recv() => {
743 if let Some(data) = data {
744 s.send(&data).await.with_context(|| "Failed to send UDP traffic to the service")?;
745 to_server_tx.send(Message::DataChannelAck(
746 DataChannelAck {
747 channel_id: data_channel.id,
748 consumed: data.len() as u32
749 }
750 )).await.with_context(|| "Failed to send UDP traffic ack to the server")?;
751 } else {
752 break;
753 }
754 },
755
756 val = s.recv(&mut buf) => {
758 let len = match val {
759 Ok(v) => v,
760 Err(_) => break
761 };
762
763 if data_channel.wait_for_capacity(len as u32).await.is_err() {
764 break;
765 }
766
767 to_server_tx.send(Message::DataChannelDataUdp(
768 DataChannelDataUdp {
769 channel_id: data_channel.id,
770 data: buf[..len].to_vec(),
771 socket_addr: Some(socket_addr_to_proto(&from)),
772 })).await.with_context(|| "Failed to send UDP traffic to the server")?;
773 },
774
775 _ = time::sleep(Duration::from_secs(UDP_TIMEOUT)) => {
777 break;
778 }
779
780 _ = data_channel.closed() => {
781 trace!("Data channel {} closed", data_channel.id);
782 break;
783 }
784 }
785 }
786
787 port_map.remove(&from);
788
789 debug!("UDP forwarder dropped for {} on {:?}", from, data_channel);
790 Ok(())
791}