1use anyhow::{anyhow, bail, Context, Result};
2use backoff::backoff::Backoff;
3use cloudpub_common::data::DataChannel;
4use cloudpub_common::fair_channel::{fair_channel, FairSender};
5use cloudpub_common::protocol::message::Message;
6use cloudpub_common::protocol::{
7 AgentInfo, ConnectState, Data, DataChannelAck, DataChannelData, DataChannelDataUdp,
8 DataChannelEof, ErrorInfo, ErrorKind, HeartBeat, Protocol,
9};
10use cloudpub_common::transport::{AddrMaybeCached, SocketOpts, Transport, WebsocketTransport};
11use cloudpub_common::utils::{
12 get_platform, proto_to_socket_addr, socket_addr_to_proto, udp_connect,
13};
14use cloudpub_common::VERSION;
15use dashmap::DashMap;
16use parking_lot::RwLock;
17use std::net::SocketAddr;
18use std::sync::Arc;
19use tokio::io::{AsyncReadExt, AsyncWriteExt};
20use tokio::net::{TcpStream, UdpSocket};
21use tokio::sync::mpsc;
22use tokio::time::{self, Duration, Instant};
23use tracing::{debug, error, info, trace, warn};
24
25use cloudpub_common::constants::{
26 run_control_chan_backoff, CONTROL_CHANNEL_SIZE, DATA_BUFFER_SIZE, DATA_CHANNEL_SIZE,
27 DEFAULT_CLIENT_RETRY_INTERVAL_SECS, UDP_BUFFER_SIZE, UDP_TIMEOUT,
28};
29use futures::future::FutureExt;
30
31use crate::config::{ClientConfig, ClientOpts};
32use crate::upgrade::handle_upgrade_available;
33use bytes::Bytes;
34use cloudpub_common::transport::ProtobufStream;
35
36#[cfg(feature = "plugins")]
37use crate::plugins::plugin_trait::PluginHandle;
38#[cfg(feature = "plugins")]
39use crate::plugins::registry::PluginRegistry;
40
41type Service = Arc<DataChannel>;
42
43type Services = Arc<DashMap<String, Service>>;
44
45struct Client<T: Transport> {
47 config: Arc<RwLock<ClientConfig>>,
48 opts: ClientOpts,
49 services: Services,
50 transport: Arc<T>,
51 connected: bool,
52 #[cfg(feature = "plugins")]
53 plugin_processes: Arc<DashMap<String, PluginHandle>>,
54 data_channels: Arc<DashMap<u32, Arc<DataChannel>>>,
55}
56
57impl<T: 'static + Transport> Client<T> {
58 async fn from(config: Arc<RwLock<ClientConfig>>, opts: ClientOpts) -> Result<Client<T>> {
60 let transport = Arc::new(
61 T::new(&config.clone().read().transport)
62 .with_context(|| "Failed to create the transport")?,
63 );
64 Ok(Client {
65 config,
66 opts,
67 services: Default::default(),
68 transport,
69 connected: false,
70 #[cfg(feature = "plugins")]
71 plugin_processes: Arc::new(DashMap::new()),
72 data_channels: Arc::new(DashMap::new()),
73 })
74 }
75
76 async fn run(
78 &mut self,
79 mut command_rx: mpsc::Receiver<Message>,
80 result_tx: mpsc::Sender<Message>,
81 ) -> Result<()> {
82 let transport = self.transport.clone();
83
84 let mut retry_backoff = run_control_chan_backoff(DEFAULT_CLIENT_RETRY_INTERVAL_SECS);
85
86 let mut start = Instant::now();
87 result_tx
88 .send(Message::ConnectState(ConnectState::Connecting.into()))
89 .await
90 .context("Can't send Connecting event")?;
91 while let Err(err) = self
92 .run_control_channel(transport.clone(), &mut command_rx, &result_tx)
93 .boxed()
94 .await
95 {
96 if result_tx.is_closed() {
97 break;
99 }
100
101 if self.connected {
102 result_tx
103 .send(Message::Error(ErrorInfo {
104 kind: ErrorKind::HandshakeFailed.into(),
105 message: crate::t!("error-network"),
106 guid: String::new(),
107 }))
108 .await
109 .context("Can't send Error event")?;
110 result_tx
111 .send(Message::ConnectState(ConnectState::Disconnected.into()))
112 .await
113 .context("Can't send Disconnected event")?;
114 result_tx
115 .send(Message::ConnectState(ConnectState::Connecting.into()))
116 .await
117 .context("Can't send Connecting event")?;
118 self.connected = false;
119 }
120
121 self.services.clear();
122 #[cfg(feature = "plugins")]
123 self.plugin_processes.clear();
124 self.data_channels.clear();
125
126 if start.elapsed() > Duration::from_secs(3) {
127 retry_backoff.reset();
129 }
130
131 if let Some(duration) = retry_backoff.next_backoff() {
132 warn!("{:#}. Retry in {:?}...", err, duration);
133 time::sleep(duration).await;
134 }
135
136 start = Instant::now();
137 }
138
139 self.services.clear();
140 #[cfg(feature = "plugins")]
141 self.plugin_processes.clear();
142 self.data_channels.clear();
143
144 Ok(())
145 }
146
147 async fn run_control_channel(
148 &mut self,
149 transport: Arc<T>,
150 command_rx: &mut mpsc::Receiver<Message>,
151 result_tx: &mpsc::Sender<Message>,
152 ) -> Result<()> {
153 let url = self.config.read().server.clone();
154 let port = url.port().unwrap_or(443);
155 let host = url.host_str().context("Failed to get host")?;
156 let mut host_and_port = format!("{}:{}", host, port);
157
158 let (mut conn, _remote_addr) = loop {
159 let mut remote_addr = AddrMaybeCached::new(&host_and_port);
160 remote_addr
161 .resolve()
162 .await
163 .context("Failed to resolve server address")?;
164
165 let mut conn = transport.connect(&remote_addr).await.context(format!(
166 "Failed to connect control channel to {}",
167 &host_and_port
168 ))?;
169
170 self.connected = true;
171
172 T::hint(&conn, SocketOpts::for_control_channel());
173
174 let (email, password) = if let Some(ref cred) = self.opts.credentials {
175 (cred.0.clone(), cred.1.clone())
176 } else {
177 (String::new(), String::new())
178 };
179
180 let token = self
181 .config
182 .read()
183 .token
184 .clone()
185 .unwrap_or_default()
186 .to_string();
187
188 let hwid = self.config.read().get_hwid();
189
190 let agent_info = AgentInfo {
191 agent_id: self.config.read().agent_id.clone(),
192 token,
193 email,
194 password,
195 hostname: hostname::get()?.to_string_lossy().into_owned(),
196 version: VERSION.to_string(),
197 gui: self.opts.gui,
198 platform: get_platform(),
199 hwid,
200 server_host_and_port: host_and_port.clone(),
201 transient: self.opts.transient,
202 secondary: self.opts.secondary,
203 };
204
205 debug!("Sending hello: {:?}", agent_info);
206
207 let hello_send = Message::AgentHello(agent_info);
208
209 conn.send_message(&hello_send)
210 .await
211 .context("Failed to send hello message")?;
212
213 debug!("Reading ack");
214 match conn
215 .recv_message()
216 .await
217 .context("Failed to read ack message")?
218 {
219 Some(msg) => match msg {
220 Message::AgentAck(args) => {
221 if !args.token.is_empty() {
222 let mut c = self.config.write();
223 c.token = Some(args.token.as_str().into());
224 c.save().context("Write config")?;
225 }
226 break (conn, remote_addr);
227 }
228 Message::Redirect(r) => {
229 host_and_port = r.host_and_port.clone();
230 debug!("Redirecting to {}", host_and_port);
231 continue;
232 }
233 Message::Error(err) => {
234 result_tx
235 .send(Message::Error(err.clone()))
236 .await
237 .context("Can't send server error event")?;
238 bail!("Error: {:?}", err.kind);
239 }
240 v => bail!("Unexpected ack message: {:?}", v),
241 },
242 None => bail!("Connection closed while reading ack message"),
243 };
244 };
245
246 debug!("Control channel established");
247
248 result_tx
249 .send(Message::ConnectState(ConnectState::Connected.into()))
250 .await
251 .context("Can't send Connected event")?;
252
253 let (to_server_tx, mut to_server_rx) = fair_channel::<Message>(CONTROL_CHANNEL_SIZE);
254 let heartbeat_timeout = self.config.read().heartbeat_timeout;
257
258 loop {
259 tokio::select! {
260 cmd = to_server_rx.recv() => {
261 if let Some(cmd) = cmd {
262 conn.send_message(&cmd).await.context("Failed to send command")?;
263 }
264 },
265 cmd = command_rx.recv() => {
266 if let Some(cmd) = cmd {
267 debug!("Received message: {:?}", cmd);
268 match cmd {
269 Message::PerformUpgrade(info) => {
270 let config_clone = self.config.clone();
271 if let Err(e) = handle_upgrade_available(
272 &info.version,
273 config_clone,
274 self.opts.gui,
275 command_rx,
276 result_tx,
277 )
278 .await
279 {
280 result_tx.send(Message::Error(ErrorInfo {
281 kind: ErrorKind::Fatal.into(),
282 message: e.to_string(),
283 guid: String::new(),
284 }))
285 .await
286 .context("Can't send Error event")?;
287 }
288 }
289 Message::Stop(_x) => {
290 info!("Stopping the client");
291 break;
292 }
293 Message::Break(break_msg) => {
294 info!("Breaking operation for guid: {}", break_msg.guid);
295 #[cfg(feature = "plugins")]
296 if let Some((_, handle)) = self.plugin_processes.remove(&break_msg.guid) {
297 info!("Dropped plugin handle for guid: {}", break_msg.guid);
298 drop(handle);
299 }
300 }
301 cmd => {
302 conn.send_message(&cmd).await.context("Failed to send message")?;
303 }
304 };
305 } else {
306 debug!("No more commands, shutting down...");
307 break;
308 }
309 },
310 val = conn.recv_message() => {
311 match val? {
312 Some(val) => {
313 match val {
314 Message::EndpointAck(mut endpoint) => {
315 #[cfg(feature = "plugins")]
316 {
317 let to_server_tx = to_server_tx.clone();
318 let config = self.config.clone();
319 let opts = self.opts.clone();
320 if endpoint.error.is_empty() {
321 let protocol: Protocol = endpoint
322 .client
323 .as_ref()
324 .unwrap()
325 .local_proto
326 .try_into()
327 .unwrap_or(Protocol::Tcp);
328 if let Some(plugin) = PluginRegistry::new().get(protocol) {
329 let handle = PluginHandle::spawn(
330 plugin,
331 endpoint.clone(),
332 config,
333 opts,
334 to_server_tx,
335 );
336 self.plugin_processes.insert(endpoint.guid.clone(), handle);
337 } else {
338 endpoint.status = Some("online".into());
339 let _ = to_server_tx.send(Message::EndpointStatus(endpoint.clone())).await;
340 }
341 }
342 }
343 #[cfg(not(feature = "plugins"))]
344 {
345 endpoint.status = Some("online".into());
346 let _ = to_server_tx.send(Message::EndpointStatus(endpoint.clone())).await;
347 }
348 result_tx
349 .send(Message::EndpointAck(endpoint))
350 .await
351 .context("Can't send EndpointAck event")?;
352 }
353
354 Message::CreateDataChannelWithId(create_msg) => {
355 let channel_id = create_msg.channel_id;
356 let endpoint = create_msg.endpoint.unwrap();
357
358 trace!("Creating data channel {} for endpoint {:?}", channel_id, endpoint.guid);
359
360 let (to_service_tx, to_service_rx) = mpsc::channel::<Data>(DATA_CHANNEL_SIZE);
362
363 let data_channel = Arc::new(DataChannel::new_client(channel_id, to_service_tx.clone()));
365 self.data_channels.insert(channel_id, data_channel.clone());
366
367 let client = endpoint.client.unwrap();
369 let mut local_addr = format!("{}:{}", client.local_addr, client.local_port);
370 #[cfg(feature = "plugins")]
371 if let Some(handle) = self.plugin_processes.get(&endpoint.guid) {
372 if let Some(port) = handle.value().port() {
373 local_addr = format!("127.0.0.1:{}", port);
374 }
375 }
376
377 let data_channels = self.data_channels.clone();
379 let protocol: Protocol = client.local_proto.try_into().unwrap();
380
381 let to_server_tx_cloned = to_server_tx.clone();
382 tokio::spawn(async move {
383 if let Err(err) = if protocol == Protocol::Udp {
384 handle_udp_data_channel(
385 data_channel,
386 local_addr,
387 to_server_tx_cloned.clone(),
388 to_service_rx
389 ).await
390 } else {
391 handle_tcp_data_channel(
392 data_channel,
393 local_addr,
394 to_server_tx_cloned.clone(),
395 to_service_rx
396 ).await
397 } {
398 error!("DataChannel {{ channel_id: {} }}: {:?}", channel_id, err);
399 to_server_tx_cloned
400 .send(Message::DataChannelEof(
401 DataChannelEof {
402 channel_id,
403 error: err.to_string()
404 })
405 ).await.ok();
406 }
407 if let Some((_, dc)) = data_channels.remove(&channel_id) { dc.close() }
408 });
409 },
410
411 Message::DataChannelData(data) => {
412 let to_service_tx = self.data_channels.get(&data.channel_id).map(|ch| ch.data_tx.clone());
414 if let Some(tx) = to_service_tx {
415 if let Err(err) = tx.send(Data {
416 data: data.data.into(),
417 socket_addr: None
418 }).await {
419 self.data_channels.remove(&data.channel_id);
420 error!("Error send to data channel {}: {:?}", data.channel_id, err);
421 }
422 } else {
423 trace!("Data channel {} not found, dropping data", data.channel_id);
424 }
425 },
426
427 Message::DataChannelDataUdp(data) => {
428 let to_service_tx = self.data_channels.get(&data.channel_id).map(|ch| ch.data_tx.clone());
430 if let Some(tx) = to_service_tx {
431 let socket_addr = data.socket_addr.as_ref()
432 .map(proto_to_socket_addr)
433 .transpose()
434 .unwrap_or_else(|err| {
435 error!("Invalid socket address for UDP data channel {}: {:?}", data.channel_id, err);
436 None
437 });
438
439 if let Err(err) = tx.send(Data {
440 data: data.data.into(),
441 socket_addr,
442 }).await {
443 self.data_channels.remove(&data.channel_id);
444 error!("Error send to UDP data channel {}: {:?}", data.channel_id, err);
445 }
446 } else {
447 trace!("UDP Data channel {} not found, dropping data", data.channel_id);
448 }
449 },
450
451 Message::DataChannelEof(eof) => {
452 if let Some((_, dc)) = self.data_channels.remove(&eof.channel_id) { dc.close() }
454 if eof.error.is_empty() {
455 trace!("Data channel {} closed by server", eof.channel_id);
457 } else {
458 trace!("Data channel {} closed by server with error: {}", eof.channel_id, eof.error);
460 }
461 },
462
463 Message::DataChannelAck(DataChannelAck { channel_id, consumed }) => {
464 if let Some(ch) = self.data_channels.get(&channel_id) {
465 ch.add_capacity(consumed);
466 }
467 }
468
469 Message::EndpointStopAck(ref ep) => {
470 self.services.remove(&ep.guid);
471 #[cfg(feature = "plugins")]
472 self.plugin_processes.remove(&ep.guid);
473 result_tx.send(val).await.context("Can't send result message")?;
474 }
475
476 Message::EndpointRemoveAck(ref ep) => {
477 self.services.remove(&ep.guid);
478 #[cfg(feature = "plugins")]
479 self.plugin_processes.remove(&ep.guid);
480 result_tx.send(val).await.context("Can't send result message")?;
481 }
482
483 Message::HeartBeat(_) => {
484 conn.send_message(&Message::HeartBeat(HeartBeat{})).await.context("Failed to send heartbeat")?;
485 },
486
487 Message::Error(ref err) => {
488 let kind: ErrorKind = err.kind.try_into().unwrap_or(ErrorKind::Fatal);
489 result_tx.send(val.clone()).await.context("Can't send result message")?;
490 if kind == ErrorKind::Fatal || kind == ErrorKind::AuthFailed {
491 error!("Fatal error received, stop client: {:?}", err);
492 break;
493 }
494 }
495
496 Message::Break(break_msg) => {
497 info!("Breaking operation for guid: {}", break_msg.guid);
498 #[cfg(feature = "plugins")]
499 self.plugin_processes.remove(&break_msg.guid);
500 }
501
502 v => {
503 result_tx.send(v).await.context("Can't send result message")?;
504 }
505 }
506 },
507 None => {
508 debug!("Connection closed by server");
509 break;
510 }
511 }
512 },
513 _ = time::sleep(Duration::from_secs(heartbeat_timeout)), if heartbeat_timeout != 0 => {
514 return Err(anyhow!("Heartbeat timed out"))
515 }
516 }
517 }
518
519 info!("Control channel shutdown");
520 result_tx
521 .send(Message::ConnectState(ConnectState::Disconnected.into()))
522 .await
523 .context("Can't send Disconnected event")?;
524 conn.close().await.ok();
525 time::sleep(Duration::from_millis(100)).await; Ok(())
527 }
528}
529
530pub async fn run_client(
531 config: Arc<RwLock<ClientConfig>>,
532 opts: ClientOpts,
533 command_rx: mpsc::Receiver<Message>,
534 result_tx: mpsc::Sender<Message>,
535) -> Result<()> {
536 let mut client = Client::<WebsocketTransport>::from(config, opts)
537 .await
538 .context("Failed to create Websocket client")?;
539 client.run(command_rx, result_tx).await
540}
541async fn handle_tcp_data_channel(
542 data_channel: Arc<DataChannel>,
543 local_addr: String,
544 to_server_tx: FairSender<Message>,
545 mut data_rx: mpsc::Receiver<Data>,
546) -> Result<()> {
547 trace!("Handling client {:?} to {}", data_channel, local_addr);
548
549 let mut local_stream = TcpStream::connect(&local_addr)
551 .await
552 .with_context(|| format!("Failed to connect to local service at {}", local_addr))?;
553
554 local_stream
556 .set_nodelay(true)
557 .context("Failed to set TCP_NODELAY")?;
558
559 let mut buf = [0u8; DATA_BUFFER_SIZE]; loop {
562 tokio::select! {
563 res = local_stream.read(&mut buf) => {
564 match res {
565 Ok(0) => {
566 trace!("EOF received from local service for {:?}", data_channel);
567 if let Err(err) = to_server_tx.send(Message::DataChannelEof(DataChannelEof {
568 channel_id: data_channel.id,
569 error: String::new()
570 }))
571 .await {
572 trace!("Failed to send EOF to server for {:?}: {:#}", data_channel, err);
573 }
574 break;
575 },
576 Ok(n) => {
577 if data_channel.wait_for_capacity(n as u32).await.is_err() {
579 trace!("Data channel {} closed when waiting for capacity", data_channel.id);
580 break;
581 }
582 if let Err(err) = to_server_tx.send(Message::DataChannelData(DataChannelData {
583 channel_id: data_channel.id,
584 data: buf[0..n].to_vec()
585 }))
586 .await {
587 trace!("Failed to send data to server for {:?}: {:#}", data_channel, err);
588 break;
589 }
590 },
591 Err(e) => {
592 return Err(e).context("Failed to read from local service");
593 }
594 }
595 }
596
597 data_result = data_rx.recv() => {
599 match data_result {
600 Some(data) => {
601 trace!("Received {} bytes from server for {:?}", data.data.len(), data_channel);
602 local_stream.write_all(&data.data).await.context("Failed to write data to local service")?;
603 to_server_tx.send(Message::DataChannelAck(
604 DataChannelAck {
605 channel_id: data_channel.id,
606 consumed: data.data.len() as u32
607 }
608 )).await.with_context(|| "Failed to send TCP traffic ack to the server")?;
609 },
610 None => {
611 trace!("EOF received from server for {:?}", data_channel);
612 break;
613 }
614 }
615 }
616
617 _ = data_channel.closed() => {
618 trace!("Data channel {} closed", data_channel.id);
619 break;
620 }
621 }
622 }
623 Ok(())
624}
625
626type UdpPortMap = Arc<DashMap<SocketAddr, mpsc::Sender<Bytes>>>;
628
629async fn handle_udp_data_channel(
630 data_channel: Arc<DataChannel>,
631 local_addr: String,
632 to_server_tx: FairSender<Message>,
633 mut data_rx: mpsc::Receiver<Data>,
634) -> Result<()> {
635 trace!(
636 "Handling client UDP channel {:?} to {}",
637 data_channel,
638 local_addr
639 );
640
641 let port_map: UdpPortMap = Arc::new(DashMap::new());
642
643 loop {
644 let data_channel = data_channel.clone();
645 tokio::select! {
647 data = data_rx.recv() => {
648 match data {
649 Some(data) => {
650 let external_addr = data.socket_addr.unwrap();
651
652 if !port_map.contains_key(&external_addr) {
653 match udp_connect(&local_addr).await {
658 Ok(s) => {
659 let (to_service_tx, to_service_rx) = mpsc::channel(DATA_CHANNEL_SIZE);
660 port_map.insert(external_addr, to_service_tx);
661 tokio::spawn(run_udp_forwarder(
662 s,
663 to_service_rx,
664 to_server_tx.clone(),
665 external_addr,
666 data_channel,
667 port_map.clone(),
668 ));
669 }
670 Err(e) => {
671 error!(
672 "Failed to create UDP forwarder for {}: {:#}",
673 external_addr, e
674 );
675 }
676 }
677 }
678
679 if let Some(tx) = port_map.get(&external_addr) {
681 let _ = tx.send(data.data).await;
682 }
683 }
684 None => {
685 trace!("EOF received from server for UDP {:?}", data_channel);
686 break;
687 }
688 }
689 }
690 _ = data_channel.closed() => {
691 trace!("Data channel {} closed", data_channel.id);
692 break;
693 }
694 }
695 }
696 Ok(())
697}
698
699async fn run_udp_forwarder(
701 s: UdpSocket,
702 mut to_service_rx: mpsc::Receiver<Bytes>,
703 to_server_tx: FairSender<Message>,
704 from: SocketAddr,
705 data_channel: Arc<DataChannel>,
706 port_map: UdpPortMap,
707) -> Result<()> {
708 trace!("UDP forwarder created for {} on {:?}", from, data_channel);
709 let mut buf = vec![0u8; UDP_BUFFER_SIZE];
710
711 loop {
712 tokio::select! {
713 data = to_service_rx.recv() => {
715 if let Some(data) = data {
716 s.send(&data).await.with_context(|| "Failed to send UDP traffic to the service")?;
717 to_server_tx.send(Message::DataChannelAck(
718 DataChannelAck {
719 channel_id: data_channel.id,
720 consumed: data.len() as u32
721 }
722 )).await.with_context(|| "Failed to send UDP traffic ack to the server")?;
723 } else {
724 break;
725 }
726 },
727
728 val = s.recv(&mut buf) => {
730 let len = match val {
731 Ok(v) => v,
732 Err(_) => break
733 };
734
735 if data_channel.wait_for_capacity(len as u32).await.is_err() {
736 break;
737 }
738
739 to_server_tx.send(Message::DataChannelDataUdp(
740 DataChannelDataUdp {
741 channel_id: data_channel.id,
742 data: buf[..len].to_vec(),
743 socket_addr: Some(socket_addr_to_proto(&from)),
744 })).await.with_context(|| "Failed to send UDP traffic to the server")?;
745 },
746
747 _ = time::sleep(Duration::from_secs(UDP_TIMEOUT)) => {
749 break;
750 }
751
752 _ = data_channel.closed() => {
753 trace!("Data channel {} closed", data_channel.id);
754 break;
755 }
756 }
757 }
758
759 port_map.remove(&from);
760
761 debug!("UDP forwarder dropped for {} on {:?}", from, data_channel);
762 Ok(())
763}