use std::collections::HashMap;
use std::io::Write;
use std::path::PathBuf;
use std::str::FromStr;
use std::time::Duration;
use anyhow::{bail, ensure, Context, Result};
use bytes::{Bytes, BytesMut};
use clap::Args;
use futures::StreamExt;
use serde::{Deserialize, Serialize};
use serde_json::json;
use termcolor::{Color, ColorChoice, ColorSpec, StandardStream, WriteColor};
use tracing::debug;
use wash_lib::cli::{validate_component_id, CommandOutput};
use wash_lib::config::DEFAULT_LATTICE;
use wasmcloud_core::parse_wit_meta_from_operation;
use wit_bindgen_wrpc::wrpc_transport::InvokeExt as _;
use wrpc_interface_http::InvokeIncomingHandler as _;
use crate::util::{default_timeout_ms, extract_arg_value, msgpack_to_json_val};
const DEFAULT_HTTP_SCHEME: &str = "http";
const DEFAULT_HTTP_HOST: &str = "localhost";
const DEFAULT_HTTP_PORT: u16 = 8080;
#[derive(Deserialize)]
struct TestResult {
#[serde(default)]
pub name: String,
#[serde(default)]
pub passed: bool,
#[serde(rename = "snapData")]
#[serde(with = "serde_bytes")]
#[serde(default, skip_serializing_if = "Option::is_none")]
pub snap_data: Option<Vec<u8>>,
}
fn print_test_results(results: &[TestResult]) {
#[derive(Deserialize)]
struct ErrorReport {
error: String,
}
let mut passed = 0u32;
let total = results.len() as u32;
let mut stdout = StandardStream::stdout(ColorChoice::Always);
let mut green = ColorSpec::new();
green.set_fg(Some(Color::Green));
let mut red = ColorSpec::new();
red.set_fg(Some(Color::Red));
for test in results.iter() {
if test.passed {
let _ = stdout.set_color(&green);
write!(&mut stdout, "Pass").unwrap();
let _ = stdout.reset();
writeln!(&mut stdout, ": {}", test.name).unwrap();
passed += 1;
} else {
let error_msg = test
.snap_data
.as_ref()
.map(|bytes| {
serde_json::from_slice::<ErrorReport>(bytes)
.map(|r| r.error)
.unwrap_or_default()
})
.unwrap_or_default();
let _ = stdout.set_color(&red);
write!(&mut stdout, "Fail").unwrap();
let _ = stdout.reset();
writeln!(&mut stdout, ": {}", error_msg).unwrap();
}
}
let status_color = if passed == total { green } else { red };
write!(&mut stdout, "Test results: ").unwrap();
let _ = stdout.set_color(&status_color);
writeln!(&mut stdout, "{}/{} Passed", passed, total).unwrap();
let _ = stdout.set_color(&ColorSpec::new());
writeln!(&mut stdout).unwrap();
}
#[derive(Debug, Args, Clone)]
#[clap(name = "call")]
pub struct CallCli {
#[clap(flatten)]
command: CallCommand,
}
impl CallCli {
pub fn command(self) -> CallCommand {
self.command
}
}
pub async fn handle_command(
CallCommand {
component_id,
function,
opts,
http_handler_invocation_opts,
http_response_extract_json,
..
}: CallCommand,
) -> Result<CommandOutput> {
ensure!(!component_id.is_empty(), "component ID may not be empty");
debug!(
?component_id,
?function,
"calling component function over wRPC"
);
let lattice = opts
.lattice
.clone()
.unwrap_or_else(|| DEFAULT_LATTICE.to_string());
let nc = create_client_from_opts_wrpc(&opts)
.await
.context("failed to create async nats client")?;
let wrpc_client =
wrpc_transport_nats::Client::new(nc, format!("{}.{component_id}", &lattice), None);
let (namespace, package, interface, name) = parse_wit_meta_from_operation(&function).context(
"Invalid function supplied. Must be in the form of `namespace:package/interface.function`",
)?;
let instance = format!("{namespace}:{package}/{interface}");
let name = name.context(
"Invalid function supplied. Must be in the form of `namespace:package/interface.function`",
)?;
debug!(
?component_id,
?instance,
?name,
?lattice,
"invoking component"
);
match function.as_str() {
"wrpc:http/incoming-handler.handle" | "wasi:http/incoming-handler.handle" => {
let request = http_handler_invocation_opts
.to_request()
.await
.context("failed to invoke handler with HTTP request options")?;
wrpc_invoke_http_handler(
wrpc_client,
&lattice,
&component_id,
opts.timeout_ms,
request,
http_response_extract_json,
)
.await
}
_ => {
wrpc_invoke_simple(
wrpc_client,
&lattice,
&component_id,
&instance,
&name,
opts.timeout_ms,
)
.await
}
}
}
#[derive(Debug, Clone, Args)]
pub struct ConnectionOpts {
#[clap(
short = 'r',
long = "rpc-host",
env = "WASMCLOUD_RPC_HOST",
default_value = "127.0.0.1"
)]
rpc_host: String,
#[clap(
short = 'p',
long = "rpc-port",
env = "WASMCLOUD_RPC_PORT",
default_value = "4222"
)]
rpc_port: String,
#[clap(
long = "rpc-jwt",
env = "WASMCLOUD_RPC_JWT",
hide_env_values = true,
requires = "rpc_seed"
)]
rpc_jwt: Option<String>,
#[clap(
long = "rpc-seed",
env = "WASMCLOUD_RPC_SEED",
hide_env_values = true,
requires = "rpc_jwt"
)]
rpc_seed: Option<String>,
#[clap(long = "rpc-credsfile", env = "WASH_RPC_CREDS", hide_env_values = true)]
rpc_credsfile: Option<PathBuf>,
#[clap(
long = "rpc-ca-file",
env = "WASH_RPC_TLS_CA_FILE",
hide_env_values = true
)]
rpc_ca_file: Option<PathBuf>,
#[clap(short = 'x', long = "lattice", env = "WASMCLOUD_LATTICE")]
lattice: Option<String>,
#[clap(
short = 't',
long = "rpc-timeout-ms",
default_value_t = default_timeout_ms(),
env = "WASMCLOUD_RPC_TIMEOUT_MS"
)]
timeout_ms: u64,
#[clap(long = "context")]
pub context: Option<String>,
}
#[derive(Args, Debug, Clone)]
pub struct CallCommand {
#[clap(flatten)]
opts: ConnectionOpts,
#[clap(name = "component-id", value_parser = validate_component_id)]
pub component_id: String,
#[clap(name = "function")]
pub function: String,
#[clap(
long = "http-response-extract-json",
default_value_t = false,
env = "WASH_CALL_HTTP_RESPONSE_EXTRACT_JSON"
)]
pub http_response_extract_json: bool,
#[clap(flatten)]
pub http_handler_invocation_opts: HttpHandlerInvocationOpts,
}
#[derive(Debug, Clone, Deserialize, Args)]
pub struct HttpHandlerInvocationOpts {
#[clap(long = "http-scheme", env = "WASH_CALL_INVOKE_HTTP_SCHEME")]
http_scheme: Option<String>,
#[clap(long = "http-host", env = "WASH_CALL_INVOKE_HTTP_HOST")]
http_host: Option<String>,
#[clap(long = "http-port", env = "WASH_CALL_INVOKE_HTTP_PORT")]
http_port: Option<u16>,
#[clap(long = "http-method", env = "WASH_CALL_INVOKE_HTTP_METHOD")]
http_method: Option<String>,
#[clap(
long = "http-body",
env = "WASH_CALL_INVOKE_HTTP_BODY",
conflicts_with = "http_body_path"
)]
http_body: Option<String>,
#[clap(
long = "http-body-path",
env = "WASH_CALL_INVOKE_HTTP_BODY_PATH",
conflicts_with = "http_body"
)]
http_body_path: Option<PathBuf>,
#[clap(long = "http-content-type", env = "WASH_CALL_INVOKE_HTTP_CONTENT_TYPE")]
http_content_type: Option<String>,
}
impl HttpHandlerInvocationOpts {
pub async fn to_request(self) -> Result<http::Request<String>> {
let HttpHandlerInvocationOpts {
http_scheme,
http_host,
http_port,
http_method,
http_body,
http_body_path,
http_content_type,
..
} = self;
let host = http_host.unwrap_or_else(|| DEFAULT_HTTP_HOST.into());
let port = http_port.unwrap_or(DEFAULT_HTTP_PORT);
let scheme = http_scheme.unwrap_or_else(|| DEFAULT_HTTP_SCHEME.into());
let method =
http::method::Method::from_str(http_method.unwrap_or_else(|| "GET".into()).as_str())
.context("failed to read method from input")?;
debug!(?host, ?port, ?scheme, ?method, content_type = ?http_content_type, "building request from options");
let http_body = match (http_body, http_body_path) {
(Some(s), _) => s,
(_, Some(p)) => tokio::fs::read_to_string(p)
.await
.context("failed to read http body file")?,
(None, None) => String::new(),
};
let mut req = http::Request::builder()
.uri(format!("{scheme}://{host}:{port}"))
.method(method);
if let Some(content_type) = http_content_type {
req = req.header("Content-Type", content_type);
}
req.body(http_body)
.context("failed to build HTTP request from handler invocation options")
}
}
#[derive(Debug, Clone, Serialize)]
struct HttpResponse {
status: u16,
headers: HashMap<String, String>,
body: Bytes,
}
async fn wrpc_invoke_http_handler(
client: wrpc_transport_nats::Client,
lattice: &str,
component_id: &str,
timeout_ms: u64,
request: http::request::Request<String>,
extract_json: bool,
) -> Result<CommandOutput> {
let result = tokio::time::timeout(
std::time::Duration::from_millis(timeout_ms),
client
.invoke_handle_http(Some(gen_wash_call_headers()), request)
)
.await
.with_context(|| format!("component invocation timeout, is component [{component_id}] running in lattice [{lattice}]?"))?
.context("failed to perform HTTP request")?;
match result {
(Ok(mut resp), _errs, io) => {
if let Some(io) = io {
io.await.context("failed to complete async I/O")?;
}
let status = resp.status().as_u16();
let headers =
HashMap::<String, String>::from_iter(resp.headers().into_iter().map(|(k, v)| {
(
k.as_str().into(),
v.to_str().map(|v| v.to_string()).unwrap_or_default(),
)
}));
let mut body = BytesMut::new();
while let Some(bytes) = resp.body_mut().body.next().await {
body.extend(bytes);
}
let body = body.freeze();
let output = if extract_json {
let body_json = serde_json::from_slice(&body)
.context("failed to parse response body bytes into a valid JSON object")?;
CommandOutput::new(
serde_json::to_string_pretty(&body_json)
.context("failed to print http response JSON")?,
HashMap::from([("response".into(), body_json)]),
)
} else {
let http_resp = HttpResponse {
status,
headers,
body,
};
CommandOutput::new(
serde_json::to_string(&http_resp)
.context("failed to print http response JSON")?,
HashMap::from([(
"response".into(),
serde_json::to_value(&http_resp)
.context("failed to convert http response to value")?,
)]),
)
};
Ok(output)
}
_ => bail!("unexpected response after HTTP wRPC invocation"),
}
}
async fn wrpc_invoke_simple(
client: wrpc_transport_nats::Client,
lattice: &str,
component_id: &str,
instance: &str,
function_name: &str,
timeout_ms: u64,
) -> Result<CommandOutput> {
let result = client
.timeout(Duration::from_millis(timeout_ms))
.invoke_values_blocking::<_, ((),), (String,)>(
Some(gen_wash_call_headers()),
instance,
function_name,
((),),
&[[]; 0],
)
.await
.with_context(|| format!("timed out invoking component, is component [{component_id}] running in lattice [{lattice}]?"));
match result {
Ok((result,)) => {
Ok(CommandOutput::new(result.clone(), HashMap::from([("result".to_string(), json!(result))])))
}
Err(e) if e.to_string().contains("transmission failed") => bail!("No component responsed to your request, ensure component {component_id} is running in lattice {lattice}"),
Err(e) => bail!("Error invoking component: {e}"),
}
}
pub fn call_output(
response: Vec<u8>,
save_output: Option<PathBuf>,
bin: char,
is_test: bool,
) -> Result<CommandOutput> {
if let Some(ref save_path) = save_output {
std::fs::write(save_path, response)
.with_context(|| format!("Error saving results to {}", &save_path.display()))?;
return Ok(CommandOutput::new(
"",
HashMap::<String, serde_json::Value>::new(),
));
}
if is_test {
let test_results: Vec<TestResult> =
rmp_serde::from_slice(&response).with_context(|| {
format!(
"Error interpreting response as TestResults. Response: {}",
String::from_utf8_lossy(&response)
)
})?;
print_test_results(&test_results);
return Ok(CommandOutput::new(
"",
HashMap::<String, serde_json::Value>::new(),
));
}
let json = HashMap::from([
(
"response".to_string(),
msgpack_to_json_val(response.clone(), bin),
),
("success".to_string(), serde_json::json!(true)),
]);
Ok(CommandOutput::new(
format!(
"\nCall response (raw): {}",
String::from_utf8_lossy(&response)
),
json,
))
}
async fn create_client_from_opts_wrpc(opts: &ConnectionOpts) -> Result<async_nats::Client> {
let ConnectionOpts {
rpc_host: host,
rpc_port: port,
rpc_jwt: jwt,
rpc_seed: seed,
rpc_credsfile: credsfile,
rpc_ca_file: tls_ca_file,
..
} = opts;
let nats_url = format!("{host}:{port}");
use async_nats::ConnectOptions;
let nc = if let Some(jwt_file) = jwt {
let jwt_contents = extract_arg_value(jwt_file)
.with_context(|| format!("Failed to extract jwt contents from {}", &jwt_file))?;
let kp = std::sync::Arc::new(if let Some(seed) = seed {
nkeys::KeyPair::from_seed(
&extract_arg_value(seed)
.with_context(|| format!("Failed to extract seed value {}", &seed))?,
)
.with_context(|| format!("Failed to create keypair from seed value {}", &seed))?
} else {
nkeys::KeyPair::new_user()
});
let mut opts = async_nats::ConnectOptions::with_jwt(jwt_contents, move |nonce| {
let key_pair = kp.clone();
async move { key_pair.sign(&nonce).map_err(async_nats::AuthError::new) }
});
if let Some(ref ca_file) = tls_ca_file {
opts = opts
.add_root_certificates(ca_file.clone())
.require_tls(true);
}
opts.connect(&nats_url).await.with_context(|| {
format!(
"Failed to connect to NATS server {}:{} while creating client",
&host, &port
)
})?
} else if let Some(credsfile_path) = credsfile {
let mut opts = ConnectOptions::with_credentials_file(credsfile_path.clone())
.await
.with_context(|| {
format!(
"Failed to authenticate to NATS with credentials file {:?}",
&credsfile_path
)
})?;
if let Some(ca_file) = tls_ca_file {
opts = opts
.add_root_certificates(ca_file.clone())
.require_tls(true);
}
opts.connect(&nats_url).await.with_context(|| {
format!(
"Failed to connect to NATS {} with credentials file {:?}",
&nats_url, &credsfile_path
)
})?
} else {
let mut opts = ConnectOptions::new();
if let Some(ca_file) = tls_ca_file {
opts = opts
.add_root_certificates(ca_file.clone())
.require_tls(true);
}
opts.connect(&nats_url).await.with_context(|| format!("Failed to connect to NATS {}\nNo credentials file was provided, you may need one to connect.", &nats_url))?
};
Ok(nc)
}
fn gen_wash_call_headers() -> async_nats::HeaderMap {
let mut headers = async_nats::HeaderMap::new();
headers.insert("source-id", "wash");
headers
}
#[cfg(test)]
mod test {
use super::CallCommand;
use anyhow::Result;
use clap::Parser;
const RPC_HOST: &str = "127.0.0.1";
const RPC_PORT: &str = "4222";
const DEFAULT_LATTICE: &str = "default";
const COMPONENT_ID: &str = "MDPDJEYIAK6MACO67PRFGOSSLODBISK4SCEYDY3HEOY4P5CVJN6UCWUK";
#[derive(Debug, Parser)]
struct Cmd {
#[clap(flatten)]
command: CallCommand,
}
#[test]
fn test_rpc_comprehensive() -> Result<()> {
let call_all: Cmd = Parser::try_parse_from([
"call",
"--context",
"some-context",
"--lattice",
DEFAULT_LATTICE,
"--rpc-host",
RPC_HOST,
"--rpc-port",
RPC_PORT,
"--rpc-timeout-ms",
"0",
COMPONENT_ID,
"wasmcloud:test/handle.operation",
])?;
match call_all.command {
CallCommand {
opts,
component_id,
function,
..
} => {
assert_eq!(&opts.rpc_host, RPC_HOST);
assert_eq!(&opts.rpc_port, RPC_PORT);
assert_eq!(&opts.lattice.unwrap(), DEFAULT_LATTICE);
assert_eq!(opts.timeout_ms, 0);
assert_eq!(opts.context, Some("some-context".to_string()));
assert_eq!(component_id, COMPONENT_ID);
assert_eq!(function, "wasmcloud:test/handle.operation");
}
#[allow(unreachable_patterns)]
cmd => panic!("call constructed incorrect command: {cmd:?}"),
}
Ok(())
}
}