use std::{
str::FromStr,
sync::{Arc, RwLock},
};
use anyhow::{Result, anyhow, ensure};
use iroh::{Endpoint, EndpointAddr, EndpointId, endpoint::ConnectError};
use iroh_metrics::{MetricsGroup, Registry, encoding::Encoder};
use irpc_iroh::IrohLazyRemoteConnection;
use n0_error::StackResultExt;
use n0_future::{task::AbortOnDropHandle, time::Duration};
use rcan::Rcan;
use tokio::sync::oneshot;
use tracing::{debug, trace, warn};
use uuid::Uuid;
use crate::{
api_secret::ApiSecret,
caps::Caps,
net_diagnostics::{DiagnosticsReport, checks::run_diagnostics},
protocol::{
ALPN, Auth, IrohServicesClient, NameEndpoint, Ping, Pong, PutMetrics,
PutNetworkDiagnostics, RemoteError,
},
};
#[derive(Debug, Clone)]
pub struct Client {
#[allow(dead_code)]
endpoint: Endpoint,
message_channel: tokio::sync::mpsc::Sender<ClientActorMessage>,
_actor_task: Arc<AbortOnDropHandle<()>>,
}
pub struct ClientBuilder {
#[allow(dead_code)]
cap_expiry: Duration,
cap: Option<Rcan<Caps>>,
endpoint: Endpoint,
name: Option<String>,
metrics_interval: Option<Duration>,
remote: Option<EndpointAddr>,
registry: Registry,
}
const DEFAULT_CAP_EXPIRY: Duration = Duration::from_secs(60 * 60 * 24 * 30); pub const API_SECRET_ENV_VAR_NAME: &str = "IROH_SERVICES_API_SECRET";
impl ClientBuilder {
pub fn new(endpoint: &Endpoint) -> Self {
let mut registry = Registry::default();
registry.register_all(endpoint.metrics());
Self {
cap: None,
cap_expiry: DEFAULT_CAP_EXPIRY,
endpoint: endpoint.clone(),
name: None,
metrics_interval: Some(Duration::from_secs(60)),
remote: None,
registry,
}
}
pub fn register_metrics_group(mut self, metrics_group: Arc<dyn MetricsGroup>) -> Self {
self.registry.register(metrics_group);
self
}
pub fn metrics_interval(mut self, interval: Duration) -> Self {
self.metrics_interval = Some(interval);
self
}
pub fn disable_metrics_interval(mut self) -> Self {
self.metrics_interval = None;
self
}
pub fn name(mut self, name: impl Into<String>) -> Result<Self> {
let name = name.into();
validate_name(&name).map_err(BuildError::InvalidName)?;
self.name = Some(name);
Ok(self)
}
pub fn api_secret_from_env(self) -> Result<Self> {
let ticket = ApiSecret::from_env_var(API_SECRET_ENV_VAR_NAME)?;
self.api_secret(ticket)
}
pub fn api_secret_from_str(self, secret_key: &str) -> Result<Self> {
let key = ApiSecret::from_str(secret_key).context("invalid iroh services api secret")?;
self.api_secret(key)
}
pub fn api_secret(mut self, ticket: ApiSecret) -> Result<Self> {
let local_id = self.endpoint.id();
let rcan = crate::caps::create_api_token_from_secret_key(
ticket.secret,
local_id,
self.cap_expiry,
Caps::for_shared_secret(),
)?;
self.remote = Some(ticket.remote);
self.rcan(rcan)
}
#[cfg(not(target_arch = "wasm32"))]
pub async fn ssh_key_from_file<P: AsRef<std::path::Path>>(self, path: P) -> Result<Self> {
let file_content = tokio::fs::read_to_string(path).await?;
let private_key = ssh_key::PrivateKey::from_openssh(&file_content)?;
self.ssh_key(&private_key)
}
#[cfg(not(target_arch = "wasm32"))]
pub fn ssh_key(mut self, key: &ssh_key::PrivateKey) -> Result<Self> {
let local_id = self.endpoint.id();
let rcan = crate::caps::create_api_token_from_ssh_key(
key,
local_id,
self.cap_expiry,
Caps::all(),
)?;
self.cap.replace(rcan);
Ok(self)
}
pub fn rcan(mut self, cap: Rcan<Caps>) -> Result<Self> {
ensure!(
EndpointId::from_verifying_key(*cap.audience()) == self.endpoint.id(),
"invalid audience"
);
self.cap.replace(cap);
Ok(self)
}
pub fn remote(mut self, remote: impl Into<EndpointAddr>) -> Self {
self.remote = Some(remote.into());
self
}
#[must_use = "dropping the client will silently cancel all client tasks"]
pub async fn build(self) -> Result<Client, BuildError> {
debug!("starting iroh-services client");
let remote = self.remote.ok_or(BuildError::MissingRemote)?;
let capabilities = self.cap.ok_or(BuildError::MissingCapability)?;
let conn = IrohLazyRemoteConnection::new(self.endpoint.clone(), remote, ALPN.to_vec());
let irpc_client = IrohServicesClient::boxed(conn);
let (tx, rx) = tokio::sync::mpsc::channel(1);
let actor_task = AbortOnDropHandle::new(n0_future::task::spawn(
ClientActor {
capabilities,
client: irpc_client,
name: self.name.clone(),
session_id: Uuid::new_v4(),
authorized: false,
}
.run(self.name, self.registry, self.metrics_interval, rx),
));
Ok(Client {
endpoint: self.endpoint,
message_channel: tx,
_actor_task: Arc::new(actor_task),
})
}
}
#[derive(thiserror::Error, Debug)]
pub enum BuildError {
#[error("Missing remote endpoint to dial")]
MissingRemote,
#[error("Missing capability")]
MissingCapability,
#[error("Unauthorized")]
Unauthorized,
#[error("Remote error: {0}")]
Remote(#[from] RemoteError),
#[error("Rpc connection error: {0}")]
Rpc(irpc::Error),
#[error("Connection error: {0}")]
Connect(ConnectError),
#[error("Invalid endpoint name: {0}")]
InvalidName(#[from] ValidateNameError),
}
impl From<irpc::Error> for BuildError {
fn from(value: irpc::Error) -> Self {
match value {
irpc::Error::Request {
source:
irpc::RequestError::Connection {
source: iroh::endpoint::ConnectionError::ApplicationClosed(frame),
..
},
..
} if frame.error_code == 401u32.into() => Self::Unauthorized,
value => Self::Rpc(value),
}
}
}
pub const CLIENT_NAME_MIN_LENGTH: usize = 2;
pub const CLIENT_NAME_MAX_LENGTH: usize = 128;
#[derive(Debug, thiserror::Error)]
pub enum ValidateNameError {
#[error("Name is too long (must be no more than {CLIENT_NAME_MAX_LENGTH} characters).")]
TooLong,
#[error("Name is too short (must be at least {CLIENT_NAME_MIN_LENGTH} characters).")]
TooShort,
}
fn validate_name(name: &str) -> Result<(), ValidateNameError> {
if name.len() < CLIENT_NAME_MIN_LENGTH {
Err(ValidateNameError::TooShort)
} else if name.len() > CLIENT_NAME_MAX_LENGTH {
Err(ValidateNameError::TooLong)
} else {
Ok(())
}
}
#[derive(thiserror::Error, Debug)]
pub enum Error {
#[error("Invalid endpoint name: {0}")]
InvalidName(#[from] ValidateNameError),
#[error("Remote error: {0}")]
Remote(#[from] RemoteError),
#[error("Connection error: {0}")]
Rpc(#[from] irpc::Error),
#[error(transparent)]
Other(#[from] anyhow::Error),
}
impl Client {
pub fn builder(endpoint: &Endpoint) -> ClientBuilder {
ClientBuilder::new(endpoint)
}
pub async fn name(&self) -> Result<Option<String>, Error> {
let (tx, rx) = oneshot::channel();
self.message_channel
.send(ClientActorMessage::ReadName { done: tx })
.await
.map_err(|_| Error::Other(anyhow!("sending name read request")))?;
rx.await
.map_err(|e| Error::Other(anyhow!("response on internal channel: {:?}", e)))
}
pub async fn set_name(&self, name: impl Into<String>) -> Result<(), Error> {
set_name_inner(self.message_channel.clone(), name.into()).await
}
pub async fn ping(&self) -> Result<Pong, Error> {
let (tx, rx) = oneshot::channel();
self.message_channel
.send(ClientActorMessage::Ping { done: tx })
.await
.map_err(|_| Error::Other(anyhow!("sending ping request")))?;
rx.await
.map_err(|e| Error::Other(anyhow!("response on internal channel: {:?}", e)))?
.map_err(Error::Remote)
}
pub async fn push_metrics(&self) -> Result<(), Error> {
let (tx, rx) = oneshot::channel();
self.message_channel
.send(ClientActorMessage::SendMetrics { done: tx })
.await
.map_err(|_| Error::Other(anyhow!("sending metrics")))?;
rx.await
.map_err(|e| Error::Other(anyhow!("response on internal channel: {:?}", e)))?
.map_err(Error::Remote)
}
pub async fn grant_capability(
&self,
remote_id: EndpointId,
caps: impl IntoIterator<Item = impl Into<crate::caps::Cap>>,
) -> Result<(), Error> {
let cap = crate::caps::create_grant_token(
self.endpoint.secret_key().clone(),
remote_id,
DEFAULT_CAP_EXPIRY,
Caps::new(caps),
)
.map_err(Error::Other)?;
let (tx, rx) = oneshot::channel();
self.message_channel
.send(ClientActorMessage::GrantCap {
cap: Box::new(cap),
done: tx,
})
.await
.map_err(|_| Error::Other(anyhow!("granting capability")))?;
rx.await
.map_err(|e| Error::Other(anyhow!("response on internal channel: {:?}", e)))?
}
pub async fn net_diagnostics(&self, send: bool) -> Result<DiagnosticsReport, Error> {
let report = run_diagnostics(&self.endpoint).await?;
if send {
let (tx, rx) = oneshot::channel();
self.message_channel
.send(ClientActorMessage::PutNetworkDiagnostics {
done: tx,
report: Box::new(report.clone()),
})
.await
.map_err(|_| Error::Other(anyhow!("sending network diagnostics report")))?;
let _ = rx
.await
.map_err(|e| Error::Other(anyhow!("response on internal channel: {:?}", e)))?;
}
Ok(report)
}
}
enum ClientActorMessage {
SendMetrics {
done: oneshot::Sender<Result<(), RemoteError>>,
},
Ping {
done: oneshot::Sender<Result<Pong, RemoteError>>,
},
#[allow(dead_code)]
GrantCap {
cap: Box<Rcan<Caps>>,
done: oneshot::Sender<Result<(), Error>>,
},
PutNetworkDiagnostics {
report: Box<DiagnosticsReport>,
done: oneshot::Sender<Result<(), Error>>,
},
ReadName {
done: oneshot::Sender<Option<String>>,
},
NameEndpoint {
name: String,
done: oneshot::Sender<Result<(), RemoteError>>,
},
}
struct ClientActor {
capabilities: Rcan<Caps>,
client: IrohServicesClient,
name: Option<String>,
session_id: Uuid,
authorized: bool,
}
impl ClientActor {
async fn run(
mut self,
initial_name: Option<String>,
registry: Registry,
interval: Option<Duration>,
mut inbox: tokio::sync::mpsc::Receiver<ClientActorMessage>,
) {
let registry = Arc::new(RwLock::new(registry));
let mut encoder = Encoder::new(registry);
let mut metrics_timer = interval.map(|interval| n0_future::time::interval(interval));
trace!("starting client actor");
if let Some(name) = initial_name
&& let Err(err) = self.send_name_endpoint(name).await
{
warn!(err = %err, "failed setting endpoint name on startup");
}
loop {
trace!("client actor tick");
tokio::select! {
biased;
Some(msg) = inbox.recv() => {
match msg {
ClientActorMessage::Ping{ done } => {
let res = self.send_ping().await;
if let Err(err) = done.send(res) {
debug!("failed to send ping: {:#?}", err);
self.authorized = false;
}
},
ClientActorMessage::SendMetrics{ done } => {
trace!("sending metrics manually triggered");
let res = self.send_metrics(&mut encoder).await;
if let Err(err) = done.send(res) {
debug!("failed to push metrics: {:#?}", err);
self.authorized = false;
}
}
ClientActorMessage::GrantCap{ cap, done } => {
let res = self.grant_cap(*cap).await;
if let Err(err) = done.send(res) {
warn!("failed to grant capability: {:#?}", err);
}
}
ClientActorMessage::ReadName{ done } => {
if let Err(err) = done.send(self.name.clone()) {
warn!("sending name value: {:#?}", err);
}
}
ClientActorMessage::NameEndpoint{ name, done } => {
let res = self.send_name_endpoint(name).await;
if let Err(err) = done.send(res) {
warn!("failed to name endpoint: {:#?}", err);
}
}
ClientActorMessage::PutNetworkDiagnostics{ report, done } => {
let res = self.put_network_diagnostics(*report).await;
if let Err(err) = done.send(res) {
warn!("failed to publish network diagnostics: {:#?}", err);
}
}
}
}
_ = async {
if let Some(ref mut timer) = metrics_timer {
timer.tick().await;
} else {
std::future::pending::<()>().await;
}
} => {
trace!("metrics send tick");
if let Err(err) = self.send_metrics(&mut encoder).await {
debug!("failed to push metrics: {:#?}", err);
self.authorized = false;
}
},
}
}
}
async fn auth(&mut self) -> Result<(), RemoteError> {
if self.authorized {
return Ok(());
}
trace!("client authorizing");
self.client
.rpc(Auth {
caps: self.capabilities.clone(),
})
.await
.inspect_err(|e| debug!("authorization failed: {:?}", e))
.map_err(|e| RemoteError::AuthError(e.to_string()))?;
self.authorized = true;
Ok(())
}
async fn send_ping(&mut self) -> Result<Pong, RemoteError> {
trace!("client actor send ping");
self.auth().await?;
let req = rand::random();
self.client
.rpc(Ping { req_id: req })
.await
.inspect_err(|e| warn!("rpc ping error: {e}"))
.map_err(|_| RemoteError::InternalServerError)
}
async fn send_name_endpoint(&mut self, name: String) -> Result<(), RemoteError> {
trace!("client sending name endpoint request");
self.auth().await?;
self.client
.rpc(NameEndpoint { name: name.clone() })
.await
.inspect_err(|e| debug!("name endpoint error: {e}"))
.map_err(|_| RemoteError::InternalServerError)??;
self.name = Some(name);
Ok(())
}
async fn send_metrics(&mut self, encoder: &mut Encoder) -> Result<(), RemoteError> {
trace!("client actor send metrics");
self.auth().await?;
let update = encoder.export();
let req = PutMetrics {
session_id: self.session_id,
update,
};
self.client
.rpc(req)
.await
.map_err(|_| RemoteError::InternalServerError)??;
Ok(())
}
async fn grant_cap(&mut self, cap: Rcan<Caps>) -> Result<(), Error> {
trace!("client actor grant capability");
self.auth().await?;
self.client
.rpc(crate::protocol::GrantCap { cap })
.await
.map_err(|_| RemoteError::InternalServerError)??;
Ok(())
}
async fn put_network_diagnostics(
&mut self,
report: crate::net_diagnostics::DiagnosticsReport,
) -> Result<(), Error> {
trace!("client actor publish network diagnostics");
self.auth().await?;
let req = PutNetworkDiagnostics { report };
self.client
.rpc(req)
.await
.map_err(|_| RemoteError::InternalServerError)??;
Ok(())
}
}
async fn set_name_inner(
message_channel: tokio::sync::mpsc::Sender<ClientActorMessage>,
name: String,
) -> Result<(), Error> {
validate_name(&name)?;
debug!(name_len = name.len(), "calling set name");
let (tx, rx) = oneshot::channel();
message_channel
.send(ClientActorMessage::NameEndpoint { name, done: tx })
.await
.map_err(|_| Error::Other(anyhow!("sending name endpoint request")))?;
rx.await
.map_err(|e| Error::Other(anyhow!("response on internal channel: {:?}", e)))?
.map_err(Error::Remote)
}
#[cfg(test)]
mod tests {
use iroh::{Endpoint, EndpointAddr, SecretKey, endpoint::presets};
use rand::{RngExt, SeedableRng};
use temp_env_vars::temp_env_vars;
use crate::{
Client,
api_secret::ApiSecret,
caps::{Cap, Caps},
client::{API_SECRET_ENV_VAR_NAME, BuildError, ValidateNameError},
};
#[tokio::test]
#[temp_env_vars]
async fn test_api_key_from_env() {
let mut rng = rand::rngs::ChaCha8Rng::seed_from_u64(0);
let shared_secret = SecretKey::from_bytes(&rng.random());
let fake_endpoint_id = SecretKey::from_bytes(&rng.random()).public();
let api_secret = ApiSecret::new(shared_secret.clone(), fake_endpoint_id);
unsafe {
std::env::set_var(API_SECRET_ENV_VAR_NAME, api_secret.to_string());
};
let endpoint = Endpoint::builder(presets::Minimal).bind().await.unwrap();
let builder = Client::builder(&endpoint).api_secret_from_env().unwrap();
let fake_endpoint_addr: EndpointAddr = fake_endpoint_id.into();
assert_eq!(builder.remote, Some(fake_endpoint_addr));
let cap = builder.cap.as_ref().expect("expected capability to be set");
assert_eq!(cap.capability(), &Caps::new([Cap::Client]));
assert_eq!(cap.audience(), &endpoint.id().as_verifying_key());
assert_eq!(cap.issuer(), &shared_secret.public().as_verifying_key());
}
#[tokio::test]
async fn test_no_metrics_interval() {
let mut rng = rand::rngs::ChaCha8Rng::seed_from_u64(1);
let shared_secret = SecretKey::from_bytes(&rng.random());
let fake_endpoint_id = SecretKey::from_bytes(&rng.random()).public();
let api_secret = ApiSecret::new(shared_secret.clone(), fake_endpoint_id);
let endpoint = Endpoint::builder(presets::Minimal).bind().await.unwrap();
let client = Client::builder(&endpoint)
.disable_metrics_interval()
.api_secret(api_secret)
.unwrap()
.build()
.await
.unwrap();
let err = client.push_metrics().await;
assert!(err.is_err());
}
#[tokio::test]
async fn test_name() {
let mut rng = rand::rngs::ChaCha8Rng::seed_from_u64(0);
let shared_secret = SecretKey::from_bytes(&rng.random());
let fake_endpoint_id = SecretKey::from_bytes(&rng.random()).public();
let api_secret = ApiSecret::new(shared_secret.clone(), fake_endpoint_id);
let endpoint = Endpoint::builder(presets::Minimal).bind().await.unwrap();
let builder = Client::builder(&endpoint)
.name("my-node 👋")
.unwrap()
.api_secret(api_secret)
.unwrap();
assert_eq!(builder.name, Some("my-node 👋".to_string()));
let Err(err) = Client::builder(&endpoint).name("a") else {
panic!("name should fail for strings under 2 bytes");
};
assert!(matches!(
err.downcast_ref::<BuildError>(),
Some(BuildError::InvalidName(ValidateNameError::TooShort))
));
let too_long_name = "👋".repeat(129);
let Err(err) = Client::builder(&endpoint).name(&too_long_name) else {
panic!("name should fail for strings over 128 bytes");
};
assert!(matches!(
err.downcast_ref::<BuildError>(),
Some(BuildError::InvalidName(ValidateNameError::TooLong))
));
}
}