use std::collections::HashMap;
use anyhow::{bail, Context, Result};
use clap::Parser;
use tokio::time::Duration;
use crate::{
cli::{input_vec_to_hashmap, CliConnectionOpts, CommandOutput},
common::{boxed_err_to_anyhow, find_host_id},
component::{scale_component, ComponentScaledInfo, ScaleComponentArgs},
config::{
WashConnectionOptions, DEFAULT_NATS_TIMEOUT_MS, DEFAULT_START_COMPONENT_TIMEOUT_MS,
DEFAULT_START_PROVIDER_TIMEOUT_MS,
},
context::default_timeout_ms,
wait::{wait_for_provider_start_event, FindEventOutcome, ProviderStartedInfo},
};
use super::validate_component_id;
#[derive(Debug, Clone, Parser)]
pub enum StartCommand {
#[clap(name = "component")]
Component(StartComponentCommand),
#[clap(name = "provider")]
Provider(StartProviderCommand),
}
#[derive(Debug, Clone, Parser)]
pub struct StartComponentCommand {
#[clap(flatten)]
pub opts: CliConnectionOpts,
#[clap(long = "host-id")]
pub host_id: Option<String>,
#[clap(name = "component-ref")]
pub component_ref: String,
#[clap(name = "component-id", value_parser = validate_component_id)]
pub component_id: String,
#[clap(
long = "max-instances",
alias = "max-concurrent",
alias = "max",
alias = "count",
default_value_t = 1
)]
pub max_instances: u32,
#[clap(short = 'c', long = "constraint", name = "constraints")]
pub constraints: Option<Vec<String>>,
#[clap(long = "auction-timeout-ms", default_value_t = default_timeout_ms())]
pub auction_timeout_ms: u64,
#[clap(long = "skip-wait")]
pub skip_wait: bool,
#[clap(long = "config")]
pub config: Vec<String>,
}
pub(crate) async fn resolve_ref(s: impl AsRef<str>) -> Result<String> {
let resolved = match s.as_ref() {
s if s.starts_with('/') => {
format!("file://{}", &s) }
s if tokio::fs::try_exists(s).await.is_ok_and(|exists| exists) => {
format!(
"file://{}",
tokio::fs::canonicalize(&s)
.await
.with_context(|| format!("failed to resolve absolute path: {}", s))?
.display()
)
}
s if s.starts_with("file://")
&& tokio::fs::try_exists(s.split_at(7).1)
.await
.is_ok_and(|exists| exists) =>
{
format!(
"file://{}",
tokio::fs::canonicalize(s.split_at(7).1)
.await
.with_context(|| format!("failed to resolve absolute path: {}", s))?
.display()
)
}
s => s.to_string(),
};
Ok(resolved)
}
pub async fn handle_start_component(cmd: StartComponentCommand) -> Result<CommandOutput> {
let timeout_ms = if cmd.opts.timeout_ms == DEFAULT_NATS_TIMEOUT_MS {
DEFAULT_START_COMPONENT_TIMEOUT_MS
} else {
cmd.opts.timeout_ms
};
let client = <CliConnectionOpts as TryInto<WashConnectionOptions>>::try_into(cmd.opts)?
.into_ctl_client(Some(cmd.auction_timeout_ms))
.await?;
let component_ref = resolve_ref(&cmd.component_ref).await?;
let host = match cmd.host_id {
Some(host) => find_host_id(&host, &client).await?.0,
None => {
let suitable_hosts = client
.perform_component_auction(
&component_ref,
&cmd.component_id,
input_vec_to_hashmap(cmd.constraints.unwrap_or_default())?,
)
.await
.map_err(boxed_err_to_anyhow)
.with_context(|| {
format!(
"Failed to auction component {} to hosts in lattice",
&component_ref
)
})?;
if suitable_hosts.is_empty() {
bail!("No suitable hosts found for component {}", component_ref);
} else {
let acks = suitable_hosts
.into_iter()
.filter_map(|h| h.response)
.collect::<Vec<_>>();
let ack = acks.first().context("No suitable hosts found")?;
ack.host_id
.parse()
.with_context(|| format!("Failed to parse host id: {}", ack.host_id))?
}
}
};
let ComponentScaledInfo {
host_id,
component_ref,
component_id,
} = scale_component(ScaleComponentArgs {
client: &client,
host_id: &host,
component_ref: &component_ref,
component_id: &cmd.component_id,
max_instances: cmd.max_instances,
skip_wait: cmd.skip_wait,
timeout_ms: Some(timeout_ms),
annotations: None,
config: cmd.config,
})
.await?;
let text = if cmd.skip_wait {
format!("Start component [{component_ref}] request received on host [{host_id}]",)
} else {
format!("Component [{component_id}] (ref: [{component_ref}]) started on host [{host_id}]",)
};
Ok(CommandOutput::new(
text.clone(),
HashMap::from([
("result".into(), text.into()),
("component_ref".into(), component_ref.into()),
("component_id".into(), component_id.into()),
("host_id".into(), host_id.into()),
]),
))
}
#[derive(Debug, Clone, Parser)]
pub struct StartProviderCommand {
#[clap(flatten)]
pub opts: CliConnectionOpts,
#[clap(long = "host-id")]
pub host_id: Option<String>,
#[clap(name = "provider-ref")]
pub provider_ref: String,
#[clap(name = "provider-id", value_parser = validate_component_id)]
pub provider_id: String,
#[clap(short = 'l', long = "link-name", default_value = "default")]
pub link_name: String,
#[clap(short = 'c', long = "constraint", name = "constraints")]
pub constraints: Option<Vec<String>>,
#[clap(long = "auction-timeout-ms", default_value_t = default_timeout_ms())]
pub auction_timeout_ms: u64,
#[clap(long = "config")]
pub config: Vec<String>,
#[clap(long = "skip-wait")]
pub skip_wait: bool,
}
pub async fn handle_start_provider(cmd: StartProviderCommand) -> Result<CommandOutput> {
let timeout_ms = if cmd.opts.timeout_ms == DEFAULT_NATS_TIMEOUT_MS {
DEFAULT_START_PROVIDER_TIMEOUT_MS
} else {
cmd.opts.timeout_ms
};
let client = <CliConnectionOpts as TryInto<WashConnectionOptions>>::try_into(cmd.opts)?
.into_ctl_client(Some(cmd.auction_timeout_ms))
.await?;
let provider_ref = resolve_ref(&cmd.provider_ref).await?;
let host = match cmd.host_id {
Some(host) => find_host_id(&host, &client).await?.0,
None => {
let suitable_hosts = client
.perform_provider_auction(
&provider_ref,
&cmd.link_name,
input_vec_to_hashmap(cmd.constraints.unwrap_or_default())?,
)
.await
.map_err(boxed_err_to_anyhow)
.with_context(|| {
format!(
"Failed to auction provider {} with link name {} to hosts in lattice",
&provider_ref, &cmd.link_name
)
})?;
if suitable_hosts.is_empty() {
bail!("No suitable hosts found for provider {}", provider_ref);
} else {
let acks = suitable_hosts
.into_iter()
.filter_map(|h| h.response)
.collect::<Vec<_>>();
let ack = acks.first().context("No suitable hosts found")?;
ack.host_id
.parse()
.with_context(|| format!("Failed to parse host id: {}", ack.host_id))?
}
}
};
let mut receiver = client
.events_receiver(vec![
"provider_started".to_string(),
"provider_start_failed".to_string(),
])
.await
.map_err(boxed_err_to_anyhow)
.context("Failed to get lattice event channel")?;
let ack = client
.start_provider(&host, &provider_ref, &cmd.provider_id, None, cmd.config)
.await
.map_err(boxed_err_to_anyhow)
.with_context(|| {
format!(
"Failed to start provider {} on host {:?}",
&cmd.provider_id, &host
)
})?;
if !ack.success {
bail!("Start provider ack not accepted: {}", ack.message);
}
if cmd.skip_wait {
let text = format!("Start provider request received: {}", &provider_ref);
return Ok(CommandOutput::new(
text.clone(),
HashMap::from([
("result".into(), text.into()),
("provider_ref".into(), provider_ref.into()),
("link_name".into(), cmd.link_name.into()),
("host_id".into(), host.to_string().into()),
]),
));
}
let event = wait_for_provider_start_event(
&mut receiver,
Duration::from_millis(timeout_ms),
host.to_string(),
provider_ref.clone(),
)
.await
.with_context(|| {
format!(
"Timed out waiting for start event for provider {} on host {}",
&provider_ref, &host
)
})?;
match event {
FindEventOutcome::Success(ProviderStartedInfo {
provider_id,
provider_ref,
host_id,
}) => {
let text = format!(
"Provider [{}] (ref: [{}]) started on host [{}]",
&provider_id, &provider_ref, &host_id
);
Ok(CommandOutput::new(
text.clone(),
HashMap::from([
("result".into(), text.into()),
("provider_ref".into(), provider_ref.into()),
("provider_id".into(), provider_id.into()),
("host_id".into(), host_id.into()),
]),
))
}
FindEventOutcome::Failure(err) => Err(err).with_context(|| {
format!(
"Failed starting provider {} on host {}",
&provider_ref, &host
)
}),
}
}