use clap::{self, Parser};
use simple_logger;
use std::{collections::HashMap, path::PathBuf};
use prompt_generator::{collection_prompt, request_prompt, scan_directory};
mod benchmark;
mod collection;
mod parser;
mod prompt_generator;
mod request;
#[derive(clap::Parser, Debug)]
#[command(name = "hen")]
#[command(version = env!("CARGO_PKG_VERSION"))]
#[command(about = "Command line API client.")]
struct Cli {
path: Option<String>,
selector: Option<String>,
#[arg(long)]
export: bool,
#[arg(long)]
benchmark: Option<usize>,
#[arg(short = 'v', long)]
verbose: bool,
#[arg(long = "input", value_parser = parse_input_kv)]
inputs: Vec<(String, String)>,
}
#[tokio::main]
async fn main() {
let args: Cli = Cli::parse();
if !args.inputs.is_empty() {
let prompt_inputs: HashMap<String, String> =
args.inputs.iter().cloned().collect::<HashMap<_, _>>();
parser::context::set_prompt_inputs(prompt_inputs);
}
if args.verbose {
simple_logger::init_with_level(log::Level::Debug).expect("Failed to initialize logger.");
}
log::debug!("Starting hen with args {:?}", args);
let cwd = std::env::current_dir().unwrap();
let collection = match args.path {
Some(path) => {
let path = PathBuf::from(path);
if path.is_dir() {
let hen_files = scan_directory(path.clone());
if hen_files.len() == 1 {
collection::Collection::new(hen_files.get(0).unwrap().clone())
} else {
collection_prompt(path)
}
} else {
collection::Collection::new(path)
}
}
None => match collection_prompt(cwd) {
collection => collection,
},
};
let collection = match collection {
Ok(collection) => collection,
Err(e) => {
eprintln!("Error: {}", e);
return;
}
};
log::debug!("PARSED COLLECTION\n{:#?}", collection);
let requests = match args.selector {
Some(selector) => {
match selector {
ref s if s == "all" => collection.requests,
_ => {
let index = selector.parse::<usize>().unwrap();
let index_request = collection.requests.get(index);
match index_request {
Some(request) => vec![request.clone()],
None => {
eprintln!("Request not found: {}", selector);
return;
}
}
}
}
}
None => {
if collection.requests.len() == 1 {
vec![collection.requests[0].clone()]
} else {
vec![request_prompt(collection)]
}
}
};
for req in requests {
if args.export {
println!("{}", req.as_curl());
continue;
}
if let Some(count) = args.benchmark {
benchmark::benchmark(req.clone(), count).await;
continue;
}
let response = req.exec().await.unwrap();
println!("{}", response);
}
}
fn parse_input_kv(s: &str) -> Result<(String, String), String> {
let mut parts = s.splitn(2, '=');
let key = parts
.next()
.map(str::trim)
.filter(|k| !k.is_empty())
.ok_or_else(|| "--input expects key=value".to_string())?;
let value = parts
.next()
.map(|v| v.trim().to_string())
.ok_or_else(|| "--input expects key=value".to_string())?;
Ok((key.to_string(), value))
}