use crate::{
GlobalArgs, Subcommand,
completions::{complete_profile, complete_recipe},
};
use anyhow::{Context, anyhow, bail};
use async_trait::async_trait;
use clap::{Parser, ValueHint};
use dialoguer::{Input, Password, Select as DialoguerSelect};
use indexmap::IndexMap;
use itertools::Itertools;
use slumber_config::Config;
use slumber_core::{
collection::{
Authentication, ProfileId, QueryParameterValue, Recipe, RecipeBody,
RecipeId, ValueTemplate,
},
database::{CollectionDatabase, Database},
http::{
BodyOverride, BuildFieldOverride, BuildOptions, Exchange, HttpEngine,
RequestBody, RequestRecord, RequestSeed, ResponseRecord,
StoredRequestError, TriggeredRequestError,
},
render::{HttpProvider, Prompt, Prompter, SelectOption, TemplateContext},
util::MaybeStr,
};
use slumber_template::{Expression, Template};
use slumber_util::{OptionExt, ResultTracedAnyhow};
use std::{
error::Error,
fs::OpenOptions,
io::{self, IsTerminal, Write},
path::{Path, PathBuf},
process::ExitCode,
};
use tracing::warn;
const HTTP_ERROR_EXIT_CODE: u8 = 2;
#[derive(Clone, Debug, Parser)]
#[clap(visible_aliases = &["req", "rq"])]
pub struct RequestCommand {
#[clap(flatten)]
build_request: BuildRequestCommand,
#[clap(flatten)]
display: DisplayExchangeCommand,
#[clap(long)]
dry_run: bool,
#[clap(long)]
exit_status: bool,
#[clap(long)]
persist: bool,
}
#[derive(Clone, Debug, Parser)]
pub struct BuildRequestCommand {
#[clap(add = complete_recipe())]
recipe_id: RecipeId,
#[clap(
long = "profile",
short,
add = complete_profile(),
)]
profile: Option<ProfileId>,
#[clap(
long,
visible_alias = "user",
conflicts_with = "bearer",
value_hint = ValueHint::Other,
value_name = "username:password",
)]
basic: Option<String>,
#[clap(
long,
visible_alias = "token",
value_hint = ValueHint::Other,
value_name = "token",
)]
bearer: Option<Template>,
#[clap(long, visible_alias = "data", value_hint = ValueHint::Other)]
body: Option<String>,
#[clap(
long,
short = 'F',
value_parser = parse_recipe_override,
value_hint = ValueHint::Other, // Disable completions
value_name = "field=value",
verbatim_doc_comment,
)]
form: Vec<(String, BuildFieldOverride)>,
#[clap(
long,
short = 'H',
value_parser = parse_recipe_override,
value_hint = ValueHint::Other, // Disable completions
value_name = "header=value",
verbatim_doc_comment,
)]
header: Vec<(String, BuildFieldOverride)>,
#[clap(
long = "override",
short = 'o',
value_parser = parse_profile_override,
value_hint = ValueHint::Other, // Disable completions
value_name = "field=value",
verbatim_doc_comment,
)]
overrides: Vec<(String, Template)>,
#[clap(
long,
value_parser = parse_recipe_override,
value_hint = ValueHint::Other, // Disable completions
value_name = "query=value",
verbatim_doc_comment,
)]
query: Vec<(String, BuildFieldOverride)>,
#[clap(long, value_hint = ValueHint::Other)]
url: Option<Template>,
}
#[derive(Clone, Debug, Parser)]
pub struct DisplayExchangeCommand {
#[clap(short, long)]
verbose: bool,
#[clap(long, value_name = "path")]
output: Option<PathBuf>,
}
impl Subcommand for RequestCommand {
async fn execute(mut self, global: GlobalArgs) -> anyhow::Result<ExitCode> {
let trigger_dependencies = !self.dry_run;
let (database, http_engine, seed, template_context) = self
.build_request
.build_seed(global, trigger_dependencies)?;
let ticket = http_engine.build(seed, &template_context).await.map_err(
|error| {
if error.has_trigger_disabled_error() {
anyhow::Error::from(error.error).context(
"Triggered requests are disabled with `--dry-run`",
)
} else {
error.error.into()
}
},
)?;
if self.dry_run {
self.display.verbose = true;
self.display.write_request(ticket.record());
Ok(ExitCode::SUCCESS)
} else {
self.display.write_request(ticket.record());
let persist_to = if self.persist { Some(database) } else { None };
let exchange = ticket.send(persist_to).await?;
let status = exchange.response.status;
self.display.write_response(&exchange.response)?;
if self.exit_status && status.as_u16() >= 400 {
Ok(ExitCode::from(HTTP_ERROR_EXIT_CODE))
} else {
Ok(ExitCode::SUCCESS)
}
}
}
}
impl BuildRequestCommand {
pub fn build_seed(
self,
global: GlobalArgs,
trigger_dependencies: bool,
) -> anyhow::Result<(
CollectionDatabase,
HttpEngine,
RequestSeed,
TemplateContext,
)> {
let collection_file = global.collection_file()?;
let config = Config::load()?;
let collection = collection_file.load()?;
let database = Database::load()?.into_collection(&collection_file)?;
database.set_name(&collection);
let http_engine = HttpEngine::new(&config.http);
if let Some(profile_id) = &self.profile {
collection.profiles.get(profile_id).ok_or_else(|| {
anyhow!(
"No profile with ID `{profile_id}`; options are: {}",
collection.profiles.keys().format(", ")
)
})?;
}
let selected_profile = self.profile.or_else(|| {
let default_profile = collection.default_profile()?;
Some(default_profile.id.clone())
});
let authentication = match (self.basic, self.bearer) {
(None, None) => None,
(None, Some(token)) => Some(Authentication::Bearer { token }),
(Some(value), None) => Some(get_basic_auth(&value)?),
(Some(_), Some(_)) => {
unreachable!("--basic and --bearer are mutually exclusive")
}
};
let recipe = collection.recipes.try_get_recipe(&self.recipe_id)?;
let body_override = self.body.try_map::<_, anyhow::Error>(|ovr| {
match &recipe.body {
None | Some(RecipeBody::Stream(_) | RecipeBody::Raw(_)) => {
let template: Template = ovr.parse()?;
Ok(BodyOverride::Raw(template))
}
Some(RecipeBody::Json(_)) => {
let json = ValueTemplate::parse_json(&ovr)?;
Ok(BodyOverride::Json(json))
}
Some(
RecipeBody::FormUrlencoded(_)
| RecipeBody::FormMultipart(_),
) => bail!(
"--body not supported for form bodies; \
use --form instead"
),
}
})?;
let build_options = BuildOptions {
url: self.url,
authentication,
headers: IndexMap::from_iter(self.header),
body: body_override,
query_parameters: get_query_parameters(recipe, self.query),
form_fields: IndexMap::from_iter(self.form),
};
let template_context = TemplateContext {
selected_profile,
collection: collection.into(),
http_provider: Box::new(CliHttpProvider {
database: database.clone(),
http_engine: http_engine.clone(),
trigger_dependencies,
}),
overrides: self
.overrides
.into_iter()
.map(|(key, template)| (key, ValueTemplate::String(template)))
.collect(),
prompter: Box::new(CliPrompter),
show_sensitive: true,
root_dir: collection_file.parent().to_owned(),
state: Default::default(),
};
let seed = RequestSeed::new(self.recipe_id, build_options);
Ok((database, http_engine, seed, template_context))
}
}
impl DisplayExchangeCommand {
pub fn write_request(&self, request: &RequestRecord) {
if self.verbose {
eprintln!(
"> {} {} {}",
request.method, request.url, request.http_version
);
for (header, value) in &request.headers {
eprintln!("> {}: {}", header, MaybeStr(value.as_bytes()));
}
match &request.body {
RequestBody::None => {}
RequestBody::Stream => eprintln!("> <stream body>"),
RequestBody::TooLarge => {
eprintln!("> body too large to display");
}
RequestBody::Some(bytes) => {
let text = std::str::from_utf8(bytes).unwrap_or("<binary>");
eprintln!("> {text}");
}
}
}
}
pub fn write_response(
&self,
response: &ResponseRecord,
) -> anyhow::Result<()> {
if self.verbose {
eprintln!();
eprintln!("< {}", response.status);
for (header, value) in &response.headers {
eprintln!("< {}: {}", header, MaybeStr(value.as_bytes()));
}
}
let (mut output, allow_binary) = if let Some(path) = &self.output {
let output: Box<dyn Write> = if path == Path::new("-") {
Box::new(io::stdout())
} else {
Box::new(
OpenOptions::new()
.create(true)
.truncate(true)
.write(true)
.open(path)
.with_context(|| {
format!("Error opening file `{}`", path.display())
})?,
)
};
(output as Box<dyn Write>, true)
} else {
let stdout = io::stdout();
let allow_binary = !stdout.is_terminal();
(Box::new(stdout) as Box<dyn Write>, allow_binary)
};
if response.body.text().is_none() && !allow_binary {
eprintln!(
"Response body is not text. Binary output can mess up your \
terminal. Pass `--output -` if you're sure you want to print \
the output, or consider `--output <FILE>` to save to a file."
);
} else {
output.write_all(response.body.bytes())?;
}
Ok(())
}
}
#[derive(Debug)]
struct CliHttpProvider {
database: CollectionDatabase,
http_engine: HttpEngine,
trigger_dependencies: bool,
}
#[async_trait(?Send)]
impl HttpProvider for CliHttpProvider {
async fn get_latest_request(
&self,
profile_id: Option<&ProfileId>,
recipe_id: &RecipeId,
) -> Result<Option<Exchange>, StoredRequestError> {
self.database
.get_latest_request(profile_id.into(), recipe_id)
.map_err(StoredRequestError::new)
}
async fn send_request(
&self,
seed: RequestSeed,
template_context: &TemplateContext,
) -> Result<Exchange, TriggeredRequestError> {
if self.trigger_dependencies {
let ticket = self.http_engine.build(seed, template_context).await?;
let exchange = ticket.send(None).await?;
Ok(exchange)
} else {
Err(TriggeredRequestError::NotAllowed)
}
}
}
#[derive(Debug)]
struct CliPrompter;
impl CliPrompter {
fn text(
message: String,
default: Option<String>,
sensitive: bool,
) -> anyhow::Result<String> {
if sensitive {
if default.is_some() {
warn!(
"Default value not supported for sensitive prompts in CLI"
);
}
Password::new()
.with_prompt(message)
.allow_empty_password(true)
.interact()
} else {
let mut input = Input::new().with_prompt(message).allow_empty(true);
if let Some(default) = default {
input = input.default(default);
}
input.interact()
}
.context("Error reading value from prompt")
.traced()
}
fn select(
message: String,
mut options: Vec<SelectOption>,
) -> anyhow::Result<slumber_template::Value> {
let index = DialoguerSelect::new()
.with_prompt(message)
.items(&options)
.default(0)
.interact()
.context("Error reading value from select")
.traced()?;
Ok(options.swap_remove(index).value)
}
}
impl Prompter for CliPrompter {
fn prompt(&self, prompt: Prompt) {
match prompt {
Prompt::Text {
message,
default,
sensitive,
channel,
} => {
if let Ok(response) = Self::text(message, default, sensitive) {
channel.reply(response);
}
}
Prompt::Select {
message,
options,
channel,
} => {
if let Ok(response) = Self::select(message, options) {
channel.reply(response);
}
}
}
}
}
fn parse_profile_override(
s: &str,
) -> Result<(String, Template), anyhow::Error> {
let (key, value) = s
.split_once('=')
.ok_or_else(|| anyhow!("invalid key=value: no \"=\" found in `{s}`"))?;
Ok((key.to_owned(), value.parse()?))
}
fn parse_recipe_override(
s: &str,
) -> Result<(String, BuildFieldOverride), Box<dyn Error + Send + Sync + 'static>>
{
if let Some((key, value)) = s.split_once('=') {
let template: Template = value.parse()?;
Ok((key.to_owned(), BuildFieldOverride::Override(template)))
} else {
Ok((s.to_owned(), BuildFieldOverride::Omit))
}
}
fn get_basic_auth(value: &str) -> anyhow::Result<Authentication> {
let (username, password) =
if let Some((username, password)) = value.split_once(':') {
let username: Template =
username.parse().context("Invalid username template")?;
let password: Template =
password.parse().context("Invalid password template")?;
(username, password)
} else {
let username: Template =
value.parse().context("Invalid username template")?;
let password = Expression::call(
"prompt",
[],
[
("message", Some("Password".into())),
("sensitive", Some(true.into())),
],
)
.into();
(username, password)
};
Ok(Authentication::Basic {
username,
password: Some(password),
})
}
fn get_query_parameters(
recipe: &Recipe,
overrides: Vec<(String, BuildFieldOverride)>,
) -> IndexMap<(String, usize), BuildFieldOverride> {
let get_n = |param: &str| -> usize {
match recipe.query.get(param) {
None => 0,
Some(QueryParameterValue::One(_)) => 1,
Some(QueryParameterValue::Many(values)) => values.len(),
}
};
overrides
.into_iter()
.sorted_by(|(a, _), (b, _)| String::cmp(a, b))
.chunk_by(|(param, _)| param.clone())
.into_iter()
.flat_map(|(param, values)| {
values
.pad_using(get_n(¶m), move |_| {
(param.clone(), BuildFieldOverride::Omit)
})
.enumerate()
.map(|(i, (param, value))| ((param, i), value))
})
.collect()
}