1use std::time::Duration;
24
25use capnp::capability::Promise;
26use capnp_rpc::{rpc_twoparty_capnp, twoparty, RpcSystem};
27
28struct StubCloudflaredServer;
44
45impl tunnelrpc_capnp::session_manager::Server for StubCloudflaredServer {}
46impl tunnelrpc_capnp::configuration_manager::Server for StubCloudflaredServer {}
47impl tunnelrpc_capnp::cloudflared_server::Server for StubCloudflaredServer {}
48use tokio::time::timeout;
49use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt};
50use tracing::{debug, info};
51use uuid::Uuid;
52
53use crate::error::TunnelError;
54use crate::tunnelrpc_capnp;
55
56pub const DUPLICATE_CONNECTION_ERROR: &str =
60 "edge already has connection registered for the given connection identifier";
61
62pub const DEFAULT_RPC_TIMEOUT: Duration = Duration::from_secs(15);
64
65#[derive(Debug, Clone)]
67pub struct TunnelAuth {
68 pub account_tag: String,
69 pub tunnel_secret: Vec<u8>,
70}
71
72#[derive(Debug, Clone)]
76pub struct ConnectionOptions {
77 pub client_id: [u8; 16],
78 pub features: Vec<String>,
79 pub version: String,
80 pub arch: String,
81 pub origin_local_ip: Vec<u8>,
82 pub replace_existing: bool,
83 pub compression_quality: u8,
84 pub num_previous_attempts: u8,
85}
86
87impl ConnectionOptions {
88 pub fn default_for_quick_tunnel(version: &str) -> Self {
92 Self {
93 client_id: *Uuid::new_v4().as_bytes(),
94 features: vec![
99 "allow_remote_config".into(),
100 "serialized_headers".into(),
101 "support_datagram_v2".into(),
102 "support_quic_eof".into(),
103 "management_logs".into(),
104 ],
105 version: version.to_string(),
106 arch: format!("{}-{}", std::env::consts::OS, std::env::consts::ARCH),
107 origin_local_ip: vec![],
108 replace_existing: false,
109 compression_quality: 0,
110 num_previous_attempts: 0,
111 }
112 }
113}
114
115#[derive(Debug, Clone)]
118pub struct RegistrationDetails {
119 pub uuid: Uuid,
120 pub location: String,
121 pub tunnel_is_remotely_managed: bool,
122}
123
124pub struct ControlSession {
134 shutdown: Option<tokio::sync::oneshot::Sender<ShutdownCommand>>,
135 done: Option<tokio::sync::oneshot::Receiver<()>>,
136 _join: std::thread::JoinHandle<()>,
137}
138
139enum ShutdownCommand {
142 Immediate,
144 Graceful(std::time::Duration),
147}
148
149impl ControlSession {
150 pub async fn shutdown_graceful(mut self, grace: std::time::Duration) {
155 if let Some(tx) = self.shutdown.take() {
156 let _ = tx.send(ShutdownCommand::Graceful(grace));
157 }
158 if let Some(rx) = self.done.take() {
159 let budget = grace + std::time::Duration::from_secs(2);
160 let _ = tokio::time::timeout(budget, rx).await;
161 }
162 }
163}
164
165impl Drop for ControlSession {
166 fn drop(&mut self) {
167 if let Some(tx) = self.shutdown.take() {
168 let _ = tx.send(ShutdownCommand::Immediate);
169 }
170 }
173}
174
175pub async fn register_connection(
187 conn: &quinn::Connection,
188 auth: &TunnelAuth,
189 tunnel_id: Uuid,
190 conn_index: u8,
191 options: &ConnectionOptions,
192) -> Result<(RegistrationDetails, ControlSession), TunnelError> {
193 debug!(%tunnel_id, conn_index, "opening control stream");
194 let (send, recv) = conn
195 .open_bi()
196 .await
197 .map_err(|e| TunnelError::Register(format!("open_bi on control stream: {e}")))?;
198 let (done_tx, done_rx) =
208 tokio::sync::oneshot::channel::<Result<RegistrationDetails, TunnelError>>();
209 let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<ShutdownCommand>();
210 let (driver_done_tx, driver_done_rx) = tokio::sync::oneshot::channel::<()>();
211
212 let auth_owned = auth.clone();
213 let options_owned = options.clone();
214
215 let join = std::thread::Builder::new()
216 .name("cfqt-rpc-driver".into())
217 .spawn(move || {
218 let rt = tokio::runtime::Builder::new_current_thread()
219 .enable_all()
220 .build()
221 .expect("rpc driver runtime");
222 let local = tokio::task::LocalSet::new();
223 local.block_on(&rt, async move {
224 let reader = recv.compat();
228 let writer = send.compat_write();
229 let network = Box::new(twoparty::VatNetwork::new(
230 reader,
231 writer,
232 rpc_twoparty_capnp::Side::Client,
233 Default::default(),
234 ));
235 let stub: tunnelrpc_capnp::cloudflared_server::Client =
239 capnp_rpc::new_client(StubCloudflaredServer);
240 let mut rpc_system = RpcSystem::new(network, Some(stub.client));
241 let server: tunnelrpc_capnp::registration_server::Client =
242 rpc_system.bootstrap(rpc_twoparty_capnp::Side::Server);
243
244 let request = match build_register_request(
246 &server,
247 &auth_owned,
248 tunnel_id,
249 conn_index,
250 &options_owned,
251 ) {
252 Ok(r) => r,
253 Err(e) => {
254 let _ = done_tx.send(Err(e));
255 return;
256 }
257 };
258 let response_promise = request.send().promise;
259
260 let call = async {
261 let reply = response_promise.await.map_err(|e| {
262 TunnelError::Register(format!("register_connection RPC: {e}"))
263 })?;
264 let response_reader = reply
265 .get()
266 .map_err(|e| TunnelError::Register(format!("response root: {e}")))?;
267 let result = response_reader
268 .get_result()
269 .map_err(|e| TunnelError::Register(format!("response.result: {e}")))?;
270 decode_connection_response(result)
271 };
272
273 tokio::pin!(call);
274 tokio::pin!(shutdown_rx);
275 let mut sent_done = false;
276 let mut done_tx = Some(done_tx);
277 let mut shutdown_kind: Option<ShutdownCommand> = None;
278 loop {
279 tokio::select! {
280 biased;
281 res = &mut call, if !sent_done => {
283 if let Some(tx) = done_tx.take() {
284 let _ = tx.send(res);
285 }
286 sent_done = true;
287 }
288 cmd = &mut shutdown_rx => {
291 shutdown_kind = cmd.ok();
292 break;
293 }
294 _ = &mut rpc_system => {
296 if !sent_done {
297 if let Some(tx) = done_tx.take() {
298 let _ = tx.send(Err(TunnelError::Register(
299 "RPC system terminated before call completed".into(),
300 )));
301 }
302 }
303 break;
304 }
305 }
306 }
307
308 if let Some(ShutdownCommand::Graceful(grace)) = shutdown_kind {
314 if sent_done {
315 let req = server.unregister_connection_request();
316 let _ = tokio::time::timeout(grace, req.send().promise).await;
317 }
318 }
319
320 drop(server);
323 let _ = driver_done_tx.send(());
324 });
325 })
326 .map_err(|e| TunnelError::Internal(format!("spawn rpc driver thread: {e}")))?;
327
328 let details = tokio::time::timeout(DEFAULT_RPC_TIMEOUT, done_rx)
329 .await
330 .map_err(|_| TunnelError::Register("register_connection RPC timed out".into()))?
331 .map_err(|_| TunnelError::Register("RPC driver dropped result channel".into()))??;
332
333 info!(uuid = %details.uuid, location = %details.location, "registered with edge");
334
335 Ok((
336 details,
337 ControlSession {
338 shutdown: Some(shutdown_tx),
339 done: Some(driver_done_rx),
340 _join: join,
341 },
342 ))
343}
344
345fn build_register_request(
348 server: &tunnelrpc_capnp::registration_server::Client,
349 auth: &TunnelAuth,
350 tunnel_id: Uuid,
351 conn_index: u8,
352 options: &ConnectionOptions,
353) -> Result<
354 capnp::capability::Request<
355 tunnelrpc_capnp::registration_server::register_connection_params::Owned,
356 tunnelrpc_capnp::registration_server::register_connection_results::Owned,
357 >,
358 TunnelError,
359> {
360 let mut request = server.register_connection_request();
361 {
362 let mut params = request.get();
363
364 let mut a = params.reborrow().init_auth();
366 a.set_account_tag(auth.account_tag.as_str());
367 a.set_tunnel_secret(&auth.tunnel_secret);
368
369 params.set_tunnel_id(tunnel_id.as_bytes());
371
372 params.set_conn_index(conn_index);
373
374 let mut o = params.reborrow().init_options();
376 {
377 let mut client = o.reborrow().init_client();
378 client.set_client_id(&options.client_id);
379 client.set_version(options.version.as_str());
380 client.set_arch(options.arch.as_str());
381 let mut feats = client.init_features(options.features.len() as u32);
382 for (i, f) in options.features.iter().enumerate() {
383 feats.set(i as u32, f.as_str());
384 }
385 }
386 o.set_origin_local_ip(&options.origin_local_ip);
387 o.set_replace_existing(options.replace_existing);
388 o.set_compression_quality(options.compression_quality);
389 o.set_num_previous_attempts(options.num_previous_attempts);
390 }
391 Ok(request)
392}
393
394fn decode_connection_response(
395 response: tunnelrpc_capnp::connection_response::Reader,
396) -> Result<RegistrationDetails, TunnelError> {
397 use tunnelrpc_capnp::connection_response::result::WhichReader;
398 let result = response.get_result();
399 match result
400 .which()
401 .map_err(|e| TunnelError::Register(format!("ConnectionResponse union: {e:?}")))?
402 {
403 WhichReader::Error(err_reader) => {
404 let err = err_reader
405 .map_err(|e| TunnelError::Register(format!("ConnectionError reader: {e}")))?;
406 let cause = err
407 .get_cause()
408 .ok()
409 .and_then(|t| t.to_string().ok())
410 .unwrap_or_else(|| "<missing cause>".into());
411 if cause == DUPLICATE_CONNECTION_ERROR {
412 return Err(TunnelError::Register(format!(
413 "duplicate connection (edge already has connIndex registered): {cause}"
414 )));
415 }
416 Err(TunnelError::Register(cause))
417 }
418 WhichReader::ConnectionDetails(details_reader) => {
419 let d = details_reader
420 .map_err(|e| TunnelError::Register(format!("ConnectionDetails reader: {e}")))?;
421 let uuid_bytes = d
422 .get_uuid()
423 .map_err(|e| TunnelError::Register(format!("ConnectionDetails.uuid: {e}")))?;
424 if uuid_bytes.len() != 16 {
425 return Err(TunnelError::Register(format!(
426 "ConnectionDetails.uuid wrong length: {}",
427 uuid_bytes.len()
428 )));
429 }
430 let mut u = [0u8; 16];
431 u.copy_from_slice(uuid_bytes);
432 let uuid = Uuid::from_bytes(u);
433 let location = d
434 .get_location_name()
435 .ok()
436 .and_then(|t| t.to_string().ok())
437 .unwrap_or_default();
438 let tunnel_is_remotely_managed = d.get_tunnel_is_remotely_managed();
439 Ok(RegistrationDetails {
440 uuid,
441 location,
442 tunnel_is_remotely_managed,
443 })
444 }
445 }
446}
447
448#[allow(dead_code)]
450async fn drive<F: std::future::Future>(
451 f: F,
452 label: &'static str,
453) -> Result<F::Output, TunnelError> {
454 timeout(DEFAULT_RPC_TIMEOUT, f)
455 .await
456 .map_err(|_| TunnelError::Register(format!("{label} timed out")))
457}
458
459#[allow(dead_code)]
460fn _suppress_unused_promise() -> Promise<(), capnp::Error> {
461 Promise::ok(())
462}
463
464#[cfg(test)]
465mod tests {
466 use super::*;
467
468 #[test]
469 fn default_options_carry_features() {
470 let o = ConnectionOptions::default_for_quick_tunnel("test/0.1");
471 assert!(o.features.contains(&"serialized_headers".to_string()));
472 assert_eq!(o.client_id.len(), 16);
473 assert!(o.version.contains("test/0.1"));
474 }
475
476 #[test]
477 fn duplicate_sentinel_matches_upstream() {
478 assert_eq!(
482 DUPLICATE_CONNECTION_ERROR,
483 "edge already has connection registered for the given connection identifier"
484 );
485 }
486}