1use super::protocol::{self, PortPrivacy};
6use crate::auth;
7use crate::constants::{IS_INTERACTIVE_CLI, PROTOCOL_VERSION_TAG, TUNNEL_SERVICE_USER_AGENT};
8use crate::state::{LauncherPaths, PersistedState};
9use crate::util::errors::{
10 wrap, AnyError, CodeError, DevTunnelError, InvalidTunnelName, TunnelCreationFailed,
11 WrappedError,
12};
13use crate::util::input::prompt_placeholder;
14use crate::{debug, info, log, spanf, trace, warning};
15use async_trait::async_trait;
16use futures::future::BoxFuture;
17use futures::{FutureExt, TryFutureExt};
18use lazy_static::lazy_static;
19use rand::prelude::IteratorRandom;
20use regex::Regex;
21use reqwest::StatusCode;
22use serde::{Deserialize, Serialize};
23use std::sync::{Arc, Mutex};
24use std::time::Duration;
25use tokio::sync::{mpsc, watch};
26use tunnels::connections::{ForwardedPortConnection, RelayTunnelHost};
27use tunnels::contracts::{
28 Tunnel, TunnelAccessControl, TunnelPort, TunnelRelayTunnelEndpoint, PORT_TOKEN,
29 TUNNEL_ACCESS_SCOPES_CONNECT, TUNNEL_PROTOCOL_AUTO,
30};
31use tunnels::management::{
32 new_tunnel_management, HttpError, TunnelLocator, TunnelManagementClient, TunnelRequestOptions,
33 NO_REQUEST_OPTIONS,
34};
35
36static TUNNEL_COUNT_LIMIT_NAME: &str = "TunnelsPerUserPerLocation";
37
38#[allow(dead_code)]
39mod tunnel_flags {
40 use crate::{log, tunnels::wsl_detect::is_wsl_installed};
41
42 pub const IS_WSL_INSTALLED: u32 = 1 << 0;
43 pub const IS_WINDOWS: u32 = 1 << 1;
44 pub const IS_LINUX: u32 = 1 << 2;
45 pub const IS_MACOS: u32 = 1 << 3;
46
47 pub fn create(log: &log::Logger) -> String {
49 let mut flags = 0;
50
51 #[cfg(windows)]
52 {
53 flags |= IS_WINDOWS;
54 }
55 #[cfg(target_os = "linux")]
56 {
57 flags |= IS_LINUX;
58 }
59 #[cfg(target_os = "macos")]
60 {
61 flags |= IS_MACOS;
62 }
63
64 if is_wsl_installed(log) {
65 flags |= IS_WSL_INSTALLED;
66 }
67
68 format!("_flag{}", flags)
69 }
70}
71
72#[derive(Clone, Serialize, Deserialize)]
73pub struct PersistedTunnel {
74 pub name: String,
75 pub id: String,
76 pub cluster: String,
77}
78
79impl PersistedTunnel {
80 pub fn into_locator(self) -> TunnelLocator {
81 TunnelLocator::ID {
82 cluster: self.cluster,
83 id: self.id,
84 }
85 }
86 pub fn locator(&self) -> TunnelLocator {
87 TunnelLocator::ID {
88 cluster: self.cluster.clone(),
89 id: self.id.clone(),
90 }
91 }
92}
93
94#[async_trait]
95trait AccessTokenProvider: Send + Sync {
96 async fn refresh_token(&self) -> Result<String, WrappedError>;
98
99 fn keep_alive(&self) -> BoxFuture<'static, Result<(), AnyError>>;
103}
104
105struct StaticAccessTokenProvider(String);
107
108impl StaticAccessTokenProvider {
109 pub fn new(token: String) -> Self {
110 Self(token)
111 }
112}
113
114#[async_trait]
115impl AccessTokenProvider for StaticAccessTokenProvider {
116 async fn refresh_token(&self) -> Result<String, WrappedError> {
117 Ok(self.0.clone())
118 }
119
120 fn keep_alive(&self) -> BoxFuture<'static, Result<(), AnyError>> {
121 futures::future::pending().boxed()
122 }
123}
124
125struct LookupAccessTokenProvider {
127 auth: auth::Auth,
128 client: TunnelManagementClient,
129 locator: TunnelLocator,
130 log: log::Logger,
131 initial_token: Arc<Mutex<Option<String>>>,
132}
133
134impl LookupAccessTokenProvider {
135 pub fn new(
136 auth: auth::Auth,
137 client: TunnelManagementClient,
138 locator: TunnelLocator,
139 log: log::Logger,
140 initial_token: Option<String>,
141 ) -> Self {
142 Self {
143 auth,
144 client,
145 locator,
146 log,
147 initial_token: Arc::new(Mutex::new(initial_token)),
148 }
149 }
150}
151
152#[async_trait]
153impl AccessTokenProvider for LookupAccessTokenProvider {
154 async fn refresh_token(&self) -> Result<String, WrappedError> {
155 if let Some(token) = self.initial_token.lock().unwrap().take() {
156 return Ok(token);
157 }
158
159 let tunnel_lookup = spanf!(
160 self.log,
161 self.log.span("dev-tunnel.tag.get"),
162 self.client.get_tunnel(
163 &self.locator,
164 &TunnelRequestOptions {
165 token_scopes: vec!["host".to_string()],
166 ..Default::default()
167 }
168 )
169 );
170
171 trace!(self.log, "Successfully refreshed access token");
172
173 match tunnel_lookup {
174 Ok(tunnel) => Ok(get_host_token_from_tunnel(&tunnel)),
175 Err(e) => Err(wrap(e, "failed to lookup tunnel for host token")),
176 }
177 }
178
179 fn keep_alive(&self) -> BoxFuture<'static, Result<(), AnyError>> {
180 let auth = self.auth.clone();
181 auth.keep_token_alive().boxed()
182 }
183}
184
185#[derive(Clone)]
186pub struct DevTunnels {
187 auth: auth::Auth,
188 log: log::Logger,
189 launcher_tunnel: PersistedState<Option<PersistedTunnel>>,
190 client: TunnelManagementClient,
191 tag: &'static str,
192}
193
194pub struct ActiveTunnel {
196 pub name: String,
198 pub id: String,
200 manager: ActiveTunnelManager,
201}
202
203impl ActiveTunnel {
204 pub async fn close(&mut self) -> Result<(), AnyError> {
206 self.manager.kill().await?;
207 Ok(())
208 }
209
210 pub async fn add_port_direct(
212 &mut self,
213 port_number: u16,
214 ) -> Result<mpsc::UnboundedReceiver<ForwardedPortConnection>, AnyError> {
215 let port = self.manager.add_port_direct(port_number).await?;
216 Ok(port)
217 }
218
219 pub async fn add_port_tcp(
221 &self,
222 port_number: u16,
223 privacy: PortPrivacy,
224 ) -> Result<(), AnyError> {
225 self.manager.add_port_tcp(port_number, privacy).await?;
226 Ok(())
227 }
228
229 pub async fn remove_port(&self, port_number: u16) -> Result<(), AnyError> {
231 self.manager.remove_port(port_number).await?;
232 Ok(())
233 }
234
235 pub fn get_port_format(&self) -> Result<String, AnyError> {
237 if let Some(details) = &*self.manager.endpoint_rx.borrow() {
238 return details
239 .as_ref()
240 .map(|r| {
241 r.base
242 .port_uri_format
243 .clone()
244 .expect("expected to have port format")
245 })
246 .map_err(|e| e.clone().into());
247 }
248
249 Err(CodeError::NoTunnelEndpoint.into())
250 }
251
252 pub fn get_port_uri(&self, port: u16) -> Result<String, AnyError> {
254 self.get_port_format()
255 .map(|f| f.replace(PORT_TOKEN, &port.to_string()))
256 }
257
258 pub fn status(&self) -> StatusLock {
260 self.manager.get_status()
261 }
262}
263
264const VSCODE_CLI_TUNNEL_TAG: &str = "vscode-server-launcher";
265const VSCODE_CLI_FORWARDING_TAG: &str = "vscode-port-forward";
266const OWNED_TUNNEL_TAGS: &[&str] = &[VSCODE_CLI_TUNNEL_TAG, VSCODE_CLI_FORWARDING_TAG];
267const MAX_TUNNEL_NAME_LENGTH: usize = 20;
268
269fn get_host_token_from_tunnel(tunnel: &Tunnel) -> String {
270 tunnel
271 .access_tokens
272 .as_ref()
273 .expect("expected to have access tokens")
274 .get("host")
275 .expect("expected to have host token")
276 .to_string()
277}
278
279fn is_valid_name(name: &str) -> Result<(), InvalidTunnelName> {
280 if name.len() > MAX_TUNNEL_NAME_LENGTH {
281 return Err(InvalidTunnelName(format!(
282 "Names cannot be longer than {} characters. Please try a different name.",
283 MAX_TUNNEL_NAME_LENGTH
284 )));
285 }
286
287 let re = Regex::new(r"^([\w-]+)$").unwrap();
288
289 if !re.is_match(name) {
290 return Err(InvalidTunnelName(
291 "Names can only contain letters, numbers, and '-'. Spaces, commas, and all other special characters are not allowed. Please try a different name.".to_string()
292 ));
293 }
294
295 Ok(())
296}
297
298lazy_static! {
299 static ref HOST_TUNNEL_REQUEST_OPTIONS: TunnelRequestOptions = TunnelRequestOptions {
300 include_ports: true,
301 token_scopes: vec!["host".to_string()],
302 ..Default::default()
303 };
304}
305
306#[derive(Clone, Debug)]
308pub struct ExistingTunnel {
309 pub tunnel_name: Option<String>,
311
312 pub host_token: String,
314
315 pub tunnel_id: String,
317
318 pub cluster: String,
320}
321
322impl DevTunnels {
323 pub fn new_port_forwarding(
325 log: &log::Logger,
326 auth: auth::Auth,
327 paths: &LauncherPaths,
328 ) -> DevTunnels {
329 let mut client = new_tunnel_management(&TUNNEL_SERVICE_USER_AGENT);
330 client.authorization_provider(auth.clone());
331
332 DevTunnels {
333 auth,
334 log: log.clone(),
335 client: client.into(),
336 launcher_tunnel: PersistedState::new(paths.root().join("port_forwarding_tunnel.json")),
337 tag: VSCODE_CLI_FORWARDING_TAG,
338 }
339 }
340
341 pub fn new_remote_tunnel(
343 log: &log::Logger,
344 auth: auth::Auth,
345 paths: &LauncherPaths,
346 ) -> DevTunnels {
347 let mut client = new_tunnel_management(&TUNNEL_SERVICE_USER_AGENT);
348 client.authorization_provider(auth.clone());
349
350 DevTunnels {
351 auth,
352 log: log.clone(),
353 client: client.into(),
354 launcher_tunnel: PersistedState::new(paths.root().join("code_tunnel.json")),
355 tag: VSCODE_CLI_TUNNEL_TAG,
356 }
357 }
358
359 pub async fn remove_tunnel(&mut self) -> Result<(), AnyError> {
360 let tunnel = match self.launcher_tunnel.load() {
361 Some(t) => t,
362 None => {
363 return Ok(());
364 }
365 };
366
367 spanf!(
368 self.log,
369 self.log.span("dev-tunnel.delete"),
370 self.client
371 .delete_tunnel(&tunnel.into_locator(), NO_REQUEST_OPTIONS)
372 )
373 .map_err(|e| wrap(e, "failed to execute `tunnel delete`"))?;
374
375 self.launcher_tunnel.save(None)?;
376 Ok(())
377 }
378
379 pub async fn rename_tunnel(&mut self, name: &str) -> Result<(), AnyError> {
381 self.update_tunnel_name(self.launcher_tunnel.load(), name)
382 .await
383 .map(|_| ())
384 }
385
386 async fn update_tunnel_name(
389 &mut self,
390 persisted: Option<PersistedTunnel>,
391 name: &str,
392 ) -> Result<(Tunnel, PersistedTunnel), AnyError> {
393 let name = name.to_ascii_lowercase();
394
395 let (mut full_tunnel, mut persisted, is_new) = match persisted {
396 Some(persisted) => {
397 debug!(
398 self.log,
399 "Found a persisted tunnel, seeing if the name matches..."
400 );
401 self.get_or_create_tunnel(persisted, Some(&name), NO_REQUEST_OPTIONS)
402 .await
403 }
404 None => {
405 debug!(self.log, "Creating a new tunnel with the requested name");
406 self.create_tunnel(&name, NO_REQUEST_OPTIONS)
407 .await
408 .map(|(pt, t)| (t, pt, true))
409 }
410 }?;
411
412 let desired_tags = self.get_labels(&name);
413 if is_new || vec_eq_as_set(&full_tunnel.labels, &desired_tags) {
414 return Ok((full_tunnel, persisted));
415 }
416
417 debug!(self.log, "Tunnel name changed, applying updates...");
418
419 full_tunnel.labels = desired_tags;
420
421 let updated_tunnel = spanf!(
422 self.log,
423 self.log.span("dev-tunnel.tag.update"),
424 self.client.update_tunnel(&full_tunnel, NO_REQUEST_OPTIONS)
425 )
426 .map_err(|e| wrap(e, "failed to rename tunnel"))?;
427
428 persisted.name = name;
429 self.launcher_tunnel.save(Some(persisted.clone()))?;
430
431 Ok((updated_tunnel, persisted))
432 }
433
434 async fn get_or_create_tunnel(
438 &mut self,
439 persisted: PersistedTunnel,
440 create_with_new_name: Option<&str>,
441 options: &TunnelRequestOptions,
442 ) -> Result<(Tunnel, PersistedTunnel, bool), AnyError> {
443 let tunnel_lookup = spanf!(
444 self.log,
445 self.log.span("dev-tunnel.tag.get"),
446 self.client.get_tunnel(&persisted.locator(), options)
447 );
448
449 match tunnel_lookup {
450 Ok(ft) => Ok((ft, persisted, false)),
451 Err(HttpError::ResponseError(e))
452 if e.status_code == StatusCode::NOT_FOUND
453 || e.status_code == StatusCode::FORBIDDEN =>
454 {
455 let (persisted, tunnel) = self
456 .create_tunnel(create_with_new_name.unwrap_or(&persisted.name), options)
457 .await?;
458 Ok((tunnel, persisted, true))
459 }
460 Err(e) => Err(wrap(e, "failed to lookup tunnel").into()),
461 }
462 }
463
464 pub async fn start_new_launcher_tunnel(
467 &mut self,
468 preferred_name: Option<&str>,
469 use_random_name: bool,
470 preserve_ports: &[u16],
471 ) -> Result<ActiveTunnel, AnyError> {
472 let (mut tunnel, persisted) = match self.launcher_tunnel.load() {
473 Some(mut persisted) => {
474 if let Some(preferred_name) = preferred_name.map(|n| n.to_ascii_lowercase()) {
475 if persisted.name.to_ascii_lowercase() != preferred_name {
476 (_, persisted) = self
477 .update_tunnel_name(Some(persisted), &preferred_name)
478 .await?;
479 }
480 }
481
482 let (tunnel, persisted, _) = self
483 .get_or_create_tunnel(persisted, None, &HOST_TUNNEL_REQUEST_OPTIONS)
484 .await?;
485 (tunnel, persisted)
486 }
487 None => {
488 debug!(self.log, "No code server tunnel found, creating new one");
489 let name = self
490 .get_name_for_tunnel(preferred_name, use_random_name)
491 .await?;
492 let (persisted, full_tunnel) = self
493 .create_tunnel(&name, &HOST_TUNNEL_REQUEST_OPTIONS)
494 .await?;
495 (full_tunnel, persisted)
496 }
497 };
498
499 tunnel = self
500 .sync_tunnel_tags(
501 &self.client,
502 &persisted.name,
503 tunnel,
504 &HOST_TUNNEL_REQUEST_OPTIONS,
505 )
506 .await?;
507
508 let locator = TunnelLocator::try_from(&tunnel).unwrap();
509 let host_token = get_host_token_from_tunnel(&tunnel);
510
511 for port_to_delete in tunnel
512 .ports
513 .iter()
514 .filter(|p: &&TunnelPort| !preserve_ports.contains(&p.port_number))
515 {
516 let output_fut = self.client.delete_tunnel_port(
517 &locator,
518 port_to_delete.port_number,
519 NO_REQUEST_OPTIONS,
520 );
521 spanf!(
522 self.log,
523 self.log.span("dev-tunnel.port.delete"),
524 output_fut
525 )
526 .map_err(|e| wrap(e, "failed to delete port"))?;
527 }
528
529 for endpoint in tunnel.endpoints {
531 let fut = self.client.delete_tunnel_endpoints(
532 &locator,
533 &endpoint.host_id,
534 NO_REQUEST_OPTIONS,
535 );
536
537 spanf!(self.log, self.log.span("dev-tunnel.endpoint.prune"), fut)
538 .map_err(|e| wrap(e, "failed to prune tunnel endpoint"))?;
539 }
540
541 self.start_tunnel(
542 locator.clone(),
543 &persisted,
544 self.client.clone(),
545 LookupAccessTokenProvider::new(
546 self.auth.clone(),
547 self.client.clone(),
548 locator,
549 self.log.clone(),
550 Some(host_token),
551 ),
552 )
553 .await
554 }
555
556 async fn create_tunnel(
557 &mut self,
558 name: &str,
559 options: &TunnelRequestOptions,
560 ) -> Result<(PersistedTunnel, Tunnel), AnyError> {
561 info!(self.log, "Creating tunnel with the name: {}", name);
562
563 let tunnel = match self.get_existing_tunnel_with_name(name).await? {
564 Some(e) => {
565 let loc = TunnelLocator::try_from(&e).unwrap();
566 info!(self.log, "Adopting existing tunnel (ID={:?})", loc);
567 spanf!(
568 self.log,
569 self.log.span("dev-tunnel.tag.get"),
570 self.client.get_tunnel(&loc, &HOST_TUNNEL_REQUEST_OPTIONS)
571 )
572 .map_err(|e| wrap(e, "failed to lookup tunnel"))?
573 }
574 None => loop {
575 let result = spanf!(
576 self.log,
577 self.log.span("dev-tunnel.create"),
578 self.client.create_tunnel(
579 Tunnel {
580 labels: self.get_labels(name),
581 ..Default::default()
582 },
583 options
584 )
585 );
586
587 match result {
588 Err(HttpError::ResponseError(e))
589 if e.status_code == StatusCode::TOO_MANY_REQUESTS =>
590 {
591 if let Some(d) = e.get_details() {
592 let detail = d.detail.unwrap_or_else(|| "unknown".to_string());
593 if detail.contains(TUNNEL_COUNT_LIMIT_NAME)
594 && self.try_recycle_tunnel().await?
595 {
596 continue;
597 }
598
599 return Err(AnyError::from(TunnelCreationFailed(
600 name.to_string(),
601 detail,
602 )));
603 }
604
605 return Err(AnyError::from(TunnelCreationFailed(
606 name.to_string(),
607 "You have exceeded a limit for the port fowarding service. Please remove other machines before trying to add this machine.".to_string(),
608 )));
609 }
610 Err(e) => {
611 return Err(AnyError::from(TunnelCreationFailed(
612 name.to_string(),
613 format!("{:?}", e),
614 )))
615 }
616 Ok(t) => break t,
617 }
618 },
619 };
620
621 let pt = PersistedTunnel {
622 cluster: tunnel.cluster_id.clone().unwrap(),
623 id: tunnel.tunnel_id.clone().unwrap(),
624 name: name.to_string(),
625 };
626
627 self.launcher_tunnel.save(Some(pt.clone()))?;
628 Ok((pt, tunnel))
629 }
630
631 fn get_labels(&self, name: &str) -> Vec<String> {
633 vec![
634 name.to_string(),
635 PROTOCOL_VERSION_TAG.to_string(),
636 self.tag.to_string(),
637 tunnel_flags::create(&self.log),
638 ]
639 }
640
641 async fn sync_tunnel_tags(
644 &self,
645 client: &TunnelManagementClient,
646 name: &str,
647 tunnel: Tunnel,
648 options: &TunnelRequestOptions,
649 ) -> Result<Tunnel, AnyError> {
650 let new_labels = self.get_labels(name);
651 if vec_eq_as_set(&tunnel.labels, &new_labels) {
652 return Ok(tunnel);
653 }
654
655 debug!(
656 self.log,
657 "Updating tunnel tags {} -> {}",
658 tunnel.labels.join(", "),
659 new_labels.join(", ")
660 );
661
662 let tunnel_update = Tunnel {
663 labels: new_labels,
664 tunnel_id: tunnel.tunnel_id.clone(),
665 cluster_id: tunnel.cluster_id.clone(),
666 ..Default::default()
667 };
668
669 let result = spanf!(
670 self.log,
671 self.log.span("dev-tunnel.protocol-tag-update"),
672 client.update_tunnel(&tunnel_update, options)
673 );
674
675 result.map_err(|e| wrap(e, "tunnel tag update failed").into())
676 }
677
678 async fn try_recycle_tunnel(&mut self) -> Result<bool, AnyError> {
681 trace!(
682 self.log,
683 "Tunnel limit hit, trying to recycle an old tunnel"
684 );
685
686 let existing_tunnels = self.list_tunnels_with_tag(OWNED_TUNNEL_TAGS).await?;
687
688 let recyclable = existing_tunnels
689 .iter()
690 .filter(|t| {
691 t.status
692 .as_ref()
693 .and_then(|s| s.host_connection_count.as_ref())
694 .map(|c| c.get_count())
695 .unwrap_or(0) == 0
696 })
697 .choose(&mut rand::thread_rng());
698
699 match recyclable {
700 Some(tunnel) => {
701 trace!(self.log, "Recycling tunnel ID {:?}", tunnel.tunnel_id);
702 spanf!(
703 self.log,
704 self.log.span("dev-tunnel.delete"),
705 self.client
706 .delete_tunnel(&tunnel.try_into().unwrap(), NO_REQUEST_OPTIONS)
707 )
708 .map_err(|e| wrap(e, "failed to execute `tunnel delete`"))?;
709 Ok(true)
710 }
711 None => {
712 trace!(self.log, "No tunnels available to recycle");
713 Ok(false)
714 }
715 }
716 }
717
718 async fn list_tunnels_with_tag(
719 &mut self,
720 tags: &[&'static str],
721 ) -> Result<Vec<Tunnel>, AnyError> {
722 let tunnels = spanf!(
723 self.log,
724 self.log.span("dev-tunnel.listall"),
725 self.client.list_all_tunnels(&TunnelRequestOptions {
726 labels: tags.iter().map(|t| t.to_string()).collect(),
727 ..Default::default()
728 })
729 )
730 .map_err(|e| wrap(e, "error listing current tunnels"))?;
731
732 Ok(tunnels)
733 }
734
735 async fn get_existing_tunnel_with_name(&self, name: &str) -> Result<Option<Tunnel>, AnyError> {
736 let existing: Vec<Tunnel> = spanf!(
737 self.log,
738 self.log.span("dev-tunnel.rename.search"),
739 self.client.list_all_tunnels(&TunnelRequestOptions {
740 labels: vec![self.tag.to_string(), name.to_string()],
741 require_all_labels: true,
742 limit: 1,
743 include_ports: true,
744 token_scopes: vec!["host".to_string()],
745 ..Default::default()
746 })
747 )
748 .map_err(|e| wrap(e, "failed to list existing tunnels"))?;
749
750 Ok(existing.into_iter().next())
751 }
752
753 fn get_placeholder_name() -> String {
754 let mut n = clean_hostname_for_tunnel(&gethostname::gethostname().to_string_lossy());
755 n.make_ascii_lowercase();
756 n.truncate(MAX_TUNNEL_NAME_LENGTH);
757 n
758 }
759
760 async fn get_name_for_tunnel(
761 &mut self,
762 preferred_name: Option<&str>,
763 mut use_random_name: bool,
764 ) -> Result<String, AnyError> {
765 let existing_tunnels = self.list_tunnels_with_tag(&[self.tag]).await?;
766 let is_name_free = |n: &str| {
767 !existing_tunnels.iter().any(|v| {
768 v.status
769 .as_ref()
770 .and_then(|s| s.host_connection_count.as_ref().map(|c| c.get_count()))
771 .unwrap_or(0) > 0 && v.labels.iter().any(|t| t == n)
772 })
773 };
774
775 if let Some(machine_name) = preferred_name {
776 let name = machine_name.to_ascii_lowercase();
777 if let Err(e) = is_valid_name(&name) {
778 info!(self.log, "{} is an invalid name", e);
779 return Err(AnyError::from(wrap(e, "invalid name")));
780 }
781 if is_name_free(&name) {
782 return Ok(name);
783 }
784 info!(
785 self.log,
786 "{} is already taken, using a random name instead", &name
787 );
788 use_random_name = true;
789 }
790
791 let mut placeholder_name = Self::get_placeholder_name();
792 if !is_name_free(&placeholder_name) {
793 for i in 2.. {
794 let fixed_name = format!("{}{}", placeholder_name, i);
795 if is_name_free(&fixed_name) {
796 placeholder_name = fixed_name;
797 break;
798 }
799 }
800 }
801
802 if use_random_name || !*IS_INTERACTIVE_CLI {
803 return Ok(placeholder_name);
804 }
805
806 loop {
807 let mut name = prompt_placeholder(
808 "What would you like to call this machine?",
809 &placeholder_name,
810 )?;
811
812 name.make_ascii_lowercase();
813
814 if let Err(e) = is_valid_name(&name) {
815 info!(self.log, "{}", e);
816 continue;
817 }
818
819 if is_name_free(&name) {
820 return Ok(name);
821 }
822
823 info!(self.log, "The name {} is already in use", name);
824 }
825 }
826
827 pub async fn start_existing_tunnel(
829 &mut self,
830 tunnel: ExistingTunnel,
831 ) -> Result<ActiveTunnel, AnyError> {
832 let tunnel_details = PersistedTunnel {
833 name: match tunnel.tunnel_name {
834 Some(n) => n,
835 None => Self::get_placeholder_name(),
836 },
837 id: tunnel.tunnel_id,
838 cluster: tunnel.cluster,
839 };
840
841 let mut mgmt = self.client.build();
842 mgmt.authorization(tunnels::management::Authorization::Tunnel(
843 tunnel.host_token.clone(),
844 ));
845
846 let client = mgmt.into();
847 self.sync_tunnel_tags(
848 &client,
849 &tunnel_details.name,
850 Tunnel {
851 cluster_id: Some(tunnel_details.cluster.clone()),
852 tunnel_id: Some(tunnel_details.id.clone()),
853 ..Default::default()
854 },
855 &HOST_TUNNEL_REQUEST_OPTIONS,
856 )
857 .await?;
858
859 self.start_tunnel(
860 tunnel_details.locator(),
861 &tunnel_details,
862 client,
863 StaticAccessTokenProvider::new(tunnel.host_token),
864 )
865 .await
866 }
867
868 async fn start_tunnel(
869 &mut self,
870 locator: TunnelLocator,
871 tunnel_details: &PersistedTunnel,
872 client: TunnelManagementClient,
873 access_token: impl AccessTokenProvider + 'static,
874 ) -> Result<ActiveTunnel, AnyError> {
875 let mut manager = ActiveTunnelManager::new(self.log.clone(), client, locator, access_token);
876
877 let endpoint_result = spanf!(
878 self.log,
879 self.log.span("dev-tunnel.serve.callback"),
880 manager.get_endpoint()
881 );
882
883 let endpoint = match endpoint_result {
884 Ok(endpoint) => endpoint,
885 Err(e) => {
886 error!(self.log, "Error connecting to tunnel endpoint: {}", e);
887 manager.kill().await.ok();
888 return Err(e);
889 }
890 };
891
892 debug!(self.log, "Connected to tunnel endpoint: {:?}", endpoint);
893
894 Ok(ActiveTunnel {
895 name: tunnel_details.name.clone(),
896 id: tunnel_details.id.clone(),
897 manager,
898 })
899 }
900}
901
902#[derive(Clone, Default)]
903pub struct StatusLock(Arc<std::sync::Mutex<protocol::singleton::Status>>);
904
905impl StatusLock {
906 fn succeed(&self) {
907 let mut status = self.0.lock().unwrap();
908 status.tunnel = protocol::singleton::TunnelState::Connected;
909 status.last_connected_at = Some(chrono::Utc::now());
910 }
911
912 fn fail(&self, reason: String) {
913 let mut status = self.0.lock().unwrap();
914 if let protocol::singleton::TunnelState::Connected = status.tunnel {
915 status.last_disconnected_at = Some(chrono::Utc::now());
916 status.tunnel = protocol::singleton::TunnelState::Disconnected;
917 }
918 status.last_fail_reason = Some(reason);
919 }
920
921 pub fn read(&self) -> protocol::singleton::Status {
922 let status = self.0.lock().unwrap();
923 status.clone()
924 }
925}
926
927struct ActiveTunnelManager {
928 close_tx: Option<mpsc::Sender<()>>,
929 endpoint_rx: watch::Receiver<Option<Result<TunnelRelayTunnelEndpoint, WrappedError>>>,
930 relay: Arc<tokio::sync::Mutex<RelayTunnelHost>>,
931 status: StatusLock,
932}
933
934impl ActiveTunnelManager {
935 pub fn new(
936 log: log::Logger,
937 mgmt: TunnelManagementClient,
938 locator: TunnelLocator,
939 access_token: impl AccessTokenProvider + 'static,
940 ) -> ActiveTunnelManager {
941 let (endpoint_tx, endpoint_rx) = watch::channel(None);
942 let (close_tx, close_rx) = mpsc::channel(1);
943
944 let relay = Arc::new(tokio::sync::Mutex::new(RelayTunnelHost::new(locator, mgmt)));
945 let relay_spawned = relay.clone();
946
947 let status = StatusLock::default();
948
949 let status_spawned = status.clone();
950 tokio::spawn(async move {
951 ActiveTunnelManager::spawn_tunnel(
952 log,
953 relay_spawned,
954 close_rx,
955 endpoint_tx,
956 access_token,
957 status_spawned,
958 )
959 .await;
960 });
961
962 ActiveTunnelManager {
963 endpoint_rx,
964 relay,
965 close_tx: Some(close_tx),
966 status,
967 }
968 }
969
970 pub fn get_status(&self) -> StatusLock {
972 self.status.clone()
973 }
974
975 #[allow(dead_code)] pub async fn add_port_tcp(
978 &self,
979 port_number: u16,
980 privacy: PortPrivacy,
981 ) -> Result<(), WrappedError> {
982 self.relay
983 .lock()
984 .await
985 .add_port(&TunnelPort {
986 port_number,
987 protocol: Some(TUNNEL_PROTOCOL_AUTO.to_owned()),
988 access_control: Some(privacy_to_tunnel_acl(privacy)),
989 ..Default::default()
990 })
991 .await
992 .map_err(|e| wrap(e, "error adding port to relay"))?;
993 Ok(())
994 }
995
996 pub async fn add_port_direct(
998 &self,
999 port_number: u16,
1000 ) -> Result<mpsc::UnboundedReceiver<ForwardedPortConnection>, WrappedError> {
1001 self.relay
1002 .lock()
1003 .await
1004 .add_port_raw(&TunnelPort {
1005 port_number,
1006 protocol: Some(TUNNEL_PROTOCOL_AUTO.to_owned()),
1007 access_control: Some(privacy_to_tunnel_acl(PortPrivacy::Private)),
1008 ..Default::default()
1009 })
1010 .await
1011 .map_err(|e| wrap(e, "error adding port to relay"))
1012 }
1013
1014 pub async fn remove_port(&self, port_number: u16) -> Result<(), WrappedError> {
1016 self.relay
1017 .lock()
1018 .await
1019 .remove_port(port_number)
1020 .await
1021 .map_err(|e| wrap(e, "error remove port from relay"))
1022 }
1023
1024 pub async fn get_endpoint(&mut self) -> Result<TunnelRelayTunnelEndpoint, AnyError> {
1027 loop {
1028 if let Some(details) = &*self.endpoint_rx.borrow() {
1029 return details.clone().map_err(AnyError::from);
1030 }
1031
1032 if self.endpoint_rx.changed().await.is_err() {
1033 return Err(DevTunnelError("tunnel creation cancelled".to_string()).into());
1034 }
1035 }
1036 }
1037
1038 pub async fn kill(&mut self) -> Result<(), AnyError> {
1041 if let Some(tx) = self.close_tx.take() {
1042 drop(tx);
1043 }
1044
1045 self.relay
1046 .lock()
1047 .await
1048 .unregister()
1049 .await
1050 .map_err(|e| wrap(e, "error unregistering relay"))?;
1051
1052 while self.endpoint_rx.changed().await.is_ok() {}
1053
1054 Ok(())
1055 }
1056
1057 async fn spawn_tunnel(
1058 log: log::Logger,
1059 relay: Arc<tokio::sync::Mutex<RelayTunnelHost>>,
1060 mut close_rx: mpsc::Receiver<()>,
1061 endpoint_tx: watch::Sender<Option<Result<TunnelRelayTunnelEndpoint, WrappedError>>>,
1062 access_token_provider: impl AccessTokenProvider + 'static,
1063 status: StatusLock,
1064 ) {
1065 let mut token_ka = access_token_provider.keep_alive();
1066 let mut backoff = Backoff::new(Duration::from_secs(5), Duration::from_secs(120));
1067
1068 macro_rules! fail {
1069 ($e: expr, $msg: expr) => {
1070 let fmt = format!("{}: {}", $msg, $e);
1071 warning!(log, &fmt);
1072 status.fail(fmt);
1073 endpoint_tx.send(Some(Err($e))).ok();
1074 backoff.delay().await;
1075 };
1076 }
1077
1078 loop {
1079 debug!(log, "Starting tunnel to server...");
1080
1081 let access_token = match access_token_provider.refresh_token().await {
1082 Ok(t) => t,
1083 Err(e) => {
1084 fail!(e, "Error refreshing access token, will retry");
1085 continue;
1086 }
1087 };
1088
1089 let handle_res = {
1092 let mut relay = relay.lock().await;
1093 relay
1094 .connect(&access_token)
1095 .await
1096 .map_err(|e| wrap(e, "error connecting to tunnel"))
1097 };
1098
1099 let mut handle = match handle_res {
1100 Ok(handle) => handle,
1101 Err(e) => {
1102 fail!(e, "Error connecting to relay, will retry");
1103 continue;
1104 }
1105 };
1106
1107 backoff.reset();
1108 status.succeed();
1109 endpoint_tx.send(Some(Ok(handle.endpoint().clone()))).ok();
1110
1111 tokio::select! {
1112 res = (&mut handle).map_err(|e| wrap(e, "error from tunnel connection")) => {
1115 if let Err(e) = res {
1116 fail!(e, "Tunnel exited unexpectedly, reconnecting");
1117 } else {
1118 warning!(log, "Tunnel exited unexpectedly but gracefully, reconnecting");
1119 backoff.delay().await;
1120 }
1121 },
1122 Err(e) = &mut token_ka => {
1123 error!(log, "access token is no longer valid, exiting: {}", e);
1124 return;
1125 },
1126 _ = close_rx.recv() => {
1127 trace!(log, "Tunnel closing gracefully");
1128 trace!(log, "Tunnel closed with result: {:?}", handle.close().await);
1129 break;
1130 }
1131 }
1132 }
1133 }
1134}
1135
1136struct Backoff {
1137 failures: u32,
1138 base_duration: Duration,
1139 max_duration: Duration,
1140}
1141
1142impl Backoff {
1143 pub fn new(base_duration: Duration, max_duration: Duration) -> Self {
1144 Self {
1145 failures: 0,
1146 base_duration,
1147 max_duration,
1148 }
1149 }
1150
1151 pub async fn delay(&mut self) {
1152 tokio::time::sleep(self.next()).await
1153 }
1154
1155 pub fn next(&mut self) -> Duration {
1156 self.failures += 1;
1157 let duration = self
1158 .base_duration
1159 .checked_mul(self.failures)
1160 .unwrap_or(self.max_duration);
1161 std::cmp::min(duration, self.max_duration)
1162 }
1163
1164 pub fn reset(&mut self) {
1165 self.failures = 0;
1166 }
1167}
1168
1169fn clean_hostname_for_tunnel(hostname: &str) -> String {
1172 let mut out = String::new();
1173 for char in hostname.chars().take(60) {
1174 match char {
1175 '-' | '_' | ' ' => {
1176 out.push('-');
1177 }
1178 '0'..='9' | 'a'..='z' | 'A'..='Z' => {
1179 out.push(char);
1180 }
1181 _ => {}
1182 }
1183 }
1184
1185 let trimmed = out.trim_matches('-');
1186 if trimmed.len() < 2 {
1187 "remote-machine".to_string() } else {
1189 trimmed.to_owned()
1190 }
1191}
1192
1193fn vec_eq_as_set(a: &[String], b: &[String]) -> bool {
1194 if a.len() != b.len() {
1195 return false;
1196 }
1197
1198 for item in a {
1199 if !b.contains(item) {
1200 return false;
1201 }
1202 }
1203
1204 true
1205}
1206
1207fn privacy_to_tunnel_acl(privacy: PortPrivacy) -> TunnelAccessControl {
1208 TunnelAccessControl {
1209 entries: vec![match privacy {
1210 PortPrivacy::Public => tunnels::contracts::TunnelAccessControlEntry {
1211 kind: tunnels::contracts::TunnelAccessControlEntryType::Anonymous,
1212 provider: None,
1213 is_inherited: false,
1214 is_deny: false,
1215 is_inverse: false,
1216 organization: None,
1217 expiration: None,
1218 subjects: vec![],
1219 scopes: vec![TUNNEL_ACCESS_SCOPES_CONNECT.to_string()],
1220 },
1221 PortPrivacy::Private => tunnels::contracts::TunnelAccessControlEntry {
1224 kind: tunnels::contracts::TunnelAccessControlEntryType::Anonymous,
1225 provider: None,
1226 is_inherited: false,
1227 is_deny: true,
1228 is_inverse: false,
1229 organization: None,
1230 expiration: None,
1231 subjects: vec![],
1232 scopes: vec![TUNNEL_ACCESS_SCOPES_CONNECT.to_string()],
1233 },
1234 }],
1235 }
1236}
1237
1238#[cfg(test)]
1239mod test {
1240 use super::*;
1241
1242 #[test]
1243 fn test_clean_hostname_for_tunnel() {
1244 assert_eq!(
1245 clean_hostname_for_tunnel("hello123"),
1246 "hello123".to_string()
1247 );
1248 assert_eq!(
1249 clean_hostname_for_tunnel("-cool-name-"),
1250 "cool-name".to_string()
1251 );
1252 assert_eq!(
1253 clean_hostname_for_tunnel("cool!name with_chars"),
1254 "coolname-with-chars".to_string()
1255 );
1256 assert_eq!(clean_hostname_for_tunnel("z"), "remote-machine".to_string());
1257 }
1258}