use anyhow::{anyhow, bail, Context, Result};
use clap::Parser;
use std::collections::HashMap;
use tokio::time::Duration;
use tracing::error;
use wasmcloud_control_interface::HostInventory;
use crate::{
cli::{CliConnectionOpts, CommandOutput},
common::{boxed_err_to_anyhow, find_host_id, get_all_inventories, FindIdError, Match},
component::{scale_component, ComponentScaledInfo, ScaleComponentArgs},
config::{host_pid_file, WashConnectionOptions},
context::default_timeout_ms,
id::ServerId,
wait::{wait_for_provider_stop_event, FindEventOutcome, ProviderStoppedInfo},
};
use super::validate_component_id;
#[derive(Debug, Clone, Parser)]
pub enum StopCommand {
#[clap(name = "component")]
Component(StopComponentCommand),
#[clap(name = "provider")]
Provider(StopProviderCommand),
#[clap(name = "host")]
Host(StopHostCommand),
}
#[derive(Debug, Clone, Parser)]
pub struct StopComponentCommand {
#[clap(flatten)]
pub opts: CliConnectionOpts,
#[clap(long = "host-id")]
pub host_id: Option<String>,
#[clap(name = "component-id", value_parser = validate_component_id)]
pub component_id: String,
#[clap(long = "skip-wait")]
pub skip_wait: bool,
}
#[derive(Debug, Clone, Parser)]
pub struct StopProviderCommand {
#[clap(flatten)]
pub opts: CliConnectionOpts,
#[clap(long = "host-id")]
pub host_id: Option<String>,
#[clap(name = "provider-id", value_parser = validate_component_id)]
pub provider_id: String,
#[clap(long = "skip-wait")]
pub skip_wait: bool,
}
#[derive(Debug, Clone, Parser)]
pub struct StopHostCommand {
#[clap(flatten)]
pub opts: CliConnectionOpts,
#[clap(name = "host-id")]
pub host_id: String,
#[clap(
long = "host-timeout",
default_value_t = default_timeout_ms()
)]
pub host_shutdown_timeout: u64,
}
pub async fn stop_provider(cmd: StopProviderCommand) -> Result<CommandOutput> {
let timeout_ms = cmd.opts.timeout_ms;
let wco: WashConnectionOptions = cmd.opts.try_into()?;
let client = wco.into_ctl_client(None).await?;
let mut receiver = client
.events_receiver(vec![
"provider_stopped".to_string(),
"provider_stop_failed".to_string(),
])
.await
.map_err(boxed_err_to_anyhow)?;
let host_id = if let Some(host_id) = cmd.host_id {
find_host_id(&host_id, &client).await?.0
} else {
find_host_with_provider(&cmd.provider_id, &client).await?
};
let ack = client
.stop_provider(&host_id, &cmd.provider_id)
.await
.map_err(boxed_err_to_anyhow)?;
if !ack.success {
bail!("Operation failed: {}", ack.message);
}
if cmd.skip_wait {
let text = format!("Provider {} stop request received", cmd.provider_id);
return Ok(CommandOutput::new(
text.clone(),
HashMap::from([
("result".into(), text.into()),
("provider_id".into(), cmd.provider_id.to_string().into()),
("host_id".into(), host_id.to_string().into()),
]),
));
}
let event = wait_for_provider_stop_event(
&mut receiver,
Duration::from_millis(timeout_ms),
host_id.to_string(),
cmd.provider_id.to_string(),
)
.await?;
match event {
FindEventOutcome::Success(ProviderStoppedInfo {
host_id,
provider_id,
}) => {
let text = format!("Provider [{}] stopped successfully", &cmd.provider_id);
Ok(CommandOutput::new(
text.clone(),
HashMap::from([
("result".into(), text.into()),
("provider_id".into(), provider_id.into()),
("host_id".into(), host_id.into()),
]),
))
}
FindEventOutcome::Failure(err) => bail!("{}", err),
}
}
pub async fn handle_stop_component(cmd: StopComponentCommand) -> Result<CommandOutput> {
let timeout_ms = cmd.opts.timeout_ms;
let wco: WashConnectionOptions = cmd.opts.try_into()?;
let client = wco.into_ctl_client(None).await?;
let component_id = cmd.component_id;
let inventory = if let Some(host_id) = cmd.host_id {
client
.get_host_inventory(&host_id)
.await
.map(|inventory| inventory.response)
.map_err(boxed_err_to_anyhow)?
.context("Supplied host did not respond to inventory query")?
} else {
let inventories = get_all_inventories(&client).await?;
inventories
.into_iter()
.find(|inv| {
inv.components
.iter()
.any(|component| component.id == component_id)
})
.ok_or_else(|| anyhow::anyhow!("No host found running component [{}]", component_id))?
};
let Some((host_id, component_ref)) = inventory
.components
.iter()
.find(|component| component.id == component_id)
.map(|component| (inventory.host_id.clone(), component.image_ref.clone()))
else {
bail!(
"No component with id [{component_id}] found on host [{}]",
inventory.host_id
);
};
let ComponentScaledInfo {
component_id,
host_id,
..
} = scale_component(ScaleComponentArgs {
client: &client,
host_id: &host_id,
component_id: &component_id,
component_ref: &component_ref,
max_instances: 0,
annotations: None,
config: vec![],
skip_wait: cmd.skip_wait,
timeout_ms: Some(timeout_ms),
})
.await?;
let text = if cmd.skip_wait {
format!("Request to stop component [{component_id}] received",)
} else {
format!("Component [{component_id}] stopped")
};
Ok(CommandOutput::new(
text.clone(),
HashMap::from([
("result".into(), text.into()),
("component_id".into(), component_id.into()),
("host_id".into(), host_id.into()),
]),
))
}
pub async fn stop_host(cmd: StopHostCommand) -> Result<CommandOutput> {
let wco: WashConnectionOptions = cmd.opts.try_into()?;
let client = wco.into_ctl_client(None).await?;
let (_, hosts_remain) = stop_hosts(client, Some(&cmd.host_id), false).await?;
let pid_file_exists = tokio::fs::try_exists(host_pid_file()?).await?;
if !hosts_remain && pid_file_exists {
tokio::fs::remove_file(host_pid_file()?).await?;
}
Ok(CommandOutput::from_key_and_text(
"result",
format!("Host {} acknowledged stop request", cmd.host_id),
))
}
async fn find_host_with_provider(
provider_id: &str,
ctl_client: &wasmcloud_control_interface::Client,
) -> Result<ServerId, FindIdError> {
find_host_with_filter(ctl_client, |inv| {
inv.providers
.into_iter()
.any(|prov| prov.id == provider_id)
.then_some((inv.host_id, inv.friendly_name))
.and_then(|(id, friendly_name)| id.parse().ok().map(|i| (i, friendly_name)))
})
.await
}
async fn find_host_with_filter<F>(
ctl_client: &wasmcloud_control_interface::Client,
filter: F,
) -> Result<ServerId, FindIdError>
where
F: FnMut(HostInventory) -> Option<(ServerId, String)>,
{
let inventories = get_all_inventories(ctl_client).await?;
let all_matching = inventories
.into_iter()
.filter_map(filter)
.collect::<Vec<(ServerId, String)>>();
if all_matching.is_empty() {
Err(FindIdError::NoMatches)
} else if all_matching.len() > 1 {
Err(FindIdError::MultipleMatches(
all_matching
.into_iter()
.map(|(id, friendly_name)| Match {
id: id.into_string(),
friendly_name: Some(friendly_name),
})
.collect(),
))
} else {
Ok(all_matching.into_iter().next().unwrap().0)
}
}
pub async fn stop_hosts(
client: wasmcloud_control_interface::client::Client,
host_id: Option<&String>,
all: bool,
) -> Result<(Vec<String>, bool)> {
let hosts = client
.get_hosts()
.await
.map_err(|e| anyhow!(e))?
.into_iter()
.filter_map(|r| r.response)
.collect::<Vec<_>>();
if let Some(host_id) = host_id {
let host_id_string = host_id.to_string();
client.stop_host(&host_id_string, None).await.map_err(|e| {
anyhow!(
"Could not stop host, ensure a host with that ID is running: {:?}",
e
)
})?;
Ok((vec![host_id_string], hosts.len() > 1))
} else if hosts.is_empty() {
Ok((vec![], false))
} else if hosts.len() == 1 {
let host_id = &hosts[0].id;
client
.stop_host(host_id, None)
.await
.map_err(|e| anyhow!(e))?;
Ok((vec![host_id.to_string()], false))
} else if all {
let host_stops = hosts
.iter()
.map(|host| async {
let host_id = &host.id;
match client.stop_host(host_id, None).await {
Ok(_) => Some(host_id.to_owned()),
Err(e) => {
error!("Could not stop host {}: {:?}", host_id, e);
None
}
}
})
.collect::<Vec<_>>();
let all_stops = futures::future::join_all(host_stops).await;
let host_ids = all_stops
.iter()
.filter_map(std::borrow::ToOwned::to_owned)
.collect::<Vec<_>>();
let hosts_remaining = all_stops.len() > host_ids.len();
Ok((host_ids, hosts_remaining))
} else {
bail!(
"More than one host is running, please specify a host ID or use --all\nRunning hosts: {:?}", hosts.into_iter().map(|h| h.id).collect::<Vec<_>>()
)
}
}