use std::time::Duration;
use capnp::capability::Promise;
use capnp_rpc::{rpc_twoparty_capnp, twoparty, RpcSystem};
struct StubCloudflaredServer;
impl tunnelrpc_capnp::session_manager::Server for StubCloudflaredServer {}
impl tunnelrpc_capnp::configuration_manager::Server for StubCloudflaredServer {}
impl tunnelrpc_capnp::cloudflared_server::Server for StubCloudflaredServer {}
use tokio::time::timeout;
use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt};
use tracing::{debug, info};
use uuid::Uuid;
use crate::error::TunnelError;
use crate::tunnelrpc_capnp;
pub const DUPLICATE_CONNECTION_ERROR: &str =
"edge already has connection registered for the given connection identifier";
pub const DEFAULT_RPC_TIMEOUT: Duration = Duration::from_secs(15);
#[derive(Debug, Clone)]
pub struct TunnelAuth {
pub account_tag: String,
pub tunnel_secret: Vec<u8>,
}
#[derive(Debug, Clone)]
pub struct ConnectionOptions {
pub client_id: [u8; 16],
pub features: Vec<String>,
pub version: String,
pub arch: String,
pub origin_local_ip: Vec<u8>,
pub replace_existing: bool,
pub compression_quality: u8,
pub num_previous_attempts: u8,
}
impl ConnectionOptions {
pub fn default_for_quick_tunnel(version: &str) -> Self {
Self {
client_id: *Uuid::new_v4().as_bytes(),
features: vec![
"allow_remote_config".into(),
"serialized_headers".into(),
"support_datagram_v2".into(),
"support_quic_eof".into(),
"management_logs".into(),
],
version: version.to_string(),
arch: format!("{}-{}", std::env::consts::OS, std::env::consts::ARCH),
origin_local_ip: vec![],
replace_existing: false,
compression_quality: 0,
num_previous_attempts: 0,
}
}
}
#[derive(Debug, Clone)]
pub struct RegistrationDetails {
pub uuid: Uuid,
pub location: String,
pub tunnel_is_remotely_managed: bool,
}
pub struct ControlSession {
shutdown: Option<tokio::sync::oneshot::Sender<ShutdownCommand>>,
done: Option<tokio::sync::oneshot::Receiver<()>>,
_join: std::thread::JoinHandle<()>,
}
enum ShutdownCommand {
Immediate,
Graceful(std::time::Duration),
}
impl ControlSession {
pub async fn shutdown_graceful(mut self, grace: std::time::Duration) {
if let Some(tx) = self.shutdown.take() {
let _ = tx.send(ShutdownCommand::Graceful(grace));
}
if let Some(rx) = self.done.take() {
let budget = grace + std::time::Duration::from_secs(2);
let _ = tokio::time::timeout(budget, rx).await;
}
}
}
impl Drop for ControlSession {
fn drop(&mut self) {
if let Some(tx) = self.shutdown.take() {
let _ = tx.send(ShutdownCommand::Immediate);
}
}
}
pub async fn register_connection(
conn: &quinn::Connection,
auth: &TunnelAuth,
tunnel_id: Uuid,
conn_index: u8,
options: &ConnectionOptions,
) -> Result<(RegistrationDetails, ControlSession), TunnelError> {
debug!(%tunnel_id, conn_index, "opening control stream");
let (send, recv) = conn
.open_bi()
.await
.map_err(|e| TunnelError::Register(format!("open_bi on control stream: {e}")))?;
let (done_tx, done_rx) =
tokio::sync::oneshot::channel::<Result<RegistrationDetails, TunnelError>>();
let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<ShutdownCommand>();
let (driver_done_tx, driver_done_rx) = tokio::sync::oneshot::channel::<()>();
let auth_owned = auth.clone();
let options_owned = options.clone();
let join = std::thread::Builder::new()
.name("cfqt-rpc-driver".into())
.spawn(move || {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.expect("rpc driver runtime");
let local = tokio::task::LocalSet::new();
local.block_on(&rt, async move {
let reader = recv.compat();
let writer = send.compat_write();
let network = Box::new(twoparty::VatNetwork::new(
reader,
writer,
rpc_twoparty_capnp::Side::Client,
Default::default(),
));
let stub: tunnelrpc_capnp::cloudflared_server::Client =
capnp_rpc::new_client(StubCloudflaredServer);
let mut rpc_system = RpcSystem::new(network, Some(stub.client));
let server: tunnelrpc_capnp::registration_server::Client =
rpc_system.bootstrap(rpc_twoparty_capnp::Side::Server);
let request = match build_register_request(
&server,
&auth_owned,
tunnel_id,
conn_index,
&options_owned,
) {
Ok(r) => r,
Err(e) => {
let _ = done_tx.send(Err(e));
return;
}
};
let response_promise = request.send().promise;
let call = async {
let reply = response_promise.await.map_err(|e| {
TunnelError::Register(format!("register_connection RPC: {e}"))
})?;
let response_reader = reply
.get()
.map_err(|e| TunnelError::Register(format!("response root: {e}")))?;
let result = response_reader
.get_result()
.map_err(|e| TunnelError::Register(format!("response.result: {e}")))?;
decode_connection_response(result)
};
tokio::pin!(call);
tokio::pin!(shutdown_rx);
let mut sent_done = false;
let mut done_tx = Some(done_tx);
let mut shutdown_kind: Option<ShutdownCommand> = None;
loop {
tokio::select! {
biased;
res = &mut call, if !sent_done => {
if let Some(tx) = done_tx.take() {
let _ = tx.send(res);
}
sent_done = true;
}
cmd = &mut shutdown_rx => {
shutdown_kind = cmd.ok();
break;
}
_ = &mut rpc_system => {
if !sent_done {
if let Some(tx) = done_tx.take() {
let _ = tx.send(Err(TunnelError::Register(
"RPC system terminated before call completed".into(),
)));
}
}
break;
}
}
}
if let Some(ShutdownCommand::Graceful(grace)) = shutdown_kind {
if sent_done {
let req = server.unregister_connection_request();
let _ = tokio::time::timeout(grace, req.send().promise).await;
}
}
drop(server);
let _ = driver_done_tx.send(());
});
})
.map_err(|e| TunnelError::Internal(format!("spawn rpc driver thread: {e}")))?;
let details = tokio::time::timeout(DEFAULT_RPC_TIMEOUT, done_rx)
.await
.map_err(|_| TunnelError::Register("register_connection RPC timed out".into()))?
.map_err(|_| TunnelError::Register("RPC driver dropped result channel".into()))??;
info!(uuid = %details.uuid, location = %details.location, "registered with edge");
Ok((
details,
ControlSession {
shutdown: Some(shutdown_tx),
done: Some(driver_done_rx),
_join: join,
},
))
}
fn build_register_request(
server: &tunnelrpc_capnp::registration_server::Client,
auth: &TunnelAuth,
tunnel_id: Uuid,
conn_index: u8,
options: &ConnectionOptions,
) -> Result<
capnp::capability::Request<
tunnelrpc_capnp::registration_server::register_connection_params::Owned,
tunnelrpc_capnp::registration_server::register_connection_results::Owned,
>,
TunnelError,
> {
let mut request = server.register_connection_request();
{
let mut params = request.get();
let mut a = params.reborrow().init_auth();
a.set_account_tag(auth.account_tag.as_str());
a.set_tunnel_secret(&auth.tunnel_secret);
params.set_tunnel_id(tunnel_id.as_bytes());
params.set_conn_index(conn_index);
let mut o = params.reborrow().init_options();
{
let mut client = o.reborrow().init_client();
client.set_client_id(&options.client_id);
client.set_version(options.version.as_str());
client.set_arch(options.arch.as_str());
let mut feats = client.init_features(options.features.len() as u32);
for (i, f) in options.features.iter().enumerate() {
feats.set(i as u32, f.as_str());
}
}
o.set_origin_local_ip(&options.origin_local_ip);
o.set_replace_existing(options.replace_existing);
o.set_compression_quality(options.compression_quality);
o.set_num_previous_attempts(options.num_previous_attempts);
}
Ok(request)
}
fn decode_connection_response(
response: tunnelrpc_capnp::connection_response::Reader,
) -> Result<RegistrationDetails, TunnelError> {
use tunnelrpc_capnp::connection_response::result::WhichReader;
let result = response.get_result();
match result
.which()
.map_err(|e| TunnelError::Register(format!("ConnectionResponse union: {e:?}")))?
{
WhichReader::Error(err_reader) => {
let err = err_reader
.map_err(|e| TunnelError::Register(format!("ConnectionError reader: {e}")))?;
let cause = err
.get_cause()
.ok()
.and_then(|t| t.to_string().ok())
.unwrap_or_else(|| "<missing cause>".into());
if cause == DUPLICATE_CONNECTION_ERROR {
return Err(TunnelError::Register(format!(
"duplicate connection (edge already has connIndex registered): {cause}"
)));
}
Err(TunnelError::Register(cause))
}
WhichReader::ConnectionDetails(details_reader) => {
let d = details_reader
.map_err(|e| TunnelError::Register(format!("ConnectionDetails reader: {e}")))?;
let uuid_bytes = d
.get_uuid()
.map_err(|e| TunnelError::Register(format!("ConnectionDetails.uuid: {e}")))?;
if uuid_bytes.len() != 16 {
return Err(TunnelError::Register(format!(
"ConnectionDetails.uuid wrong length: {}",
uuid_bytes.len()
)));
}
let mut u = [0u8; 16];
u.copy_from_slice(uuid_bytes);
let uuid = Uuid::from_bytes(u);
let location = d
.get_location_name()
.ok()
.and_then(|t| t.to_string().ok())
.unwrap_or_default();
let tunnel_is_remotely_managed = d.get_tunnel_is_remotely_managed();
Ok(RegistrationDetails {
uuid,
location,
tunnel_is_remotely_managed,
})
}
}
}
#[allow(dead_code)]
async fn drive<F: std::future::Future>(
f: F,
label: &'static str,
) -> Result<F::Output, TunnelError> {
timeout(DEFAULT_RPC_TIMEOUT, f)
.await
.map_err(|_| TunnelError::Register(format!("{label} timed out")))
}
#[allow(dead_code)]
fn _suppress_unused_promise() -> Promise<(), capnp::Error> {
Promise::ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_options_carry_features() {
let o = ConnectionOptions::default_for_quick_tunnel("test/0.1");
assert!(o.features.contains(&"serialized_headers".to_string()));
assert_eq!(o.client_id.len(), 16);
assert!(o.version.contains("test/0.1"));
}
#[test]
fn duplicate_sentinel_matches_upstream() {
assert_eq!(
DUPLICATE_CONNECTION_ERROR,
"edge already has connection registered for the given connection identifier"
);
}
}