#![cfg(not(target_arch = "wasm32"))]
#![allow(clippy::let_and_return)]
use std::{
convert::{TryFrom, TryInto},
fmt,
sync::Arc,
time::Duration,
};
use async_nats::HeaderMap;
use futures::Future;
#[cfg(feature = "prometheus")]
use prometheus::{IntCounter, Opts};
use serde::{de::DeserializeOwned, Serialize};
use serde_json::Value as JsonValue;
use tracing::{debug, error, info, instrument, trace, warn};
#[cfg(feature = "otel")]
use crate::otel::OtelHeaderInjector;
use crate::{
chunkify::{needs_chunking, ChunkEndpoint},
common::Message,
core::{Invocation, InvocationResponse, WasmCloudEntity},
error::{RpcError, RpcResult},
provider_main::get_host_bridge_safe,
wascap::{jwt, prelude::Claims},
};
pub(crate) const DEFAULT_RPC_TIMEOUT_MILLIS: Duration = Duration::from_millis(2000);
pub(crate) const CHUNK_RPC_EXTRA_TIME: Duration = Duration::from_secs(13);
#[derive(Clone)]
pub struct RpcClient {
client: async_nats::Client,
key: Arc<wascap::prelude::KeyPair>,
host_id: String,
timeout: Option<Duration>,
#[cfg(feature = "prometheus")]
pub(crate) stats: Arc<RpcStats>,
}
impl fmt::Debug for RpcClient {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("RpcClient()")
}
}
#[cfg(feature = "prometheus")]
#[derive(Debug)]
pub struct RpcStats {
pub(crate) rpc_sent: IntCounter,
pub(crate) rpc_sent_err: IntCounter,
pub(crate) rpc_sent_chunky: IntCounter,
pub(crate) rpc_sent_resp_chunky: IntCounter,
pub(crate) rpc_sent_bytes: IntCounter,
pub(crate) rpc_sent_resp_bytes: IntCounter,
pub(crate) rpc_sent_timeouts: IntCounter,
pub(crate) rpc_recv: IntCounter,
pub(crate) rpc_recv_err: IntCounter,
pub(crate) rpc_recv_chunky: IntCounter,
pub(crate) rpc_recv_resp_chunky: IntCounter,
pub(crate) rpc_recv_bytes: IntCounter,
pub(crate) rpc_recv_resp_bytes: IntCounter,
}
#[doc(hidden)]
pub fn rpc_topic(entity: &WasmCloudEntity, lattice_prefix: &str) -> String {
if !entity.link_name.is_empty() {
format!(
"wasmbus.rpc.{}.{}.{}",
lattice_prefix, entity.public_key, entity.link_name
)
} else {
format!("wasmbus.rpc.{}.{}", lattice_prefix, entity.public_key)
}
}
impl RpcClient {
pub fn new(
nats: async_nats::Client,
host_id: String,
timeout: Option<Duration>,
key_pair: Arc<wascap::prelude::KeyPair>,
) -> Self {
Self::new_client(nats, host_id, timeout, key_pair)
}
pub(crate) fn new_client(
nats: async_nats::Client,
host_id: String,
timeout: Option<Duration>,
key_pair: Arc<wascap::prelude::KeyPair>,
) -> Self {
RpcClient {
client: nats,
host_id,
timeout,
#[cfg(feature = "prometheus")]
stats: Arc::new(RpcStats::init(key_pair.public_key())),
key: key_pair,
}
}
pub fn client(&self) -> async_nats::Client {
self.client.clone()
}
pub fn set_timeout(&mut self, timeout: Option<Duration>) {
self.timeout = timeout;
}
pub async fn send_json<Target, Arg, Resp>(
&self,
origin: WasmCloudEntity,
target: Target,
lattice: &str,
method: &str,
data: JsonValue,
) -> RpcResult<JsonValue>
where
Arg: DeserializeOwned + Serialize,
Resp: DeserializeOwned + Serialize,
Target: Into<WasmCloudEntity>,
{
let msg = JsonMessage(method, data).try_into()?;
let bytes = self.send(origin, target, lattice, msg).await?;
let resp = response_to_json::<Resp>(&bytes)?;
Ok(resp)
}
pub async fn send<Target>(
&self,
origin: WasmCloudEntity,
target: Target,
lattice: &str,
message: Message<'_>,
) -> RpcResult<Vec<u8>>
where
Target: Into<WasmCloudEntity>,
{
let rc = self.inner_rpc(origin, target, lattice, message, true, self.timeout).await;
#[cfg(feature = "prometheus")]
{
if rc.is_err() {
self.stats.rpc_sent_err.inc()
}
}
rc
}
pub async fn send_timeout<Target>(
&self,
origin: WasmCloudEntity,
target: Target,
lattice: &str,
message: Message<'_>,
timeout: Duration,
) -> RpcResult<Vec<u8>>
where
Target: Into<WasmCloudEntity>,
{
let rc = self.inner_rpc(origin, target, lattice, message, true, Some(timeout)).await;
#[cfg(feature = "prometheus")]
{
if rc.is_err() {
self.stats.rpc_sent_err.inc();
}
}
rc
}
#[doc(hidden)]
pub async fn post<Target>(
&self,
origin: WasmCloudEntity,
target: Target,
lattice: &str,
message: Message<'_>,
) -> RpcResult<()>
where
Target: Into<WasmCloudEntity>,
{
let rc = self.inner_rpc(origin, target, lattice, message, false, None).await;
match rc {
Err(e) => {
#[cfg(feature = "prometheus")]
self.stats.rpc_sent_err.inc();
Err(e)
}
Ok(_) => Ok(()),
}
}
#[instrument(level = "debug", skip(self, origin, target, message), fields( provider_id = tracing::field::Empty, method = tracing::field::Empty, lattice_id = tracing::field::Empty, subject = tracing::field::Empty, issuer = tracing::field::Empty, sender_key = tracing::field::Empty, contract_id = tracing::field::Empty, link_name = tracing::field::Empty, target_key = tracing::field::Empty ))]
async fn inner_rpc<Target>(
&self,
origin: WasmCloudEntity,
target: Target,
lattice: &str,
message: Message<'_>,
expect_response: bool,
timeout: Option<Duration>,
) -> RpcResult<Vec<u8>>
where
Target: Into<WasmCloudEntity>,
{
let target = target.into();
let origin_url = origin.url();
let subject = make_uuid();
let issuer = self.key.public_key();
let raw_target_url = target.url();
let target_url = format!("{}/{}", raw_target_url, &message.method);
let span = tracing::span::Span::current();
if let Some(hb) = get_host_bridge_safe() {
span.record("provider_id", &tracing::field::display(&hb.provider_key()));
}
span.record("method", &tracing::field::display(&message.method));
span.record("lattice_id", &tracing::field::display(&lattice));
span.record("subject", &tracing::field::display(&subject));
span.record("issuer", &tracing::field::display(&issuer));
if !origin.public_key.is_empty() {
span.record("sender_key", &tracing::field::display(&origin.public_key));
}
if !target.contract_id.is_empty() {
span.record("contract_id", &tracing::field::display(&target.contract_id));
}
if !target.link_name.is_empty() {
span.record("link_name", &tracing::field::display(&target.link_name));
}
if !target.public_key.is_empty() {
span.record("target_key", &tracing::field::display(&target.public_key));
}
let claims = Claims::<jwt::Invocation>::new(
issuer.clone(),
subject.clone(),
&target_url,
&origin_url,
&invocation_hash(&target_url, &origin_url, message.method, &message.arg),
);
let topic = rpc_topic(&target, lattice);
let method = message.method.to_string();
let len = message.arg.len();
let chunkify = needs_chunking(len);
let (invocation, body) = {
let mut inv = Invocation {
origin,
target,
operation: method.clone(),
id: subject,
encoded_claims: claims.encode(&self.key).unwrap_or_default(),
host_id: self.host_id.clone(),
content_length: Some(len as u64),
..Default::default()
};
if chunkify {
(inv, Some(Vec::from(message.arg)))
} else {
inv.msg = Vec::from(message.arg);
(inv, None)
}
};
let nats_body = crate::common::serialize(&invocation)?;
if let Some(body) = body {
let inv_id = invocation.id.clone();
debug!(invocation_id = %inv_id, %len, "chunkifying invocation");
let lattice = lattice.to_string();
if let Err(error) = ChunkEndpoint::with_client(lattice, self.client(), None)
.chunkify(&inv_id, &mut body.as_slice())
.await
{
error!(%error, "chunking error");
return Err(RpcError::Other(error.to_string()));
}
}
let timeout = if chunkify {
timeout.map(|t| t + CHUNK_RPC_EXTRA_TIME)
} else {
timeout
};
#[cfg(feature = "prometheus")]
{
self.stats.rpc_sent.inc();
if let Some(len) = invocation.content_length {
self.stats.rpc_sent_bytes.inc_by(len);
}
if chunkify {
self.stats.rpc_sent_chunky.inc();
}
}
if expect_response {
let this = self.clone();
let topic_ = topic.clone();
let payload = if let Some(timeout) = timeout {
match tokio::time::timeout(timeout, this.request(topic, nats_body)).await {
Err(elapsed) => {
#[cfg(feature = "prometheus")]
self.stats.rpc_sent_timeouts.inc();
Err(RpcError::Timeout(elapsed.to_string()))
}
Ok(Ok(data)) => Ok(data),
Ok(Err(err)) => Err(RpcError::Nats(err.to_string())),
}
} else {
this.request(topic, nats_body)
.await
.map_err(|e| RpcError::Nats(e.to_string()))
}
.map_err(|error| {
error!(%error, topic=%topic_, "sending request");
error
})?;
let inv_response =
crate::common::deserialize::<InvocationResponse>(&payload).map_err(|e| {
RpcError::Deser(format!("response to {}: {}", &method, &e.to_string()))
})?;
match inv_response.error {
None => {
#[cfg(feature = "prometheus")]
if let Some(len) = inv_response.content_length {
self.stats.rpc_sent_resp_bytes.inc_by(len);
}
let msg = if inv_response.content_length.is_some()
&& inv_response.content_length.unwrap() > inv_response.msg.len() as u64
{
let lattice = lattice.to_string();
#[cfg(feature = "prometheus")]
{
self.stats.rpc_sent_resp_chunky.inc();
}
ChunkEndpoint::with_client(lattice, self.client(), None)
.get_unchunkified_response(&inv_response.invocation_id)
.await?
} else {
inv_response.msg
};
trace!("rpc ok response");
Ok(msg)
}
Some(err) => {
error!(error = %err, "rpc error response");
Err(RpcError::Rpc(err))
}
}
} else {
self.publish(topic, nats_body)
.await
.map_err(|e| RpcError::Nats(format!("publish error: {}: {}", target_url, e)))?;
Ok(Vec::new())
}
}
#[instrument(level = "debug", skip_all, fields(subject = %subject))]
pub async fn request(&self, subject: String, payload: Vec<u8>) -> RpcResult<Vec<u8>> {
#[cfg(feature = "otel")]
let headers: Option<HeaderMap> = Some(OtelHeaderInjector::default_with_span().into());
#[cfg(not(feature = "otel"))]
let headers: Option<HeaderMap> = None;
let nc = self.client();
match self
.maybe_timeout(self.timeout, async move {
if let Some(headers) = headers {
nc.request_with_headers(subject, headers, payload.into()).await
} else {
nc.request(subject, payload.into()).await
}
})
.await
{
Err(error) => {
error!(%error, "sending request");
Err(error)
}
Ok(message) => Ok(message.payload.to_vec()),
}
}
#[instrument(level = "debug", skip_all, fields(subject = %subject))]
pub async fn publish(&self, subject: String, payload: Vec<u8>) -> RpcResult<()> {
#[cfg(feature = "otel")]
let headers: Option<HeaderMap> = Some(OtelHeaderInjector::default_with_span().into());
#[cfg(not(feature = "otel"))]
let headers: Option<HeaderMap> = None;
let nc = self.client();
self.maybe_timeout(self.timeout, async move {
if let Some(headers) = headers {
nc.publish_with_headers(subject, headers, payload.into())
.await
.map_err(|e| RpcError::Nats(e.to_string()))
} else {
nc.publish(subject, payload.into())
.await
.map_err(|e| RpcError::Nats(e.to_string()))
}
})
.await?;
let nc = self.client();
tokio::spawn(async move {
if let Err(error) = nc.flush().await {
error!(%error, "flush after publish");
}
});
Ok(())
}
pub async fn publish_invocation_response(
&self,
reply_to: String,
response: InvocationResponse,
lattice: &str,
) -> RpcResult<()> {
let content_length = Some(response.msg.len() as u64);
let response = {
let inv_id = response.invocation_id.clone();
if needs_chunking(response.msg.len()) {
#[cfg(feature = "prometheus")]
{
self.stats.rpc_recv_resp_chunky.inc();
}
let buf = response.msg;
ChunkEndpoint::with_client(lattice.to_string(), self.client(), None)
.chunkify_response(&inv_id, &mut buf.as_slice())
.await?;
InvocationResponse {
msg: Vec::new(),
content_length,
..response
}
} else {
InvocationResponse { content_length, ..response }
}
};
match crate::common::serialize(&response) {
Ok(t) => Ok(self.publish(reply_to, t).await?),
Err(e) => {
Err(RpcError::Ser(format!("InvocationResponse: {}", e)))
}
}
}
pub async fn dechunk(&self, mut inv: Invocation, lattice: &str) -> RpcResult<Invocation> {
if inv.content_length.is_some() && inv.content_length.unwrap() > inv.msg.len() as u64 {
#[cfg(feature = "prometheus")]
{
self.stats.rpc_recv_chunky.inc();
}
inv.msg = ChunkEndpoint::with_client(lattice.to_string(), self.client(), None)
.get_unchunkified(&inv.id.clone())
.await
.map_err(|e| e.to_string())?;
}
Ok(inv)
}
pub async fn validate_invocation(
&self,
inv: Invocation,
) -> Result<(Invocation, Claims<jwt::Invocation>), String> {
let vr = jwt::validate_token::<jwt::Invocation>(&inv.encoded_claims)
.map_err(|e| format!("{}", e))?;
if vr.expired {
return Err("Invocation claims token expired".into());
}
if !vr.signature_valid {
return Err("Invocation claims signature invalid".into());
}
if vr.cannot_use_yet {
return Err("Attempt to use invocation before claims token allows".into());
}
let target_url = format!("{}/{}", inv.target.url(), &inv.operation);
let hash = invocation_hash(&target_url, &inv.origin.url(), &inv.operation, &inv.msg);
let claims =
Claims::<jwt::Invocation>::decode(&inv.encoded_claims).map_err(|e| format!("{}", e))?;
let inv_claims = claims
.metadata
.as_ref()
.ok_or_else(|| "No wascap metadata found on claims".to_string())?;
if inv_claims.invocation_hash != hash {
return Err(format!(
"Invocation hash does not match signed claims hash ({} / {})",
inv_claims.invocation_hash, hash
));
}
if !inv.host_id.starts_with('N') && inv.host_id.len() != 56 {
return Err(format!("Invalid host ID on invocation: '{}'", inv.host_id));
}
if inv_claims.target_url != target_url {
return Err(format!(
"Invocation claims and invocation target URL do not match: {} != {}",
&inv_claims.target_url, &target_url
));
}
if inv_claims.origin_url != inv.origin.url() {
return Err("Invocation claims and invocation origin URL do not match".into());
}
Ok((inv, claims))
}
async fn maybe_timeout<F, T, E>(&self, t: Option<Duration>, f: F) -> RpcResult<T>
where
F: Future<Output = Result<T, E>> + Send + Sync + 'static,
T: 'static,
E: ToString,
{
if let Some(timeout) = t {
match tokio::time::timeout(timeout, f).await {
Err(elapsed) => {
#[cfg(feature = "prometheus")]
self.stats.rpc_sent_timeouts.inc();
Err(RpcError::Timeout(elapsed.to_string()))
}
Ok(Ok(data)) => Ok(data),
Ok(Err(err)) => Err(RpcError::Nats(err.to_string())),
}
} else {
f.await.map_err(|e| RpcError::Nats(e.to_string()))
}
}
}
pub fn with_connection_event_logging(
opts: async_nats::ConnectOptions,
) -> async_nats::ConnectOptions {
use async_nats::Event;
opts.event_callback(|event| async move {
match event {
Event::Disconnected => warn!("nats client disconnected"),
Event::Connected => info!("nats client connected"),
Event::ClientError(err) => error!("nats client error: '{:?}'", err),
Event::ServerError(err) => error!("nats server error: '{:?}'", err),
Event::SlowConsumer(val) => warn!("nats slow consumer detected ({})", val),
Event::LameDuckMode => warn!("nats lame duck mode"),
}
})
}
#[derive(Clone)]
pub struct InvocationArg {
pub origin: String,
pub operation: String,
pub arg: Vec<u8>,
}
pub(crate) fn invocation_hash(
target_url: &str,
origin_url: &str,
method: &str,
args: &[u8],
) -> String {
use sha2::Digest as _;
let mut hasher = sha2::Sha256::new();
hasher.update(origin_url.as_bytes());
hasher.update(target_url.as_bytes());
hasher.update(method.as_bytes());
hasher.update(args);
let digest = hasher.finalize();
data_encoding::HEXUPPER.encode(digest.as_slice())
}
#[doc(hidden)]
pub fn make_uuid() -> String {
use uuid::Uuid;
Uuid::new_v4()
.as_simple()
.encode_lower(&mut Uuid::encode_buffer())
.to_string()
}
struct JsonMessage<'m>(&'m str, JsonValue);
impl<'m> TryFrom<JsonMessage<'m>> for Message<'m> {
type Error = RpcError;
fn try_from(jm: JsonMessage<'m>) -> Result<Message<'m>, Self::Error> {
let arg = json_to_args::<JsonValue>(jm.1)?;
Ok(Message {
method: jm.0,
arg: std::borrow::Cow::Owned(arg),
})
}
}
fn json_to_args<T>(v: JsonValue) -> RpcResult<Vec<u8>>
where
T: Serialize,
T: DeserializeOwned,
{
crate::common::serialize(
&serde_json::from_value::<T>(v)
.map_err(|e| RpcError::Deser(format!("invalid params: {}.", e)))?,
)
}
fn response_to_json<T>(msg: &[u8]) -> RpcResult<JsonValue>
where
T: Serialize,
T: DeserializeOwned,
{
serde_json::to_value(crate::common::deserialize::<T>(msg)?)
.map_err(|e| RpcError::Ser(format!("response serialization : {}.", e)))
}
#[cfg(feature = "prometheus")]
impl RpcStats {
fn init(public_key: String) -> RpcStats {
let mut map = std::collections::HashMap::new();
map.insert("public_key".to_string(), public_key);
RpcStats {
rpc_sent: IntCounter::with_opts(
Opts::new("rpc_sent", "number of rpc nats messages sent").const_labels(map.clone()),
)
.unwrap(),
rpc_sent_err: IntCounter::with_opts(
Opts::new("rpc_sent_err", "number of errors sending rpc").const_labels(map.clone()),
)
.unwrap(),
rpc_sent_chunky: IntCounter::with_opts(
Opts::new(
"rpc_sent_chunky",
"number of rpc messages that were chunkified",
)
.const_labels(map.clone()),
)
.unwrap(),
rpc_sent_resp_chunky: IntCounter::with_opts(
Opts::new(
"rpc_sent_resp_chunky",
"number of responses to sent rpc that were chunkified",
)
.const_labels(map.clone()),
)
.unwrap(),
rpc_sent_bytes: IntCounter::with_opts(
Opts::new("rpc_sent_bytes", "total bytes sent in rpc requests")
.const_labels(map.clone()),
)
.unwrap(),
rpc_sent_resp_bytes: IntCounter::with_opts(
Opts::new(
"rpc_sent_resp_bytes",
"total bytes sent in responses to incoming rpc",
)
.const_labels(map.clone()),
)
.unwrap(),
rpc_sent_timeouts: IntCounter::with_opts(
Opts::new(
"rpc_sent_timeouts",
"number of rpc messages that incurred timeout error",
)
.const_labels(map.clone()),
)
.unwrap(),
rpc_recv: IntCounter::with_opts(
Opts::new("rpc_recv", "number of rpc messages received").const_labels(map.clone()),
)
.unwrap(),
rpc_recv_err: IntCounter::with_opts(
Opts::new(
"rpc_recv_err",
"number of errors encountered responding to incoming rpc",
)
.const_labels(map.clone()),
)
.unwrap(),
rpc_recv_chunky: IntCounter::with_opts(
Opts::new(
"rpc_recv_chunky",
"number of received rpc that were chunkified",
)
.const_labels(map.clone()),
)
.unwrap(),
rpc_recv_resp_chunky: IntCounter::with_opts(
Opts::new(
"rpc_recv_resp_chunky",
"number of chunkified responses to received rpc",
)
.const_labels(map.clone()),
)
.unwrap(),
rpc_recv_bytes: IntCounter::with_opts(
Opts::new("rpc_recv_bytes", "total bytes in received rpc")
.const_labels(map.clone()),
)
.unwrap(),
rpc_recv_resp_bytes: IntCounter::with_opts(
Opts::new(
"rpc_recv_resp_bytes",
"total bytes in responses to incoming rpc",
)
.const_labels(map.clone()),
)
.unwrap(),
}
}
}