use std::pin::pin;
use std::time::Duration;
use anyhow::{anyhow, ensure, Context as _, Result};
use serde::Deserialize;
use tokio::time::interval;
use tokio_stream::wrappers::IntervalStream;
use tokio_stream::StreamExt;
use tracing::warn;
use wasmcloud_core::health_subject;
pub fn deserialize<'de, T: Deserialize<'de>>(buf: &'de [u8]) -> Result<T> {
serde_json::from_slice(buf).context("failed to deserialize")
}
pub struct StartProviderArgs<'a> {
pub client: &'a wasmcloud_control_interface::Client,
pub host_id: &'a str,
pub provider_id: &'a str,
pub provider_ref: &'a str,
pub config: Vec<String>,
}
pub struct StopProviderArgs<'a> {
pub client: &'a wasmcloud_control_interface::Client,
pub host_id: &'a str,
pub provider_id: &'a str,
}
#[derive(Deserialize)]
#[serde(deny_unknown_fields)]
struct ProviderHealthCheckResponse {
#[serde(default)]
healthy: bool,
#[serde(default)]
message: Option<String>,
}
pub async fn assert_start_provider(
StartProviderArgs {
client,
host_id,
provider_id,
provider_ref,
config,
}: StartProviderArgs<'_>,
) -> Result<()> {
let lattice = client.lattice();
let rpc_client = client.nats_client();
let resp = client
.start_provider(host_id, provider_ref, provider_id, None, config)
.await
.map_err(|e| anyhow!(e).context("failed to start provider"))?;
ensure!(resp.succeeded());
let res = pin!(IntervalStream::new(interval(Duration::from_secs(1)))
.take(30)
.then(|_| rpc_client.request(health_subject(lattice, provider_id), "".into(),))
.filter_map(|res| {
match res {
Err(error) => {
warn!(?error, "failed to connect to provider");
None
}
Ok(res) => Some(res),
}
}))
.next()
.await
.context("failed to perform health check request")?;
let ProviderHealthCheckResponse { healthy, message } = deserialize(&res.payload)
.map_err(|e| anyhow!(e).context("failed to decode health check response"))?;
ensure!(message == None);
ensure!(healthy);
Ok(())
}
pub async fn assert_start_provider_timeout(
StartProviderArgs {
client,
host_id,
provider_id,
provider_ref,
config,
}: StartProviderArgs<'_>,
) -> Result<()> {
if let Err(e) = client
.start_provider(host_id, provider_ref, provider_id, None, config)
.await
{
ensure!(e.to_string().contains("timed out"));
return Ok(());
}
anyhow::bail!("start_provider should not have received a response")
}
pub async fn assert_stop_provider(
StopProviderArgs {
client,
host_id,
provider_id,
}: StopProviderArgs<'_>,
) -> Result<()> {
let lattice = client.lattice();
let rpc_client = client.nats_client();
let resp = client
.stop_provider(host_id, provider_id)
.await
.map_err(|e| anyhow!(e).context("failed to start provider"))?;
ensure!(resp.succeeded());
pin!(IntervalStream::new(interval(Duration::from_secs(1)))
.take(30)
.then(|_| rpc_client.request(health_subject(lattice, provider_id), "".into(),))
.filter_map(|res| {
res.is_err().then_some(())
}))
.next()
.await
.context("provider did not stop and continued to respond to health check requests")?;
Ok(())
}