cli/tunnels/
dev_tunnels.rs

1/*---------------------------------------------------------------------------------------------
2 *  Copyright (c) Microsoft Corporation. All rights reserved.
3 *  Licensed under the MIT License. See License.txt in the project root for license information.
4 *--------------------------------------------------------------------------------------------*/
5use super::protocol::{self, PortPrivacy};
6use crate::auth;
7use crate::constants::{IS_INTERACTIVE_CLI, PROTOCOL_VERSION_TAG, TUNNEL_SERVICE_USER_AGENT};
8use crate::state::{LauncherPaths, PersistedState};
9use crate::util::errors::{
10	wrap, AnyError, CodeError, DevTunnelError, InvalidTunnelName, TunnelCreationFailed,
11	WrappedError,
12};
13use crate::util::input::prompt_placeholder;
14use crate::{debug, info, log, spanf, trace, warning};
15use async_trait::async_trait;
16use futures::future::BoxFuture;
17use futures::{FutureExt, TryFutureExt};
18use lazy_static::lazy_static;
19use rand::prelude::IteratorRandom;
20use regex::Regex;
21use reqwest::StatusCode;
22use serde::{Deserialize, Serialize};
23use std::sync::{Arc, Mutex};
24use std::time::Duration;
25use tokio::sync::{mpsc, watch};
26use tunnels::connections::{ForwardedPortConnection, RelayTunnelHost};
27use tunnels::contracts::{
28	Tunnel, TunnelAccessControl, TunnelPort, TunnelRelayTunnelEndpoint, PORT_TOKEN,
29	TUNNEL_ACCESS_SCOPES_CONNECT, TUNNEL_PROTOCOL_AUTO,
30};
31use tunnels::management::{
32	new_tunnel_management, HttpError, TunnelLocator, TunnelManagementClient, TunnelRequestOptions,
33	NO_REQUEST_OPTIONS,
34};
35
36static TUNNEL_COUNT_LIMIT_NAME: &str = "TunnelsPerUserPerLocation";
37
38#[allow(dead_code)]
39mod tunnel_flags {
40	use crate::{log, tunnels::wsl_detect::is_wsl_installed};
41
42	pub const IS_WSL_INSTALLED: u32 = 1 << 0;
43	pub const IS_WINDOWS: u32 = 1 << 1;
44	pub const IS_LINUX: u32 = 1 << 2;
45	pub const IS_MACOS: u32 = 1 << 3;
46
47	/// Creates a flag string for the tunnel
48	pub fn create(log: &log::Logger) -> String {
49		let mut flags = 0;
50
51		#[cfg(windows)]
52		{
53			flags |= IS_WINDOWS;
54		}
55		#[cfg(target_os = "linux")]
56		{
57			flags |= IS_LINUX;
58		}
59		#[cfg(target_os = "macos")]
60		{
61			flags |= IS_MACOS;
62		}
63
64		if is_wsl_installed(log) {
65			flags |= IS_WSL_INSTALLED;
66		}
67
68		format!("_flag{}", flags)
69	}
70}
71
72#[derive(Clone, Serialize, Deserialize)]
73pub struct PersistedTunnel {
74	pub name: String,
75	pub id: String,
76	pub cluster: String,
77}
78
79impl PersistedTunnel {
80	pub fn into_locator(self) -> TunnelLocator {
81		TunnelLocator::ID {
82			cluster: self.cluster,
83			id: self.id,
84		}
85	}
86	pub fn locator(&self) -> TunnelLocator {
87		TunnelLocator::ID {
88			cluster: self.cluster.clone(),
89			id: self.id.clone(),
90		}
91	}
92}
93
94#[async_trait]
95trait AccessTokenProvider: Send + Sync {
96	/// Gets the current access token.
97	async fn refresh_token(&self) -> Result<String, WrappedError>;
98
99	/// Maintains the stored credential by refreshing it against the service
100	/// to ensure its stays current. Returns a future that should be polled and
101	/// only completes if a refresh fails in a consistent way.
102	fn keep_alive(&self) -> BoxFuture<'static, Result<(), AnyError>>;
103}
104
105/// Access token provider that provides a fixed token without refreshing.
106struct StaticAccessTokenProvider(String);
107
108impl StaticAccessTokenProvider {
109	pub fn new(token: String) -> Self {
110		Self(token)
111	}
112}
113
114#[async_trait]
115impl AccessTokenProvider for StaticAccessTokenProvider {
116	async fn refresh_token(&self) -> Result<String, WrappedError> {
117		Ok(self.0.clone())
118	}
119
120	fn keep_alive(&self) -> BoxFuture<'static, Result<(), AnyError>> {
121		futures::future::pending().boxed()
122	}
123}
124
125/// Access token provider that looks up the token from the tunnels API.
126struct LookupAccessTokenProvider {
127	auth: auth::Auth,
128	client: TunnelManagementClient,
129	locator: TunnelLocator,
130	log: log::Logger,
131	initial_token: Arc<Mutex<Option<String>>>,
132}
133
134impl LookupAccessTokenProvider {
135	pub fn new(
136		auth: auth::Auth,
137		client: TunnelManagementClient,
138		locator: TunnelLocator,
139		log: log::Logger,
140		initial_token: Option<String>,
141	) -> Self {
142		Self {
143			auth,
144			client,
145			locator,
146			log,
147			initial_token: Arc::new(Mutex::new(initial_token)),
148		}
149	}
150}
151
152#[async_trait]
153impl AccessTokenProvider for LookupAccessTokenProvider {
154	async fn refresh_token(&self) -> Result<String, WrappedError> {
155		if let Some(token) = self.initial_token.lock().unwrap().take() {
156			return Ok(token);
157		}
158
159		let tunnel_lookup = spanf!(
160			self.log,
161			self.log.span("dev-tunnel.tag.get"),
162			self.client.get_tunnel(
163				&self.locator,
164				&TunnelRequestOptions {
165					token_scopes: vec!["host".to_string()],
166					..Default::default()
167				}
168			)
169		);
170
171		trace!(self.log, "Successfully refreshed access token");
172
173		match tunnel_lookup {
174			Ok(tunnel) => Ok(get_host_token_from_tunnel(&tunnel)),
175			Err(e) => Err(wrap(e, "failed to lookup tunnel for host token")),
176		}
177	}
178
179	fn keep_alive(&self) -> BoxFuture<'static, Result<(), AnyError>> {
180		let auth = self.auth.clone();
181		auth.keep_token_alive().boxed()
182	}
183}
184
185#[derive(Clone)]
186pub struct DevTunnels {
187	auth: auth::Auth,
188	log: log::Logger,
189	launcher_tunnel: PersistedState<Option<PersistedTunnel>>,
190	client: TunnelManagementClient,
191	tag: &'static str,
192}
193
194/// Representation of a tunnel returned from the `start` methods.
195pub struct ActiveTunnel {
196	/// Name of the tunnel
197	pub name: String,
198	/// Underlying dev tunnels ID
199	pub id: String,
200	manager: ActiveTunnelManager,
201}
202
203impl ActiveTunnel {
204	/// Closes and unregisters the tunnel.
205	pub async fn close(&mut self) -> Result<(), AnyError> {
206		self.manager.kill().await?;
207		Ok(())
208	}
209
210	/// Forwards a port to local connections.
211	pub async fn add_port_direct(
212		&mut self,
213		port_number: u16,
214	) -> Result<mpsc::UnboundedReceiver<ForwardedPortConnection>, AnyError> {
215		let port = self.manager.add_port_direct(port_number).await?;
216		Ok(port)
217	}
218
219	/// Forwards a port over TCP.
220	pub async fn add_port_tcp(
221		&self,
222		port_number: u16,
223		privacy: PortPrivacy,
224	) -> Result<(), AnyError> {
225		self.manager.add_port_tcp(port_number, privacy).await?;
226		Ok(())
227	}
228
229	/// Removes a forwarded port TCP.
230	pub async fn remove_port(&self, port_number: u16) -> Result<(), AnyError> {
231		self.manager.remove_port(port_number).await?;
232		Ok(())
233	}
234
235	/// Gets the template string for forming forwarded port web URIs..
236	pub fn get_port_format(&self) -> Result<String, AnyError> {
237		if let Some(details) = &*self.manager.endpoint_rx.borrow() {
238			return details
239				.as_ref()
240				.map(|r| {
241					r.base
242						.port_uri_format
243						.clone()
244						.expect("expected to have port format")
245				})
246				.map_err(|e| e.clone().into());
247		}
248
249		Err(CodeError::NoTunnelEndpoint.into())
250	}
251
252	/// Gets the public URI on which a forwarded port can be access in browser.
253	pub fn get_port_uri(&self, port: u16) -> Result<String, AnyError> {
254		self.get_port_format()
255			.map(|f| f.replace(PORT_TOKEN, &port.to_string()))
256	}
257
258	/// Gets an object to read the current tunnel status.
259	pub fn status(&self) -> StatusLock {
260		self.manager.get_status()
261	}
262}
263
264const VSCODE_CLI_TUNNEL_TAG: &str = "vscode-server-launcher";
265const VSCODE_CLI_FORWARDING_TAG: &str = "vscode-port-forward";
266const OWNED_TUNNEL_TAGS: &[&str] = &[VSCODE_CLI_TUNNEL_TAG, VSCODE_CLI_FORWARDING_TAG];
267const MAX_TUNNEL_NAME_LENGTH: usize = 20;
268
269fn get_host_token_from_tunnel(tunnel: &Tunnel) -> String {
270	tunnel
271		.access_tokens
272		.as_ref()
273		.expect("expected to have access tokens")
274		.get("host")
275		.expect("expected to have host token")
276		.to_string()
277}
278
279fn is_valid_name(name: &str) -> Result<(), InvalidTunnelName> {
280	if name.len() > MAX_TUNNEL_NAME_LENGTH {
281		return Err(InvalidTunnelName(format!(
282			"Names cannot be longer than {} characters. Please try a different name.",
283			MAX_TUNNEL_NAME_LENGTH
284		)));
285	}
286
287	let re = Regex::new(r"^([\w-]+)$").unwrap();
288
289	if !re.is_match(name) {
290		return Err(InvalidTunnelName(
291            "Names can only contain letters, numbers, and '-'. Spaces, commas, and all other special characters are not allowed. Please try a different name.".to_string()
292        ));
293	}
294
295	Ok(())
296}
297
298lazy_static! {
299	static ref HOST_TUNNEL_REQUEST_OPTIONS: TunnelRequestOptions = TunnelRequestOptions {
300		include_ports: true,
301		token_scopes: vec!["host".to_string()],
302		..Default::default()
303	};
304}
305
306/// Structure optionally passed into `start_existing_tunnel` to forward an existing tunnel.
307#[derive(Clone, Debug)]
308pub struct ExistingTunnel {
309	/// Name you'd like to assign preexisting tunnel to use to connect to the VS Code Server
310	pub tunnel_name: Option<String>,
311
312	/// Token to authenticate and use preexisting tunnel
313	pub host_token: String,
314
315	/// Id of preexisting tunnel to use to connect to the VS Code Server
316	pub tunnel_id: String,
317
318	/// Cluster of preexisting tunnel to use to connect to the VS Code Server
319	pub cluster: String,
320}
321
322impl DevTunnels {
323	/// Creates a new DevTunnels client used for port forwarding.
324	pub fn new_port_forwarding(
325		log: &log::Logger,
326		auth: auth::Auth,
327		paths: &LauncherPaths,
328	) -> DevTunnels {
329		let mut client = new_tunnel_management(&TUNNEL_SERVICE_USER_AGENT);
330		client.authorization_provider(auth.clone());
331
332		DevTunnels {
333			auth,
334			log: log.clone(),
335			client: client.into(),
336			launcher_tunnel: PersistedState::new(paths.root().join("port_forwarding_tunnel.json")),
337			tag: VSCODE_CLI_FORWARDING_TAG,
338		}
339	}
340
341	/// Creates a new DevTunnels client used for the Remote Tunnels extension to access the VS Code Server.
342	pub fn new_remote_tunnel(
343		log: &log::Logger,
344		auth: auth::Auth,
345		paths: &LauncherPaths,
346	) -> DevTunnels {
347		let mut client = new_tunnel_management(&TUNNEL_SERVICE_USER_AGENT);
348		client.authorization_provider(auth.clone());
349
350		DevTunnels {
351			auth,
352			log: log.clone(),
353			client: client.into(),
354			launcher_tunnel: PersistedState::new(paths.root().join("code_tunnel.json")),
355			tag: VSCODE_CLI_TUNNEL_TAG,
356		}
357	}
358
359	pub async fn remove_tunnel(&mut self) -> Result<(), AnyError> {
360		let tunnel = match self.launcher_tunnel.load() {
361			Some(t) => t,
362			None => {
363				return Ok(());
364			}
365		};
366
367		spanf!(
368			self.log,
369			self.log.span("dev-tunnel.delete"),
370			self.client
371				.delete_tunnel(&tunnel.into_locator(), NO_REQUEST_OPTIONS)
372		)
373		.map_err(|e| wrap(e, "failed to execute `tunnel delete`"))?;
374
375		self.launcher_tunnel.save(None)?;
376		Ok(())
377	}
378
379	/// Renames the current tunnel to the new name.
380	pub async fn rename_tunnel(&mut self, name: &str) -> Result<(), AnyError> {
381		self.update_tunnel_name(self.launcher_tunnel.load(), name)
382			.await
383			.map(|_| ())
384	}
385
386	/// Updates the name of the existing persisted tunnel to the new name.
387	/// Gracefully creates a new tunnel if the previous one was deleted.
388	async fn update_tunnel_name(
389		&mut self,
390		persisted: Option<PersistedTunnel>,
391		name: &str,
392	) -> Result<(Tunnel, PersistedTunnel), AnyError> {
393		let name = name.to_ascii_lowercase();
394
395		let (mut full_tunnel, mut persisted, is_new) = match persisted {
396			Some(persisted) => {
397				debug!(
398					self.log,
399					"Found a persisted tunnel, seeing if the name matches..."
400				);
401				self.get_or_create_tunnel(persisted, Some(&name), NO_REQUEST_OPTIONS)
402					.await
403			}
404			None => {
405				debug!(self.log, "Creating a new tunnel with the requested name");
406				self.create_tunnel(&name, NO_REQUEST_OPTIONS)
407					.await
408					.map(|(pt, t)| (t, pt, true))
409			}
410		}?;
411
412		let desired_tags = self.get_labels(&name);
413		if is_new || vec_eq_as_set(&full_tunnel.labels, &desired_tags) {
414			return Ok((full_tunnel, persisted));
415		}
416
417		debug!(self.log, "Tunnel name changed, applying updates...");
418
419		full_tunnel.labels = desired_tags;
420
421		let updated_tunnel = spanf!(
422			self.log,
423			self.log.span("dev-tunnel.tag.update"),
424			self.client.update_tunnel(&full_tunnel, NO_REQUEST_OPTIONS)
425		)
426		.map_err(|e| wrap(e, "failed to rename tunnel"))?;
427
428		persisted.name = name;
429		self.launcher_tunnel.save(Some(persisted.clone()))?;
430
431		Ok((updated_tunnel, persisted))
432	}
433
434	/// Gets the persisted tunnel from the service, or creates a new one.
435	/// If `create_with_new_name` is given, the new tunnel has that name
436	/// instead of the one previously persisted.
437	async fn get_or_create_tunnel(
438		&mut self,
439		persisted: PersistedTunnel,
440		create_with_new_name: Option<&str>,
441		options: &TunnelRequestOptions,
442	) -> Result<(Tunnel, PersistedTunnel, /* is_new */ bool), AnyError> {
443		let tunnel_lookup = spanf!(
444			self.log,
445			self.log.span("dev-tunnel.tag.get"),
446			self.client.get_tunnel(&persisted.locator(), options)
447		);
448
449		match tunnel_lookup {
450			Ok(ft) => Ok((ft, persisted, false)),
451			Err(HttpError::ResponseError(e))
452				if e.status_code == StatusCode::NOT_FOUND
453					|| e.status_code == StatusCode::FORBIDDEN =>
454			{
455				let (persisted, tunnel) = self
456					.create_tunnel(create_with_new_name.unwrap_or(&persisted.name), options)
457					.await?;
458				Ok((tunnel, persisted, true))
459			}
460			Err(e) => Err(wrap(e, "failed to lookup tunnel").into()),
461		}
462	}
463
464	/// Starts a new tunnel for the code server on the port. Unlike `start_new_tunnel`,
465	/// this attempts to reuse or create a tunnel of a preferred name or of a generated friendly tunnel name.
466	pub async fn start_new_launcher_tunnel(
467		&mut self,
468		preferred_name: Option<&str>,
469		use_random_name: bool,
470		preserve_ports: &[u16],
471	) -> Result<ActiveTunnel, AnyError> {
472		let (mut tunnel, persisted) = match self.launcher_tunnel.load() {
473			Some(mut persisted) => {
474				if let Some(preferred_name) = preferred_name.map(|n| n.to_ascii_lowercase()) {
475					if persisted.name.to_ascii_lowercase() != preferred_name {
476						(_, persisted) = self
477							.update_tunnel_name(Some(persisted), &preferred_name)
478							.await?;
479					}
480				}
481
482				let (tunnel, persisted, _) = self
483					.get_or_create_tunnel(persisted, None, &HOST_TUNNEL_REQUEST_OPTIONS)
484					.await?;
485				(tunnel, persisted)
486			}
487			None => {
488				debug!(self.log, "No code server tunnel found, creating new one");
489				let name = self
490					.get_name_for_tunnel(preferred_name, use_random_name)
491					.await?;
492				let (persisted, full_tunnel) = self
493					.create_tunnel(&name, &HOST_TUNNEL_REQUEST_OPTIONS)
494					.await?;
495				(full_tunnel, persisted)
496			}
497		};
498
499		tunnel = self
500			.sync_tunnel_tags(
501				&self.client,
502				&persisted.name,
503				tunnel,
504				&HOST_TUNNEL_REQUEST_OPTIONS,
505			)
506			.await?;
507
508		let locator = TunnelLocator::try_from(&tunnel).unwrap();
509		let host_token = get_host_token_from_tunnel(&tunnel);
510
511		for port_to_delete in tunnel
512			.ports
513			.iter()
514			.filter(|p: &&TunnelPort| !preserve_ports.contains(&p.port_number))
515		{
516			let output_fut = self.client.delete_tunnel_port(
517				&locator,
518				port_to_delete.port_number,
519				NO_REQUEST_OPTIONS,
520			);
521			spanf!(
522				self.log,
523				self.log.span("dev-tunnel.port.delete"),
524				output_fut
525			)
526			.map_err(|e| wrap(e, "failed to delete port"))?;
527		}
528
529		// cleanup any old trailing tunnel endpoints
530		for endpoint in tunnel.endpoints {
531			let fut = self.client.delete_tunnel_endpoints(
532				&locator,
533				&endpoint.host_id,
534				NO_REQUEST_OPTIONS,
535			);
536
537			spanf!(self.log, self.log.span("dev-tunnel.endpoint.prune"), fut)
538				.map_err(|e| wrap(e, "failed to prune tunnel endpoint"))?;
539		}
540
541		self.start_tunnel(
542			locator.clone(),
543			&persisted,
544			self.client.clone(),
545			LookupAccessTokenProvider::new(
546				self.auth.clone(),
547				self.client.clone(),
548				locator,
549				self.log.clone(),
550				Some(host_token),
551			),
552		)
553		.await
554	}
555
556	async fn create_tunnel(
557		&mut self,
558		name: &str,
559		options: &TunnelRequestOptions,
560	) -> Result<(PersistedTunnel, Tunnel), AnyError> {
561		info!(self.log, "Creating tunnel with the name: {}", name);
562
563		let tunnel = match self.get_existing_tunnel_with_name(name).await? {
564			Some(e) => {
565				let loc = TunnelLocator::try_from(&e).unwrap();
566				info!(self.log, "Adopting existing tunnel (ID={:?})", loc);
567				spanf!(
568					self.log,
569					self.log.span("dev-tunnel.tag.get"),
570					self.client.get_tunnel(&loc, &HOST_TUNNEL_REQUEST_OPTIONS)
571				)
572				.map_err(|e| wrap(e, "failed to lookup tunnel"))?
573			}
574			None => loop {
575				let result = spanf!(
576					self.log,
577					self.log.span("dev-tunnel.create"),
578					self.client.create_tunnel(
579						Tunnel {
580							labels: self.get_labels(name),
581							..Default::default()
582						},
583						options
584					)
585				);
586
587				match result {
588					Err(HttpError::ResponseError(e))
589						if e.status_code == StatusCode::TOO_MANY_REQUESTS =>
590					{
591						if let Some(d) = e.get_details() {
592							let detail = d.detail.unwrap_or_else(|| "unknown".to_string());
593							if detail.contains(TUNNEL_COUNT_LIMIT_NAME)
594								&& self.try_recycle_tunnel().await?
595							{
596								continue;
597							}
598
599							return Err(AnyError::from(TunnelCreationFailed(
600								name.to_string(),
601								detail,
602							)));
603						}
604
605						return Err(AnyError::from(TunnelCreationFailed(
606								name.to_string(),
607								"You have exceeded a limit for the port fowarding service. Please remove other machines before trying to add this machine.".to_string(),
608							)));
609					}
610					Err(e) => {
611						return Err(AnyError::from(TunnelCreationFailed(
612							name.to_string(),
613							format!("{:?}", e),
614						)))
615					}
616					Ok(t) => break t,
617				}
618			},
619		};
620
621		let pt = PersistedTunnel {
622			cluster: tunnel.cluster_id.clone().unwrap(),
623			id: tunnel.tunnel_id.clone().unwrap(),
624			name: name.to_string(),
625		};
626
627		self.launcher_tunnel.save(Some(pt.clone()))?;
628		Ok((pt, tunnel))
629	}
630
631	/// Gets the expected tunnel tags
632	fn get_labels(&self, name: &str) -> Vec<String> {
633		vec![
634			name.to_string(),
635			PROTOCOL_VERSION_TAG.to_string(),
636			self.tag.to_string(),
637			tunnel_flags::create(&self.log),
638		]
639	}
640
641	/// Ensures the tunnel contains a tag for the current PROTCOL_VERSION, and no
642	/// other version tags.
643	async fn sync_tunnel_tags(
644		&self,
645		client: &TunnelManagementClient,
646		name: &str,
647		tunnel: Tunnel,
648		options: &TunnelRequestOptions,
649	) -> Result<Tunnel, AnyError> {
650		let new_labels = self.get_labels(name);
651		if vec_eq_as_set(&tunnel.labels, &new_labels) {
652			return Ok(tunnel);
653		}
654
655		debug!(
656			self.log,
657			"Updating tunnel tags {} -> {}",
658			tunnel.labels.join(", "),
659			new_labels.join(", ")
660		);
661
662		let tunnel_update = Tunnel {
663			labels: new_labels,
664			tunnel_id: tunnel.tunnel_id.clone(),
665			cluster_id: tunnel.cluster_id.clone(),
666			..Default::default()
667		};
668
669		let result = spanf!(
670			self.log,
671			self.log.span("dev-tunnel.protocol-tag-update"),
672			client.update_tunnel(&tunnel_update, options)
673		);
674
675		result.map_err(|e| wrap(e, "tunnel tag update failed").into())
676	}
677
678	/// Tries to delete an unused tunnel, and then creates a tunnel with the
679	/// given `new_name`.
680	async fn try_recycle_tunnel(&mut self) -> Result<bool, AnyError> {
681		trace!(
682			self.log,
683			"Tunnel limit hit, trying to recycle an old tunnel"
684		);
685
686		let existing_tunnels = self.list_tunnels_with_tag(OWNED_TUNNEL_TAGS).await?;
687
688		let recyclable = existing_tunnels
689			.iter()
690			.filter(|t| {
691				t.status
692					.as_ref()
693					.and_then(|s| s.host_connection_count.as_ref())
694					.map(|c| c.get_count())
695					.unwrap_or(0) == 0
696			})
697			.choose(&mut rand::thread_rng());
698
699		match recyclable {
700			Some(tunnel) => {
701				trace!(self.log, "Recycling tunnel ID {:?}", tunnel.tunnel_id);
702				spanf!(
703					self.log,
704					self.log.span("dev-tunnel.delete"),
705					self.client
706						.delete_tunnel(&tunnel.try_into().unwrap(), NO_REQUEST_OPTIONS)
707				)
708				.map_err(|e| wrap(e, "failed to execute `tunnel delete`"))?;
709				Ok(true)
710			}
711			None => {
712				trace!(self.log, "No tunnels available to recycle");
713				Ok(false)
714			}
715		}
716	}
717
718	async fn list_tunnels_with_tag(
719		&mut self,
720		tags: &[&'static str],
721	) -> Result<Vec<Tunnel>, AnyError> {
722		let tunnels = spanf!(
723			self.log,
724			self.log.span("dev-tunnel.listall"),
725			self.client.list_all_tunnels(&TunnelRequestOptions {
726				labels: tags.iter().map(|t| t.to_string()).collect(),
727				..Default::default()
728			})
729		)
730		.map_err(|e| wrap(e, "error listing current tunnels"))?;
731
732		Ok(tunnels)
733	}
734
735	async fn get_existing_tunnel_with_name(&self, name: &str) -> Result<Option<Tunnel>, AnyError> {
736		let existing: Vec<Tunnel> = spanf!(
737			self.log,
738			self.log.span("dev-tunnel.rename.search"),
739			self.client.list_all_tunnels(&TunnelRequestOptions {
740				labels: vec![self.tag.to_string(), name.to_string()],
741				require_all_labels: true,
742				limit: 1,
743				include_ports: true,
744				token_scopes: vec!["host".to_string()],
745				..Default::default()
746			})
747		)
748		.map_err(|e| wrap(e, "failed to list existing tunnels"))?;
749
750		Ok(existing.into_iter().next())
751	}
752
753	fn get_placeholder_name() -> String {
754		let mut n = clean_hostname_for_tunnel(&gethostname::gethostname().to_string_lossy());
755		n.make_ascii_lowercase();
756		n.truncate(MAX_TUNNEL_NAME_LENGTH);
757		n
758	}
759
760	async fn get_name_for_tunnel(
761		&mut self,
762		preferred_name: Option<&str>,
763		mut use_random_name: bool,
764	) -> Result<String, AnyError> {
765		let existing_tunnels = self.list_tunnels_with_tag(&[self.tag]).await?;
766		let is_name_free = |n: &str| {
767			!existing_tunnels.iter().any(|v| {
768				v.status
769					.as_ref()
770					.and_then(|s| s.host_connection_count.as_ref().map(|c| c.get_count()))
771					.unwrap_or(0) > 0 && v.labels.iter().any(|t| t == n)
772			})
773		};
774
775		if let Some(machine_name) = preferred_name {
776			let name = machine_name.to_ascii_lowercase();
777			if let Err(e) = is_valid_name(&name) {
778				info!(self.log, "{} is an invalid name", e);
779				return Err(AnyError::from(wrap(e, "invalid name")));
780			}
781			if is_name_free(&name) {
782				return Ok(name);
783			}
784			info!(
785				self.log,
786				"{} is already taken, using a random name instead", &name
787			);
788			use_random_name = true;
789		}
790
791		let mut placeholder_name = Self::get_placeholder_name();
792		if !is_name_free(&placeholder_name) {
793			for i in 2.. {
794				let fixed_name = format!("{}{}", placeholder_name, i);
795				if is_name_free(&fixed_name) {
796					placeholder_name = fixed_name;
797					break;
798				}
799			}
800		}
801
802		if use_random_name || !*IS_INTERACTIVE_CLI {
803			return Ok(placeholder_name);
804		}
805
806		loop {
807			let mut name = prompt_placeholder(
808				"What would you like to call this machine?",
809				&placeholder_name,
810			)?;
811
812			name.make_ascii_lowercase();
813
814			if let Err(e) = is_valid_name(&name) {
815				info!(self.log, "{}", e);
816				continue;
817			}
818
819			if is_name_free(&name) {
820				return Ok(name);
821			}
822
823			info!(self.log, "The name {} is already in use", name);
824		}
825	}
826
827	/// Hosts an existing tunnel, where the tunnel ID and host token are given.
828	pub async fn start_existing_tunnel(
829		&mut self,
830		tunnel: ExistingTunnel,
831	) -> Result<ActiveTunnel, AnyError> {
832		let tunnel_details = PersistedTunnel {
833			name: match tunnel.tunnel_name {
834				Some(n) => n,
835				None => Self::get_placeholder_name(),
836			},
837			id: tunnel.tunnel_id,
838			cluster: tunnel.cluster,
839		};
840
841		let mut mgmt = self.client.build();
842		mgmt.authorization(tunnels::management::Authorization::Tunnel(
843			tunnel.host_token.clone(),
844		));
845
846		let client = mgmt.into();
847		self.sync_tunnel_tags(
848			&client,
849			&tunnel_details.name,
850			Tunnel {
851				cluster_id: Some(tunnel_details.cluster.clone()),
852				tunnel_id: Some(tunnel_details.id.clone()),
853				..Default::default()
854			},
855			&HOST_TUNNEL_REQUEST_OPTIONS,
856		)
857		.await?;
858
859		self.start_tunnel(
860			tunnel_details.locator(),
861			&tunnel_details,
862			client,
863			StaticAccessTokenProvider::new(tunnel.host_token),
864		)
865		.await
866	}
867
868	async fn start_tunnel(
869		&mut self,
870		locator: TunnelLocator,
871		tunnel_details: &PersistedTunnel,
872		client: TunnelManagementClient,
873		access_token: impl AccessTokenProvider + 'static,
874	) -> Result<ActiveTunnel, AnyError> {
875		let mut manager = ActiveTunnelManager::new(self.log.clone(), client, locator, access_token);
876
877		let endpoint_result = spanf!(
878			self.log,
879			self.log.span("dev-tunnel.serve.callback"),
880			manager.get_endpoint()
881		);
882
883		let endpoint = match endpoint_result {
884			Ok(endpoint) => endpoint,
885			Err(e) => {
886				error!(self.log, "Error connecting to tunnel endpoint: {}", e);
887				manager.kill().await.ok();
888				return Err(e);
889			}
890		};
891
892		debug!(self.log, "Connected to tunnel endpoint: {:?}", endpoint);
893
894		Ok(ActiveTunnel {
895			name: tunnel_details.name.clone(),
896			id: tunnel_details.id.clone(),
897			manager,
898		})
899	}
900}
901
902#[derive(Clone, Default)]
903pub struct StatusLock(Arc<std::sync::Mutex<protocol::singleton::Status>>);
904
905impl StatusLock {
906	fn succeed(&self) {
907		let mut status = self.0.lock().unwrap();
908		status.tunnel = protocol::singleton::TunnelState::Connected;
909		status.last_connected_at = Some(chrono::Utc::now());
910	}
911
912	fn fail(&self, reason: String) {
913		let mut status = self.0.lock().unwrap();
914		if let protocol::singleton::TunnelState::Connected = status.tunnel {
915			status.last_disconnected_at = Some(chrono::Utc::now());
916			status.tunnel = protocol::singleton::TunnelState::Disconnected;
917		}
918		status.last_fail_reason = Some(reason);
919	}
920
921	pub fn read(&self) -> protocol::singleton::Status {
922		let status = self.0.lock().unwrap();
923		status.clone()
924	}
925}
926
927struct ActiveTunnelManager {
928	close_tx: Option<mpsc::Sender<()>>,
929	endpoint_rx: watch::Receiver<Option<Result<TunnelRelayTunnelEndpoint, WrappedError>>>,
930	relay: Arc<tokio::sync::Mutex<RelayTunnelHost>>,
931	status: StatusLock,
932}
933
934impl ActiveTunnelManager {
935	pub fn new(
936		log: log::Logger,
937		mgmt: TunnelManagementClient,
938		locator: TunnelLocator,
939		access_token: impl AccessTokenProvider + 'static,
940	) -> ActiveTunnelManager {
941		let (endpoint_tx, endpoint_rx) = watch::channel(None);
942		let (close_tx, close_rx) = mpsc::channel(1);
943
944		let relay = Arc::new(tokio::sync::Mutex::new(RelayTunnelHost::new(locator, mgmt)));
945		let relay_spawned = relay.clone();
946
947		let status = StatusLock::default();
948
949		let status_spawned = status.clone();
950		tokio::spawn(async move {
951			ActiveTunnelManager::spawn_tunnel(
952				log,
953				relay_spawned,
954				close_rx,
955				endpoint_tx,
956				access_token,
957				status_spawned,
958			)
959			.await;
960		});
961
962		ActiveTunnelManager {
963			endpoint_rx,
964			relay,
965			close_tx: Some(close_tx),
966			status,
967		}
968	}
969
970	/// Gets a copy of the current tunnel status information
971	pub fn get_status(&self) -> StatusLock {
972		self.status.clone()
973	}
974
975	/// Adds a port for TCP/IP forwarding.
976	#[allow(dead_code)] // todo: port forwarding
977	pub async fn add_port_tcp(
978		&self,
979		port_number: u16,
980		privacy: PortPrivacy,
981	) -> Result<(), WrappedError> {
982		self.relay
983			.lock()
984			.await
985			.add_port(&TunnelPort {
986				port_number,
987				protocol: Some(TUNNEL_PROTOCOL_AUTO.to_owned()),
988				access_control: Some(privacy_to_tunnel_acl(privacy)),
989				..Default::default()
990			})
991			.await
992			.map_err(|e| wrap(e, "error adding port to relay"))?;
993		Ok(())
994	}
995
996	/// Adds a port for TCP/IP forwarding.
997	pub async fn add_port_direct(
998		&self,
999		port_number: u16,
1000	) -> Result<mpsc::UnboundedReceiver<ForwardedPortConnection>, WrappedError> {
1001		self.relay
1002			.lock()
1003			.await
1004			.add_port_raw(&TunnelPort {
1005				port_number,
1006				protocol: Some(TUNNEL_PROTOCOL_AUTO.to_owned()),
1007				access_control: Some(privacy_to_tunnel_acl(PortPrivacy::Private)),
1008				..Default::default()
1009			})
1010			.await
1011			.map_err(|e| wrap(e, "error adding port to relay"))
1012	}
1013
1014	/// Removes a port from TCP/IP forwarding.
1015	pub async fn remove_port(&self, port_number: u16) -> Result<(), WrappedError> {
1016		self.relay
1017			.lock()
1018			.await
1019			.remove_port(port_number)
1020			.await
1021			.map_err(|e| wrap(e, "error remove port from relay"))
1022	}
1023
1024	/// Gets the most recent details from the tunnel process. Returns None if
1025	/// the process exited before providing details.
1026	pub async fn get_endpoint(&mut self) -> Result<TunnelRelayTunnelEndpoint, AnyError> {
1027		loop {
1028			if let Some(details) = &*self.endpoint_rx.borrow() {
1029				return details.clone().map_err(AnyError::from);
1030			}
1031
1032			if self.endpoint_rx.changed().await.is_err() {
1033				return Err(DevTunnelError("tunnel creation cancelled".to_string()).into());
1034			}
1035		}
1036	}
1037
1038	/// Kills the process, and waits for it to exit.
1039	/// See https://tokio.rs/tokio/topics/shutdown#waiting-for-things-to-finish-shutting-down for how this works
1040	pub async fn kill(&mut self) -> Result<(), AnyError> {
1041		if let Some(tx) = self.close_tx.take() {
1042			drop(tx);
1043		}
1044
1045		self.relay
1046			.lock()
1047			.await
1048			.unregister()
1049			.await
1050			.map_err(|e| wrap(e, "error unregistering relay"))?;
1051
1052		while self.endpoint_rx.changed().await.is_ok() {}
1053
1054		Ok(())
1055	}
1056
1057	async fn spawn_tunnel(
1058		log: log::Logger,
1059		relay: Arc<tokio::sync::Mutex<RelayTunnelHost>>,
1060		mut close_rx: mpsc::Receiver<()>,
1061		endpoint_tx: watch::Sender<Option<Result<TunnelRelayTunnelEndpoint, WrappedError>>>,
1062		access_token_provider: impl AccessTokenProvider + 'static,
1063		status: StatusLock,
1064	) {
1065		let mut token_ka = access_token_provider.keep_alive();
1066		let mut backoff = Backoff::new(Duration::from_secs(5), Duration::from_secs(120));
1067
1068		macro_rules! fail {
1069			($e: expr, $msg: expr) => {
1070				let fmt = format!("{}: {}", $msg, $e);
1071				warning!(log, &fmt);
1072				status.fail(fmt);
1073				endpoint_tx.send(Some(Err($e))).ok();
1074				backoff.delay().await;
1075			};
1076		}
1077
1078		loop {
1079			debug!(log, "Starting tunnel to server...");
1080
1081			let access_token = match access_token_provider.refresh_token().await {
1082				Ok(t) => t,
1083				Err(e) => {
1084					fail!(e, "Error refreshing access token, will retry");
1085					continue;
1086				}
1087			};
1088
1089			// we don't bother making a client that can refresh the token, since
1090			// the tunnel won't be able to host as soon as the access token expires.
1091			let handle_res = {
1092				let mut relay = relay.lock().await;
1093				relay
1094					.connect(&access_token)
1095					.await
1096					.map_err(|e| wrap(e, "error connecting to tunnel"))
1097			};
1098
1099			let mut handle = match handle_res {
1100				Ok(handle) => handle,
1101				Err(e) => {
1102					fail!(e, "Error connecting to relay, will retry");
1103					continue;
1104				}
1105			};
1106
1107			backoff.reset();
1108			status.succeed();
1109			endpoint_tx.send(Some(Ok(handle.endpoint().clone()))).ok();
1110
1111			tokio::select! {
1112				// error is mapped like this prevent it being used across an await,
1113				// which Rust dislikes since there's a non-sendable dyn Error in there
1114				res = (&mut handle).map_err(|e| wrap(e, "error from tunnel connection")) => {
1115					if let Err(e) = res {
1116						fail!(e, "Tunnel exited unexpectedly, reconnecting");
1117					} else {
1118						warning!(log, "Tunnel exited unexpectedly but gracefully, reconnecting");
1119						backoff.delay().await;
1120					}
1121				},
1122				Err(e) = &mut token_ka => {
1123					error!(log, "access token is no longer valid, exiting: {}", e);
1124					return;
1125				},
1126				_ = close_rx.recv() => {
1127					trace!(log, "Tunnel closing gracefully");
1128					trace!(log, "Tunnel closed with result: {:?}", handle.close().await);
1129					break;
1130				}
1131			}
1132		}
1133	}
1134}
1135
1136struct Backoff {
1137	failures: u32,
1138	base_duration: Duration,
1139	max_duration: Duration,
1140}
1141
1142impl Backoff {
1143	pub fn new(base_duration: Duration, max_duration: Duration) -> Self {
1144		Self {
1145			failures: 0,
1146			base_duration,
1147			max_duration,
1148		}
1149	}
1150
1151	pub async fn delay(&mut self) {
1152		tokio::time::sleep(self.next()).await
1153	}
1154
1155	pub fn next(&mut self) -> Duration {
1156		self.failures += 1;
1157		let duration = self
1158			.base_duration
1159			.checked_mul(self.failures)
1160			.unwrap_or(self.max_duration);
1161		std::cmp::min(duration, self.max_duration)
1162	}
1163
1164	pub fn reset(&mut self) {
1165		self.failures = 0;
1166	}
1167}
1168
1169/// Cleans up the hostname so it can be used as a tunnel name.
1170/// See TUNNEL_NAME_PATTERN in the tunnels SDK for the rules we try to use.
1171fn clean_hostname_for_tunnel(hostname: &str) -> String {
1172	let mut out = String::new();
1173	for char in hostname.chars().take(60) {
1174		match char {
1175			'-' | '_' | ' ' => {
1176				out.push('-');
1177			}
1178			'0'..='9' | 'a'..='z' | 'A'..='Z' => {
1179				out.push(char);
1180			}
1181			_ => {}
1182		}
1183	}
1184
1185	let trimmed = out.trim_matches('-');
1186	if trimmed.len() < 2 {
1187		"remote-machine".to_string() // placeholder if the result was empty
1188	} else {
1189		trimmed.to_owned()
1190	}
1191}
1192
1193fn vec_eq_as_set(a: &[String], b: &[String]) -> bool {
1194	if a.len() != b.len() {
1195		return false;
1196	}
1197
1198	for item in a {
1199		if !b.contains(item) {
1200			return false;
1201		}
1202	}
1203
1204	true
1205}
1206
1207fn privacy_to_tunnel_acl(privacy: PortPrivacy) -> TunnelAccessControl {
1208	TunnelAccessControl {
1209		entries: vec![match privacy {
1210			PortPrivacy::Public => tunnels::contracts::TunnelAccessControlEntry {
1211				kind: tunnels::contracts::TunnelAccessControlEntryType::Anonymous,
1212				provider: None,
1213				is_inherited: false,
1214				is_deny: false,
1215				is_inverse: false,
1216				organization: None,
1217				expiration: None,
1218				subjects: vec![],
1219				scopes: vec![TUNNEL_ACCESS_SCOPES_CONNECT.to_string()],
1220			},
1221			// Ensure private ports are actually private and do not inherit any
1222			// default visibility that may be set on the tunnel:
1223			PortPrivacy::Private => tunnels::contracts::TunnelAccessControlEntry {
1224				kind: tunnels::contracts::TunnelAccessControlEntryType::Anonymous,
1225				provider: None,
1226				is_inherited: false,
1227				is_deny: true,
1228				is_inverse: false,
1229				organization: None,
1230				expiration: None,
1231				subjects: vec![],
1232				scopes: vec![TUNNEL_ACCESS_SCOPES_CONNECT.to_string()],
1233			},
1234		}],
1235	}
1236}
1237
1238#[cfg(test)]
1239mod test {
1240	use super::*;
1241
1242	#[test]
1243	fn test_clean_hostname_for_tunnel() {
1244		assert_eq!(
1245			clean_hostname_for_tunnel("hello123"),
1246			"hello123".to_string()
1247		);
1248		assert_eq!(
1249			clean_hostname_for_tunnel("-cool-name-"),
1250			"cool-name".to_string()
1251		);
1252		assert_eq!(
1253			clean_hostname_for_tunnel("cool!name with_chars"),
1254			"coolname-with-chars".to_string()
1255		);
1256		assert_eq!(clean_hostname_for_tunnel("z"), "remote-machine".to_string());
1257	}
1258}