use crate::config;
use crate::error::*;
use crate::package_definition::{
ArgType, ArgumentDefinition, CommandDefinition, OutputDeclaration, OutputType,
PackageDefinition,
};
use crate::paths;
use crate::receipt::{self, Receipt};
use clap::ArgAction;
use std::collections::BTreeMap;
use std::io::{BufRead, BufReader, BufWriter, IsTerminal, Write};
use std::net::TcpStream;
use std::path::Path;
use std::process::{self, Command, Stdio};
use std::sync::{Arc, Mutex};
use zacor_package::protocol::{self as proto, Message};
pub struct ResolvedPackage {
pub receipt: Receipt,
pub definition: PackageDefinition,
pub version: String,
}
fn resolve(home: &Path, name: &str) -> Result<ResolvedPackage> {
let receipt = receipt::read(home, name)?
.ok_or_else(|| anyhow!("package '{}' not found\nhint: install it with `zacor install <source>`", name))?;
if !receipt.active {
bail!("package '{}' is disabled\nhint: run `zacor enable {}`", name, name);
}
let version = receipt.current.clone();
let def_path = paths::definition_path(home, name, &version);
if !def_path.exists() {
bail!(
"package.yaml for '{}' v{} not found in store\nhint: reinstall with `zacor install <source>`",
name, version
);
}
let definition = crate::package_definition::parse_file(&def_path)
.with_context(|| format!("corrupt package.yaml for '{}' v{}\nhint: reinstall with `zacor install <source>`", name, version))?;
Ok(ResolvedPackage {
receipt,
definition,
version,
})
}
pub fn build_clap_command(def: &PackageDefinition) -> clap::Command {
let mut cmd = clap::Command::new(def.name.clone())
.version(def.version.clone())
.disable_help_subcommand(true);
if let Some(ref desc) = def.description {
cmd = cmd.about(desc.clone());
}
let has_default = def.commands.contains_key("default");
let named: Vec<(&String, &CommandDefinition)> = def
.commands
.iter()
.filter(|(k, _)| k.as_str() != "default")
.collect();
let has_named = !named.is_empty();
match (has_default, has_named) {
(true, false) => {
let default_cmd = &def.commands["default"];
if def.description.is_none()
&& let Some(ref desc) = default_cmd.description
{
cmd = cmd.about(desc.clone());
}
let has_rest = default_cmd.args.values().any(|a| a.rest);
for (name, arg_def) in &default_cmd.args {
cmd = cmd.arg(build_arg(name, arg_def));
}
if has_rest {
cmd = cmd.trailing_var_arg(true);
}
}
(true, true) => {
let default_cmd = &def.commands["default"];
let has_rest = default_cmd.args.values().any(|a| a.rest);
for (name, arg_def) in &default_cmd.args {
cmd = cmd.arg(build_arg(name, arg_def));
}
if has_rest {
cmd = cmd.trailing_var_arg(true);
}
cmd = cmd.subcommand_required(false);
for (name, cmd_def) in &named {
cmd = cmd.subcommand(build_subcommand(name, cmd_def));
}
}
(false, _) => {
cmd = cmd.subcommand_required(true);
for (name, cmd_def) in &named {
cmd = cmd.subcommand(build_subcommand(name, cmd_def));
}
}
}
cmd
}
fn build_subcommand(name: &str, def: &CommandDefinition) -> clap::Command {
let mut cmd = clap::Command::new(name.to_string());
if let Some(ref desc) = def.description {
cmd = cmd.about(desc.clone());
}
let has_rest = def.args.values().any(|a| a.rest);
for (arg_name, arg_def) in &def.args {
cmd = cmd.arg(build_arg(arg_name, arg_def));
}
if has_rest {
cmd = cmd.trailing_var_arg(true);
}
for (sub_name, sub_def) in &def.commands {
cmd = cmd.subcommand(build_subcommand(sub_name, sub_def));
}
if !def.commands.is_empty() {
cmd = cmd.subcommand_required(true);
}
cmd
}
fn build_arg(name: &str, def: &ArgumentDefinition) -> clap::Arg {
let mut arg = clap::Arg::new(name.to_string());
if let Some(ref flag) = def.flag {
arg = arg.long(flag.clone());
} else if def.arg_type == ArgType::Bool {
arg = arg.long(name.to_string());
}
match def.arg_type {
ArgType::Bool => {
arg = arg.action(ArgAction::SetTrue);
}
ArgType::Number | ArgType::Integer => {
arg = arg.value_parser(parse_number);
}
ArgType::Path => {
arg = arg.value_hint(clap::ValueHint::AnyPath);
}
ArgType::Choice => {
if let Some(ref values) = def.values {
arg = arg.value_parser(
clap::builder::PossibleValuesParser::new(values.clone()),
);
}
}
ArgType::String => {}
}
if def.rest {
arg = arg.num_args(0..);
}
if def.arg_type != ArgType::Bool && def.required && def.default.is_none() {
arg = arg.required(true);
}
if def.required
&& let Some(ref default) = def.default
{
arg = arg.default_value(config::yaml_value_to_string(default));
}
arg
}
fn parse_number(s: &str) -> std::result::Result<String, String> {
s.parse::<f64>()
.map_err(|_| format!("'{}' is not a valid number", s))?;
Ok(s.to_string())
}
fn clap_parse(
cmd: clap::Command,
pkg_name: &str,
args: &[String],
def: &PackageDefinition,
) -> std::result::Result<(String, BTreeMap<String, String>), clap::Error> {
let mut full_args = vec![pkg_name.to_string()];
full_args.extend_from_slice(args);
let matches = cmd.try_get_matches_from(full_args)?;
if let Some((sub_name, sub_matches)) = matches.subcommand()
&& let Some(cmd_def) = def.commands.get(sub_name)
{
let (sub_path, flags) = extract_from_command(sub_matches, cmd_def);
let path = if sub_path.is_empty() {
sub_name.to_string()
} else {
format!("{}.{}", sub_name, sub_path)
};
return Ok((path, flags));
}
if let Some(default_cmd) = def.commands.get("default") {
let flags = extract_args(&matches, &default_cmd.args);
return Ok(("default".to_string(), flags));
}
Ok(("default".to_string(), BTreeMap::new()))
}
fn extract_from_command(
matches: &clap::ArgMatches,
cmd_def: &CommandDefinition,
) -> (String, BTreeMap<String, String>) {
if let Some((sub_name, sub_matches)) = matches.subcommand()
&& let Some(sub_cmd_def) = cmd_def.commands.get(sub_name)
{
let (sub_path, flags) = extract_from_command(sub_matches, sub_cmd_def);
let path = if sub_path.is_empty() {
sub_name.to_string()
} else {
format!("{}.{}", sub_name, sub_path)
};
return (path, flags);
}
let flags = extract_args(matches, &cmd_def.args);
(String::new(), flags)
}
fn extract_args(
matches: &clap::ArgMatches,
arg_defs: &BTreeMap<String, ArgumentDefinition>,
) -> BTreeMap<String, String> {
let mut flags = BTreeMap::new();
for (name, def) in arg_defs {
if def.arg_type == ArgType::Bool {
if matches.get_flag(name) {
flags.insert(name.clone(), "true".to_string());
}
} else if def.rest {
if let Some(vals) = matches.get_many::<String>(name) {
let joined: String = vals.cloned().collect::<Vec<_>>().join(" ");
if !joined.is_empty() {
flags.insert(name.clone(), joined);
}
}
} else if let Some(val) = matches.get_one::<String>(name) {
flags.insert(name.clone(), val.clone());
}
}
flags
}
fn find_command<'a>(
commands: &'a BTreeMap<String, CommandDefinition>,
path: &str,
) -> Result<&'a CommandDefinition> {
let parts: Vec<&str> = path.split('.').collect();
let mut current = commands;
let mut cmd = None;
for part in &parts {
match current.get(*part) {
Some(c) => {
cmd = Some(c);
current = &c.commands;
}
None => bail!("command '{}' not found", path),
}
}
cmd.ok_or_else(|| anyhow!("empty command path"))
}
fn resolve_mode(resolved: &ResolvedPackage) -> receipt::Mode {
if let Some(mode) = resolved.receipt.mode {
return mode;
}
if let Some(ref exec) = resolved.definition.execution {
if let Some(ref default) = exec.default {
if let Ok(mode) = default.parse::<receipt::Mode>() {
return mode;
}
}
}
receipt::Mode::Command
}
#[allow(clippy::too_many_arguments)]
fn execute(
home: &Path,
resolved: &ResolvedPackage,
env_vars: &BTreeMap<String, String>,
placeholders: &BTreeMap<String, String>,
command_path: &str,
command: &CommandDefinition,
parsed_flags: &BTreeMap<String, String>,
raw_json: bool,
force_text: bool,
) -> Result<i32> {
if resolved.definition.protocol {
let mode = resolve_mode(resolved);
if mode == receipt::Mode::Service && resolved.definition.service.is_some() {
return execute_service(home, resolved, command_path, command, parsed_flags, raw_json, force_text);
}
return execute_protocol(home, resolved, command_path, command, parsed_flags, raw_json, force_text, env_vars);
}
execute_command(home, resolved, env_vars, placeholders, command, raw_json, force_text)
}
fn send_message(
writer: &Arc<Mutex<BufWriter<Box<dyn Write + Send>>>>,
msg: &Message,
) -> Result<()> {
let json = serde_json::to_string(msg).context("failed to serialize protocol message")?;
let mut w = writer.lock().unwrap();
writeln!(w, "{}", json).context("failed to write to module")?;
w.flush().context("failed to flush module writer")
}
fn forward_stdin_as_input(writer: Arc<Mutex<BufWriter<Box<dyn Write + Send>>>>) {
let stdin = std::io::stdin();
let mut reader = BufReader::new(stdin.lock());
let mut line = String::new();
loop {
line.clear();
match reader.read_line(&mut line) {
Ok(0) => break, Ok(_) => {}
Err(_) => break,
}
let msg = Message::Input(proto::Input {
data: line.clone(),
eof: false,
});
if send_message(&writer, &msg).is_err() {
break;
}
}
let eof = Message::Input(proto::Input {
data: String::new(),
eof: true,
});
let _ = send_message(&writer, &eof);
}
pub(crate) fn run_protocol_session(
reader: impl BufRead,
writer: impl Write + Send + 'static,
invoke_msg: &Message,
command: &CommandDefinition,
raw_json: bool,
force_text: bool,
) -> Result<i32> {
let has_input = match invoke_msg {
Message::Invoke(inv) => inv.input,
_ => false,
};
let module_writer: Arc<Mutex<BufWriter<Box<dyn Write + Send>>>> =
Arc::new(Mutex::new(BufWriter::new(Box::new(writer))));
send_message(&module_writer, invoke_msg)?;
if has_input {
let w = module_writer.clone();
std::thread::Builder::new()
.name("zr-input-fwd".into())
.spawn(move || forward_stdin_as_input(w))
.context("failed to spawn input forwarding thread")?;
}
let is_tty = std::io::stdout().is_terminal();
let render = !raw_json && command.output.is_some() && (force_text || is_tty);
let streaming = command.output.as_ref().is_some_and(|o| o.stream);
let mut records: Vec<serde_json::Value> = Vec::new();
let mut exit_code: Option<i32> = None;
let mut streaming_started = false;
let stdout_handle = std::io::stdout();
let mut stdout_writer = BufWriter::new(stdout_handle.lock());
for line in reader.lines() {
let line = match line {
Ok(l) => l,
Err(_) => break,
};
if line.is_empty() {
continue;
}
let msg: Message = match serde_json::from_str(&line) {
Ok(m) => m,
Err(_) => continue, };
match msg {
Message::Output(output) => {
if render && streaming {
if !streaming_started {
if let Some(output_decl) = &command.output {
crate::render::render_streaming_header(output_decl, &mut stdout_writer);
}
streaming_started = true;
}
if let Some(output_decl) = &command.output {
crate::render::render_streaming_row(
&output.record,
output_decl,
&mut stdout_writer,
);
}
} else if render {
records.push(output.record);
} else {
let json = serde_json::to_string(&output.record)
.unwrap_or_default();
let _ = writeln!(stdout_writer, "{}", json);
let _ = stdout_writer.flush();
}
}
Message::Progress(progress) => {
if is_tty {
render_progress(progress.fraction);
}
}
Message::CapabilityReq(req) => {
let res = crate::capability_provider::handle(&req);
if send_message(&module_writer, &Message::CapabilityRes(res)).is_err() {
break;
}
}
Message::Done(done) => {
if let Some(ref error) = done.error {
eprintln!("error: {}", error);
}
exit_code = Some(done.exit_code);
break;
}
_ => {} }
}
if is_tty {
eprint!("\r\x1b[K");
}
if render && !streaming && !records.is_empty() {
if let Some(output_decl) = &command.output {
match output_decl.resolved_output_type() {
OutputType::Text => {
crate::render::render_text(&records, output_decl, &mut stdout_writer);
}
OutputType::Record => {
if let Some(record) = records.first() {
crate::render::render_record(record, output_decl, &mut stdout_writer);
}
}
OutputType::Table => {
crate::render::render_table(&records, output_decl, &mut stdout_writer);
}
}
}
}
let _ = stdout_writer.flush();
Ok(exit_code.unwrap_or(1))
}
fn execute_service(
home: &Path,
resolved: &ResolvedPackage,
command_path: &str,
command: &CommandDefinition,
parsed_flags: &BTreeMap<String, String>,
raw_json: bool,
force_text: bool,
) -> Result<i32> {
let service = resolved.definition.service.as_ref().unwrap();
let port = service.port.ok_or_else(|| {
anyhow!(
"service package '{}' must declare a port in service.port",
resolved.definition.name
)
})?;
ensure_service_running(home, &resolved.definition.name, port)?;
let stream = TcpStream::connect(format!("127.0.0.1:{}", port))
.with_context(|| format!("failed to connect to service '{}' on port {}", resolved.definition.name, port))?;
let reader = BufReader::new(stream.try_clone().context("failed to clone TCP stream")?);
let has_input = command.input.is_some();
let invoke_msg = Message::Invoke(proto::Invoke::from_str_args(
command_path,
parsed_flags,
has_input,
));
run_protocol_session(reader, stream, &invoke_msg, command, raw_json, force_text)
}
fn ensure_service_running(home: &Path, name: &str, port: u16) -> Result<()> {
if TcpStream::connect(format!("127.0.0.1:{}", port)).is_ok() {
return Ok(());
}
let client = crate::daemon_client::connect_or_start_daemon(home)?;
let response = crate::daemon_client::start_service(&client, name)?;
if !response.ok {
bail!(
"failed to start service '{}': {}",
name,
response.error.unwrap_or_else(|| "unknown error".into())
);
}
Ok(())
}
fn execute_protocol(
home: &Path,
resolved: &ResolvedPackage,
command_path: &str,
command: &CommandDefinition,
parsed_flags: &BTreeMap<String, String>,
raw_json: bool,
force_text: bool,
env_vars: &BTreeMap<String, String>,
) -> Result<i32> {
let binary_name = resolved.definition.binary.as_ref().ok_or_else(|| {
anyhow!(
"protocol package '{}' must have a binary",
resolved.definition.name
)
})?;
let bin_path = paths::store_binary_path(
home,
&resolved.definition.name,
&resolved.version,
binary_name,
);
if !bin_path.exists() {
bail!(
"binary '{}' not found for '{}' v{}\nhint: reinstall with `zacor install <source>`",
binary_name,
resolved.definition.name,
resolved.version
);
}
#[cfg(windows)]
let _job = crate::job_object::JobObject::setup().ok();
let mut child = Command::new(&bin_path)
.envs(env_vars)
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::inherit())
.spawn()
.with_context(|| {
format!(
"failed to spawn package '{}'",
resolved.definition.name
)
})?;
#[cfg(windows)]
if let Some(ref job) = _job {
let _ = job.assign(&child);
}
let child_stdin = child.stdin.take().unwrap();
let child_stdout = child.stdout.take().unwrap();
let has_input = command.input.is_some();
let invoke_msg = Message::Invoke(proto::Invoke::from_str_args(
command_path,
parsed_flags,
has_input,
));
let reader = BufReader::new(child_stdout);
let result = run_protocol_session(reader, child_stdin, &invoke_msg, command, raw_json, force_text);
let _ = child.wait();
result
}
fn render_progress(fraction: f64) {
let clamped = fraction.clamp(0.0, 1.0);
let pct = (clamped * 100.0) as u32;
let filled = (clamped * 20.0) as usize;
let bar: String = "█".repeat(filled) + &"░".repeat(20 - filled);
eprint!("\r{} {:>3}%", bar, pct);
}
fn should_render(output: &Option<OutputDeclaration>, raw_json: bool, force_text: bool) -> bool {
!raw_json && output.is_some() && (force_text || std::io::stdout().is_terminal())
}
fn execute_command(
home: &Path,
resolved: &ResolvedPackage,
env_vars: &BTreeMap<String, String>,
placeholders: &BTreeMap<String, String>,
command: &CommandDefinition,
raw_json: bool,
force_text: bool,
) -> Result<i32> {
let render = should_render(&command.output, raw_json, force_text);
if let Some(ref binary_name) = resolved.definition.binary {
let bin_path = paths::store_binary_path(
home,
&resolved.definition.name,
&resolved.version,
binary_name,
);
if !bin_path.exists() {
bail!(
"binary '{}' not found for '{}' v{}\nhint: reinstall with `zacor install <source>`",
binary_name, resolved.definition.name, resolved.version
);
}
let output_decl = if render { command.output.as_ref() } else { None };
exec_binary(&bin_path, &resolved.definition.name, env_vars, output_decl)
} else if let Some(ref invoke) = command.invoke {
crate::execute::exec_invoke(invoke, env_vars, placeholders)
} else {
bail!(
"package '{}' has no binary and no invoke template for this command",
resolved.definition.name
);
}
}
fn exec_binary(
bin: &Path,
name: &str,
env_vars: &BTreeMap<String, String>,
output: Option<&OutputDeclaration>,
) -> Result<i32> {
#[cfg(unix)]
if output.is_none() {
use std::os::unix::process::CommandExt;
let err = Command::new(bin)
.envs(env_vars)
.stdin(process::Stdio::inherit())
.stdout(process::Stdio::inherit())
.stderr(process::Stdio::inherit())
.exec();
return Err(anyhow!(err).context(format!("failed to exec package '{}'", name)));
}
#[cfg(windows)]
let _job = match crate::job_object::JobObject::setup() {
Ok(job) => Some(job),
Err(e) => {
eprintln!("warning: failed to create Job Object: {:#}", e);
None
}
};
let stdout_cfg = if output.is_some() {
process::Stdio::piped()
} else {
process::Stdio::inherit()
};
let mut child = Command::new(bin)
.envs(env_vars)
.stdin(process::Stdio::inherit())
.stdout(stdout_cfg)
.stderr(process::Stdio::inherit())
.spawn()
.with_context(|| format!("failed to execute package '{}'", name))?;
#[cfg(windows)]
if let Some(ref job) = _job
&& let Err(e) = job.assign(&child)
{
eprintln!("warning: failed to assign process to Job Object: {:#}", e);
}
if let Some(output_decl) = output
&& let Some(child_stdout) = child.stdout.take()
{
let reader = BufReader::new(child_stdout);
let stdout = std::io::stdout();
let writer = std::io::BufWriter::new(stdout.lock());
crate::render::render_jsonl(reader, output_decl, writer);
}
let status = child
.wait()
.with_context(|| format!("failed to wait for package '{}'", name))?;
Ok(status.code().unwrap_or(1))
}
pub fn run(home: &Path, name: &str, args: &[String], raw_json: bool, force_text: bool) -> Result<i32> {
let resolved = resolve(home, name)?;
let cmd = build_clap_command(&resolved.definition);
let (command_path, parsed_flags) = match clap_parse(cmd, name, args, &resolved.definition) {
Ok(result) => result,
Err(e) => {
if e.use_stderr() {
eprint!("{}", e);
return Ok(2);
} else {
print!("{}", e);
return Ok(0);
}
}
};
let command = find_command(&resolved.definition.commands, &command_path)?;
let cwd = std::env::current_dir().ok();
let project_root = match cwd {
Some(ref c) => paths::discover_project_root(c, home),
None => None,
};
let project_config = project_root.as_ref().and_then(|root| {
config::read_project(root).ok()
});
let global_config = config::read_global(home).unwrap_or_default();
let (env_vars, placeholders) = crate::execute::build_env_vars(
home,
&resolved.definition.name,
&command_path,
&resolved.version,
&parsed_flags,
command,
&resolved.receipt,
&global_config,
&resolved.definition.config,
project_root.as_deref(),
resolved.definition.project_data,
project_config.as_ref(),
cwd.as_deref(),
);
execute(home, &resolved, &env_vars, &placeholders, &command_path, command, &parsed_flags, raw_json, force_text)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_util;
#[test]
fn test_dispatch_missing_package() {
let home = test_util::temp_home("dispatch");
let result = run(home.path(), "nonexistent", &[], false, false);
assert!(result.is_err());
let err = format!("{:#}", result.unwrap_err());
assert!(err.contains("not found"), "got: {}", err);
}
#[test]
fn test_dispatch_disabled_package() {
let home = test_util::temp_home("dispatch");
let mut r = receipt::Receipt::new(
"1.0.0".to_string(),
receipt::SourceRecord::Local {
path: "/tmp/mymod".to_string(),
},
);
r.active = false;
receipt::write(home.path(), "mymod", &r).unwrap();
let result = run(home.path(), "mymod", &[], false, false);
assert!(result.is_err());
let err = format!("{:#}", result.unwrap_err());
assert!(err.contains("disabled"), "got: {}", err);
assert!(err.contains("zacor enable"), "got: {}", err);
}
#[test]
fn test_dispatch_corrupt_definition() {
let home = test_util::temp_home("dispatch");
receipt::write(
home.path(),
"broken",
&receipt::Receipt::new(
"1.0.0".to_string(),
receipt::SourceRecord::Local {
path: "/tmp/broken".to_string(),
},
),
)
.unwrap();
let result = run(home.path(), "broken", &[], false, false);
assert!(result.is_err());
let err = format!("{:#}", result.unwrap_err());
assert!(err.contains("not found in store") || err.contains("reinstall"), "got: {}", err);
}
#[test]
fn test_build_clap_single_default() {
let yaml = r#"
name: echo
version: "0.2.0"
description: "Echo text"
commands:
default:
description: Echo text
args:
text:
type: string
required: true
"#;
let def = crate::package_definition::parse(yaml).unwrap();
let cmd = build_clap_command(&def);
let matches = cmd.try_get_matches_from(["echo", "hello"]).unwrap();
assert_eq!(matches.get_one::<String>("text").unwrap(), "hello");
assert!(matches.subcommand().is_none());
}
#[test]
fn test_build_clap_default_plus_named() {
let yaml = r#"
name: my-pkg
version: "1.0.0"
commands:
default:
args:
text:
type: string
transcribe:
description: Transcribe audio
args:
file:
type: path
required: true
"#;
let def = crate::package_definition::parse(yaml).unwrap();
let cmd = build_clap_command(&def);
let matches = cmd.try_get_matches_from(["my-pkg", "hello"]).unwrap();
assert!(matches.subcommand().is_none());
assert_eq!(matches.get_one::<String>("text").unwrap(), "hello");
let cmd = build_clap_command(&def);
let matches = cmd
.try_get_matches_from(["my-pkg", "transcribe", "file.mp3"])
.unwrap();
let (name, sub) = matches.subcommand().unwrap();
assert_eq!(name, "transcribe");
assert_eq!(sub.get_one::<String>("file").unwrap(), "file.mp3");
}
#[test]
fn test_build_clap_named_only() {
let yaml = r#"
name: my-pkg
version: "1.0.0"
commands:
transcribe:
description: Transcribe audio
translate:
description: Translate text
"#;
let def = crate::package_definition::parse(yaml).unwrap();
let cmd = build_clap_command(&def);
let result = cmd.try_get_matches_from(["my-pkg"]);
assert!(result.is_err());
}
#[test]
fn test_build_clap_nested_commands() {
let yaml = r#"
name: my-pkg
version: "1.0.0"
commands:
transcribe:
description: Transcribe audio
commands:
batch:
description: Batch transcribe
args:
files:
type: string
required: true
"#;
let def = crate::package_definition::parse(yaml).unwrap();
let cmd = build_clap_command(&def);
let matches = cmd
.try_get_matches_from(["my-pkg", "transcribe", "batch", "*.mp3"])
.unwrap();
let (name, sub) = matches.subcommand().unwrap();
assert_eq!(name, "transcribe");
let (nested_name, nested_sub) = sub.subcommand().unwrap();
assert_eq!(nested_name, "batch");
assert_eq!(nested_sub.get_one::<String>("files").unwrap(), "*.mp3");
}
#[test]
fn test_build_clap_arg_types() {
let yaml = r#"
name: test
version: "1.0.0"
commands:
default:
args:
input:
type: string
required: true
count:
type: number
flag: count
verbose:
type: bool
flag: verbose
file:
type: path
flag: file
format:
type: choice
flag: format
values: [json, csv, text]
"#;
let def = crate::package_definition::parse(yaml).unwrap();
let cmd = build_clap_command(&def);
let result = cmd.try_get_matches_from(["test", "hello", "--count", "abc"]);
assert!(result.is_err());
let cmd = build_clap_command(&def);
let result = cmd.try_get_matches_from(["test", "hello", "--format", "invalid"]);
assert!(result.is_err());
let cmd = build_clap_command(&def);
let matches = cmd
.try_get_matches_from([
"test", "hello", "--count", "42", "--verbose", "--file", "/path", "--format", "json",
])
.unwrap();
assert_eq!(matches.get_one::<String>("input").unwrap(), "hello");
assert_eq!(matches.get_one::<String>("count").unwrap(), "42");
assert!(matches.get_flag("verbose"));
assert_eq!(matches.get_one::<String>("file").unwrap(), "/path");
assert_eq!(matches.get_one::<String>("format").unwrap(), "json");
}
#[test]
fn test_build_clap_flag_vs_positional() {
let yaml = r#"
name: test
version: "1.0.0"
commands:
default:
args:
text:
type: string
required: true
model:
type: choice
flag: model
values: [base, large]
"#;
let def = crate::package_definition::parse(yaml).unwrap();
let cmd = build_clap_command(&def);
let matches = cmd
.try_get_matches_from(["test", "hello", "--model", "large"])
.unwrap();
assert_eq!(matches.get_one::<String>("text").unwrap(), "hello");
assert_eq!(matches.get_one::<String>("model").unwrap(), "large");
let cmd = build_clap_command(&def);
let matches = cmd
.try_get_matches_from(["test", "--model", "base", "hello"])
.unwrap();
assert_eq!(matches.get_one::<String>("text").unwrap(), "hello");
assert_eq!(matches.get_one::<String>("model").unwrap(), "base");
}
#[test]
fn test_clap_parse_default_command() {
let yaml = r#"
name: echo
version: "0.2.0"
commands:
default:
args:
text:
type: string
required: true
"#;
let def = crate::package_definition::parse(yaml).unwrap();
let cmd = build_clap_command(&def);
let (path, flags) = clap_parse(cmd, "echo", &["hello".to_string()], &def).unwrap();
assert_eq!(path, "default");
assert_eq!(flags["text"], "hello");
}
#[test]
fn test_clap_parse_named_command() {
let yaml = r#"
name: my-pkg
version: "1.0.0"
commands:
transcribe:
description: Transcribe audio
args:
file:
type: path
required: true
translate:
description: Translate text
"#;
let def = crate::package_definition::parse(yaml).unwrap();
let cmd = build_clap_command(&def);
let (path, flags) =
clap_parse(cmd, "my-pkg", &["transcribe".to_string(), "file.mp3".to_string()], &def)
.unwrap();
assert_eq!(path, "transcribe");
assert_eq!(flags["file"], "file.mp3");
}
#[test]
fn test_clap_parse_nested_command() {
let yaml = r#"
name: my-pkg
version: "1.0.0"
commands:
transcribe:
description: Transcribe
commands:
batch:
description: Batch
args:
files:
type: string
required: true
"#;
let def = crate::package_definition::parse(yaml).unwrap();
let cmd = build_clap_command(&def);
let (path, flags) = clap_parse(
cmd,
"my-pkg",
&["transcribe".to_string(), "batch".to_string(), "*.mp3".to_string()],
&def,
)
.unwrap();
assert_eq!(path, "transcribe.batch");
assert_eq!(flags["files"], "*.mp3");
}
#[test]
fn test_clap_parse_bool_flag() {
let yaml = r#"
name: test
version: "1.0.0"
commands:
default:
args:
verbose:
type: bool
flag: verbose
"#;
let def = crate::package_definition::parse(yaml).unwrap();
let cmd = build_clap_command(&def);
let (_, flags) = clap_parse(cmd, "test", &["--verbose".to_string()], &def).unwrap();
assert_eq!(flags["verbose"], "true");
let cmd = build_clap_command(&def);
let (_, flags) = clap_parse(cmd, "test", &[], &def).unwrap();
assert!(!flags.contains_key("verbose"));
}
#[test]
fn test_clap_parse_unknown_flag_error() {
let yaml = r#"
name: echo
version: "0.2.0"
commands:
default:
args:
text:
type: string
"#;
let def = crate::package_definition::parse(yaml).unwrap();
let cmd = build_clap_command(&def);
let result = clap_parse(cmd, "echo", &["--unknown".to_string(), "hello".to_string()], &def);
assert!(result.is_err());
}
#[test]
fn test_bool_auto_flags() {
let yaml = r#"
name: test
version: "1.0.0"
commands:
default:
args:
changes:
type: bool
drafts:
type: bool
"#;
let def = crate::package_definition::parse(yaml).unwrap();
let cmd = build_clap_command(&def);
let (_, flags) = clap_parse(cmd, "test", &["--changes".to_string()], &def).unwrap();
assert_eq!(flags["changes"], "true");
assert!(!flags.contains_key("drafts"));
let cmd = build_clap_command(&def);
let (_, flags) = clap_parse(cmd, "test", &["--changes".to_string(), "--drafts".to_string()], &def).unwrap();
assert_eq!(flags["changes"], "true");
assert_eq!(flags["drafts"], "true");
let cmd = build_clap_command(&def);
let (_, flags) = clap_parse(cmd, "test", &[], &def).unwrap();
assert!(!flags.contains_key("changes"));
assert!(!flags.contains_key("drafts"));
}
#[test]
fn test_find_command_default() {
let mut commands = BTreeMap::new();
commands.insert("default".to_string(), CommandDefinition::default());
let cmd = find_command(&commands, "default").unwrap();
assert!(cmd.args.is_empty());
}
#[test]
fn test_find_command_nested() {
let mut inner = BTreeMap::new();
inner.insert("batch".to_string(), CommandDefinition::default());
let mut commands = BTreeMap::new();
commands.insert(
"transcribe".to_string(),
CommandDefinition {
commands: inner,
..Default::default()
},
);
let cmd = find_command(&commands, "transcribe.batch").unwrap();
assert!(cmd.args.is_empty());
}
#[test]
fn test_find_command_not_found() {
let commands = BTreeMap::new();
let result = find_command(&commands, "nonexistent");
assert!(result.is_err());
}
fn make_resolved(
mode: Option<receipt::Mode>,
exec_default: Option<&str>,
service: bool,
) -> ResolvedPackage {
let mut r = receipt::Receipt::new(
"1.0.0".to_string(),
receipt::SourceRecord::Local {
path: "/tmp/test".to_string(),
},
);
r.mode = mode;
let mut def = crate::package_definition::parse(
r#"
name: test
version: "1.0.0"
protocol: true
commands:
default:
description: test
"#,
)
.unwrap();
if let Some(default) = exec_default {
def.execution = Some(crate::package_definition::ExecutionSection {
default: Some(default.to_string()),
});
}
if service {
def.service = Some(crate::package_definition::ServiceSection {
start: "test".into(),
port: Some(9999),
health: None,
startup: None,
});
}
ResolvedPackage {
receipt: r,
definition: def,
version: "1.0.0".to_string(),
}
}
#[test]
fn test_mode_resolution_receipt_overrides_definition() {
let resolved = make_resolved(Some(receipt::Mode::Service), Some("command"), true);
assert_eq!(resolve_mode(&resolved), receipt::Mode::Service);
}
#[test]
fn test_mode_resolution_definition_default() {
let resolved = make_resolved(None, Some("service"), true);
assert_eq!(resolve_mode(&resolved), receipt::Mode::Service);
}
#[test]
fn test_mode_resolution_fallback_to_command() {
let resolved = make_resolved(None, None, false);
assert_eq!(resolve_mode(&resolved), receipt::Mode::Command);
}
#[test]
fn test_mode_resolution_receipt_command_overrides_service_default() {
let resolved = make_resolved(Some(receipt::Mode::Command), Some("service"), true);
assert_eq!(resolve_mode(&resolved), receipt::Mode::Command);
}
#[test]
fn test_should_render_force_text() {
let output = Some(OutputDeclaration {
output_type: Some(OutputType::Table),
cardinality: None,
display: None,
schema: None,
field: None,
stream: false,
});
assert!(should_render(&output, false, true));
assert!(!should_render(&output, true, true));
assert!(!should_render(&None, false, true));
}
#[test]
fn test_text_and_json_conflict() {
use clap::{CommandFactory, Parser};
#[derive(Parser)]
struct TestCli {
#[arg(long, conflicts_with = "text")]
json: bool,
#[arg(long, conflicts_with = "json")]
text: bool,
}
assert!(TestCli::command().try_get_matches_from(["test", "--text"]).is_ok());
assert!(TestCli::command().try_get_matches_from(["test", "--json"]).is_ok());
assert!(TestCli::command().try_get_matches_from(["test", "--text", "--json"]).is_err());
}
}