use http::Method;
use reqwest::header::{HeaderName, HeaderValue};
use serde_json::Value;
use std::{
collections::HashMap,
fmt,
path::PathBuf,
sync::atomic::{AtomicU64, Ordering},
time::{SystemTime, UNIX_EPOCH},
};
use crate::parser::{self, context};
mod assertion;
mod planner;
mod response_capture;
mod runner;
mod template;
pub use assertion::Assertion;
pub use planner::RequestPlanner;
pub use response_capture::{ResponseCapture, ResponseSnapshot};
pub use runner::{
execute_plan as execute_request_plan,
execute_plan_with_observer as execute_request_plan_with_observer, ExecutionEvent,
ExecutionObserver, ExecutionOptions, ExecutionRecord, InterruptSignal, RequestFailure,
RequestFailureKind,
};
pub use template::{
expand_templates, FragmentInclude, RequestTemplate, TemplateError, VariableStore,
};
#[derive(Debug, Clone)]
pub enum FormDataType {
Text(String),
File(PathBuf),
}
#[derive(Debug, Clone)]
pub struct Request {
pub description: String,
pub base_description: String,
pub method: Method,
pub url: String,
pub headers: HashMap<String, String>,
pub query_params: HashMap<String, String>,
pub form_data: HashMap<String, FormDataType>,
pub body: Option<String>,
pub body_content_type: Option<String>,
pub callback_src: Vec<String>,
pub response_captures: Vec<ResponseCapture>,
pub assertions: Vec<Assertion>,
pub declared_dependencies: Vec<String>,
pub dependencies: Vec<String>,
pub context: HashMap<String, String>,
pub working_dir: PathBuf,
pub map_iteration: Option<MapIteration>,
}
#[derive(Debug, Clone)]
pub struct RequestExecution {
pub output: String,
pub export_env: HashMap<String, String>,
pub snapshot: ResponseSnapshot,
pub assertions: Vec<AssertionOutcome>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct AssertionOutcome {
pub assertion: String,
pub status: AssertionStatus,
pub message: Option<String>,
}
impl AssertionOutcome {
pub(crate) fn passed(assertion: impl Into<String>) -> Self {
Self {
assertion: assertion.into(),
status: AssertionStatus::Passed,
message: None,
}
}
pub(crate) fn failed(assertion: impl Into<String>, message: impl Into<String>) -> Self {
Self {
assertion: assertion.into(),
status: AssertionStatus::Failed,
message: Some(message.into()),
}
}
pub(crate) fn skipped(assertion: impl Into<String>, message: impl Into<String>) -> Self {
Self {
assertion: assertion.into(),
status: AssertionStatus::Skipped,
message: Some(message.into()),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AssertionStatus {
Passed,
Failed,
Skipped,
}
impl AssertionStatus {
pub(crate) fn as_str(&self) -> &'static str {
match self {
Self::Passed => "passed",
Self::Failed => "failed",
Self::Skipped => "skipped",
}
}
}
#[derive(Debug, Clone)]
pub struct RequestExecutionError {
message: String,
assertions: Vec<AssertionOutcome>,
}
impl RequestExecutionError {
fn new(message: impl Into<String>) -> Self {
Self {
message: message.into(),
assertions: Vec::new(),
}
}
fn with_assertions(message: impl Into<String>, assertions: Vec<AssertionOutcome>) -> Self {
Self {
message: message.into(),
assertions,
}
}
pub(crate) fn into_parts(self) -> (String, Vec<AssertionOutcome>) {
(self.message, self.assertions)
}
}
impl fmt::Display for RequestExecutionError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.message)
}
}
impl std::error::Error for RequestExecutionError {}
#[derive(Debug, Clone)]
pub struct MapIteration {
pub variables: Vec<(String, String)>,
pub label: String,
}
static MULTIPART_BOUNDARY_COUNTER: AtomicU64 = AtomicU64::new(0);
impl MapIteration {
pub fn suffix(&self) -> String {
if self.label.is_empty() {
String::new()
} else {
format!("[{}]", self.label)
}
}
}
impl Request {
pub async fn exec(
&self,
inherited_context: &HashMap<String, String>,
dependency_snapshots: &HashMap<String, ResponseSnapshot>,
observer: Option<&ExecutionObserver>,
) -> Result<RequestExecution, RequestExecutionError> {
let mut context_map = inherited_context.clone();
for (key, value) in &self.context {
context_map.insert(key.clone(), value.clone());
}
let client = reqwest::Client::new();
let request = self
.build_reqwest_request(&client, &context_map)
.map_err(|err| RequestExecutionError::new(err.to_string()))?;
log::debug!("REQUEST\n{:#?}", request);
let resp = client
.execute(request)
.await
.map_err(|err| RequestExecutionError::new(err.to_string()))?;
log::debug!("RESPONSE\n{:#?}", resp);
let status = resp.status();
let headers = resp.headers().clone();
let resp_text = String::from_utf8_lossy(
&resp
.bytes()
.await
.map_err(|err| RequestExecutionError::new(err.to_string()))?,
)
.into_owned();
let json_body = serde_json::from_str::<Value>(&resp_text).ok();
let sanitized_response = resp_text.replace('\0', "");
let snapshot = ResponseSnapshot {
status,
headers: headers.clone(),
body: sanitized_response.clone(),
json: json_body.clone(),
};
let mut captured_values: HashMap<String, String> = HashMap::new();
for capture in &self.response_captures {
let source_snapshot = match &capture.source {
response_capture::CaptureSource::Current => &snapshot,
response_capture::CaptureSource::Dependency(name) => dependency_snapshots
.get(name)
.ok_or_else(|| {
RequestExecutionError::new(
CaptureDependencyError {
dependency: name.clone(),
request: self.description.clone(),
}
.to_string(),
)
})?,
};
match capture.extract_from_snapshot(source_snapshot) {
Ok(Some((name, value))) => {
captured_values.insert(name, value);
}
Ok(None) => {}
Err(err) => return Err(RequestExecutionError::new(err.to_string())),
}
}
if !captured_values.is_empty() {
log::debug!("CAPTURED VALUES\n{:#?}", captured_values);
}
let status_code = status.as_str().to_string();
let mut export_env: HashMap<String, String> = context_map.clone();
export_env.insert("RESPONSE".to_string(), sanitized_response.clone());
export_env.insert("STATUS".to_string(), status_code.clone());
if let Some(reason) = status.canonical_reason() {
export_env.insert("STATUS_TEXT".to_string(), reason.to_string());
}
export_env.insert("DESCRIPTION".to_string(), self.description.clone());
for (key, value) in &captured_values {
export_env.insert(key.clone(), value.replace('\0', ""));
}
let mut assignment_callbacks: Vec<(String, String)> = Vec::new();
let mut regular_callbacks: Vec<&String> = Vec::new();
if !self.callback_src.is_empty() {
for src in &self.callback_src {
match parse_callback_assignment(src) {
Ok(Some((command, variable))) => assignment_callbacks.push((command, variable)),
Ok(None) => regular_callbacks.push(src),
Err(err) => return Err(RequestExecutionError::new(err.to_string())),
}
}
}
for (command, variable) in assignment_callbacks {
let output =
parser::eval_shell_script(&command, &self.working_dir, Some(export_env.clone()));
let sanitized = sanitize_callback_assignment_output(&output);
export_env.insert(variable, sanitized);
}
let mut assertion_results = Vec::with_capacity(self.assertions.len());
for assertion in &self.assertions {
let should_execute = match assertion.should_execute(
&export_env,
&snapshot,
dependency_snapshots,
) {
Ok(should_execute) => should_execute,
Err(err) => {
let message = err.to_string();
assertion_results.push(AssertionOutcome::failed(
assertion.raw.clone(),
message.clone(),
));
return Err(RequestExecutionError::with_assertions(
message,
assertion_results,
));
}
};
if !should_execute {
log::debug!("Skipping assertion '{}' due to guard", assertion.raw);
assertion_results.push(AssertionOutcome::skipped(
assertion.raw.clone(),
"guard evaluated to false",
));
continue;
}
if let Err(err) = assertion.evaluate(&export_env, &snapshot, dependency_snapshots) {
let message = err.to_string();
assertion_results.push(AssertionOutcome::failed(
assertion.raw.clone(),
message.clone(),
));
return Err(RequestExecutionError::with_assertions(
message,
assertion_results,
));
}
assertion_results.push(AssertionOutcome::passed(assertion.raw.clone()));
if let Some(callback) = observer {
callback(ExecutionEvent::AssertionPassed {
request: self.description.clone(),
assertion: assertion.raw.clone(),
});
}
}
let mut callback_resp: Vec<String> = vec![];
if !regular_callbacks.is_empty() {
for src in regular_callbacks {
let output =
parser::eval_shell_script(src, &self.working_dir, Some(export_env.clone()));
callback_resp.push(output);
}
log::debug!("CALLBACK RESPONSE\n{:#?}", callback_resp);
}
let output = if !callback_resp.is_empty() {
callback_resp.join("\n")
} else {
resp_text
};
let final_export_env = apply_map_suffix_to_exports(export_env, self.map_iteration.as_ref());
Ok(RequestExecution {
output,
export_env: final_export_env,
snapshot,
assertions: assertion_results,
})
}
fn build_reqwest_request(
&self,
client: &reqwest::Client,
context_map: &HashMap<String, String>,
) -> Result<reqwest::Request, Box<dyn std::error::Error>> {
let resolved_url = context::resolve_with_context(&self.url, context_map);
let mut request = client.request(
self.method.clone(),
context::try_inject_from_prompt(&resolved_url)?,
);
let multipart_payload = self.build_multipart_payload(context_map)?;
let mut query_params: HashMap<String, String> = HashMap::new();
for (key, value) in &self.query_params {
let resolved = context::resolve_with_context(value, context_map);
let resolved = context::try_inject_from_prompt(&resolved)?;
query_params.insert(key.clone(), resolved);
}
request = request.query(&query_params);
if let Some((body, boundary)) = multipart_payload {
request = request.header(
reqwest::header::CONTENT_TYPE,
format!("multipart/form-data; boundary={boundary}"),
);
request = request.body(body);
}
if let Some(body) = &self.body {
let resolved = context::resolve_with_context(body, context_map);
request = request.body(context::try_inject_from_prompt(&resolved)?);
}
let mut request = request.build()?;
if let Some(content_type) = &self.body_content_type {
let resolved = context::resolve_with_context(content_type, context_map);
request.headers_mut().remove(reqwest::header::CONTENT_TYPE);
request.headers_mut().insert(
reqwest::header::CONTENT_TYPE,
parse_header_value(
context::try_inject_from_prompt(&resolved)?,
"Content-Type".to_string(),
)?,
);
}
for (key, value) in &self.headers {
let resolved_value = context::resolve_with_context(value, context_map);
let header_name = parse_header_name(key)?;
request.headers_mut().remove(&header_name);
request.headers_mut().insert(
header_name,
parse_header_value(
context::try_inject_from_prompt(&resolved_value)?,
key.clone(),
)?,
);
}
Ok(request)
}
fn build_multipart_payload(
&self,
context_map: &HashMap<String, String>,
) -> Result<Option<(Vec<u8>, String)>, Box<dyn std::error::Error>> {
if self.form_data.is_empty() {
return Ok(None);
}
let boundary = generate_multipart_boundary();
let mut body = Vec::new();
for (name, value) in &self.form_data {
body.extend_from_slice(b"--");
body.extend_from_slice(boundary.as_bytes());
body.extend_from_slice(b"\r\n");
let escaped_name = escape_multipart_quoted_string(name);
match value {
FormDataType::Text(text) => {
let resolved = context::resolve_with_context(text, context_map);
let resolved = context::try_inject_from_prompt(&resolved)?;
body.extend_from_slice(
format!(
"Content-Disposition: form-data; name=\"{}\"\r\n\r\n",
escaped_name
)
.as_bytes(),
);
body.extend_from_slice(resolved.as_bytes());
}
FormDataType::File(filepath) => {
let filename = filepath
.file_name()
.and_then(|name| name.to_str())
.ok_or_else(|| {
std::io::Error::new(
std::io::ErrorKind::InvalidInput,
format!(
"Multipart file names must be valid UTF-8: {}",
filepath.display()
),
)
})?;
let escaped_filename = escape_multipart_quoted_string(filename);
body.extend_from_slice(
format!(
"Content-Disposition: form-data; name=\"{}\"; filename=\"{}\"\r\n\r\n",
escaped_name, escaped_filename
)
.as_bytes(),
);
body.extend_from_slice(&std::fs::read(filepath)?);
}
}
body.extend_from_slice(b"\r\n");
}
body.extend_from_slice(b"--");
body.extend_from_slice(boundary.as_bytes());
body.extend_from_slice(b"--\r\n");
Ok(Some((body, boundary)))
}
pub fn as_curl(&self) -> String {
let mut curl = String::new();
let headers = self
.headers
.iter()
.map(|(key, value)| format!("-H '{}: {}'", key, value))
.collect::<Vec<String>>()
.join(" ");
let has_manual_content_type = self
.headers
.keys()
.any(|key| key.eq_ignore_ascii_case("content-type"));
let headers = match (&self.body_content_type, has_manual_content_type) {
(Some(content_type), false) => {
format!("-H 'Content-Type: {}' {}", content_type, headers)
}
_ => headers,
};
let query_params = self
.query_params
.iter()
.map(|(key, value)| format!("{}={}", key, value))
.collect::<Vec<String>>()
.join("&");
let query_params = match query_params.len() {
0 => "".to_string(),
_ => format!("?{}", query_params),
};
let form_data = self
.form_data
.iter()
.map(|(key, value)| match value {
FormDataType::Text(text) => format!("-F '{}={}'", key, text),
FormDataType::File(filepath) => format!("-F '{}=@{}'", key, filepath.display()),
})
.collect::<Vec<String>>()
.join(" ");
let body = match &self.body {
Some(body) => format!("-d ' {}'", body),
None => "".to_string(),
};
curl.push_str(&format!(
"curl -X {} '{}{}' {} {} {}",
self.method, self.url, query_params, headers, form_data, body
));
curl
}
}
fn parse_callback_assignment(
src: &str,
) -> Result<Option<(String, String)>, CallbackAssignmentError> {
let trimmed = src.trim();
let Some(index) = trimmed.find("->") else {
return Ok(None);
};
if index == 0 {
return Err(CallbackAssignmentError::new(
"callback assignment is missing a command before '->'",
));
}
let before_arrow = &trimmed[..index];
let after_arrow = &trimmed[index + 2..];
let has_surrounding_whitespace = before_arrow
.chars()
.last()
.map(|c| c.is_whitespace())
.unwrap_or(false)
&& after_arrow
.chars()
.next()
.map(|c| c.is_whitespace())
.unwrap_or(false);
if !has_surrounding_whitespace {
return Err(CallbackAssignmentError::new(
"callback assignment requires whitespace on both sides of '->'",
));
}
let command = before_arrow.trim();
if command.is_empty() {
return Err(CallbackAssignmentError::new(
"callback assignment command cannot be empty",
));
}
let remainder = after_arrow.trim();
if !remainder.starts_with('$') {
return Err(CallbackAssignmentError::new(
"callback assignment target must start with '$'",
));
}
let name = remainder.trim_start_matches('$').trim();
if name.is_empty() {
return Err(CallbackAssignmentError::new(
"callback assignment target variable name cannot be empty",
));
}
Ok(Some((command.to_string(), name.to_string())))
}
fn sanitize_callback_assignment_output(output: &str) -> String {
output
.trim_end_matches(['\n', '\r'])
.replace('\0', "")
.to_string()
}
fn escape_multipart_quoted_string(value: &str) -> String {
value
.replace('\\', "\\\\")
.replace('"', "\\\"")
.replace('\r', "\\\r")
.replace('\n', "\\\n")
}
fn generate_multipart_boundary() -> String {
let timestamp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_nanos();
let counter = MULTIPART_BOUNDARY_COUNTER.fetch_add(1, Ordering::Relaxed) as u128;
format!(
"------------------------{:016x}{:016x}",
(timestamp >> 64) as u64,
(timestamp as u64) ^ counter as u64,
)
}
#[derive(Debug)]
struct CallbackAssignmentError {
message: String,
}
impl CallbackAssignmentError {
fn new(message: impl Into<String>) -> Self {
Self {
message: message.into(),
}
}
}
impl std::fmt::Display for CallbackAssignmentError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.message)
}
}
impl std::error::Error for CallbackAssignmentError {}
#[derive(Debug)]
struct CaptureDependencyError {
dependency: String,
request: String,
}
impl std::fmt::Display for CaptureDependencyError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Request '{}' references dependency '{}' in a capture but it has not executed yet.",
self.request, self.dependency
)
}
}
impl std::error::Error for CaptureDependencyError {}
#[derive(Debug)]
struct InvalidRequestHeaderName {
header: String,
}
impl std::fmt::Display for InvalidRequestHeaderName {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Invalid header name '{}'", self.header)
}
}
impl std::error::Error for InvalidRequestHeaderName {}
#[derive(Debug)]
struct InvalidRequestHeaderValue {
header: String,
value: String,
}
impl std::fmt::Display for InvalidRequestHeaderValue {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Invalid value '{}' for header '{}'",
self.value, self.header
)
}
}
impl std::error::Error for InvalidRequestHeaderValue {}
fn parse_header_name(header: &str) -> Result<HeaderName, Box<dyn std::error::Error>> {
header.parse::<HeaderName>().map_err(|_| {
Box::new(InvalidRequestHeaderName {
header: header.to_string(),
}) as Box<dyn std::error::Error>
})
}
fn parse_header_value(
value: String,
header: String,
) -> Result<HeaderValue, Box<dyn std::error::Error>> {
value.parse::<HeaderValue>().map_err(|_| {
Box::new(InvalidRequestHeaderValue { header, value }) as Box<dyn std::error::Error>
})
}
fn apply_map_suffix_to_exports(
export_env: HashMap<String, String>,
iteration: Option<&MapIteration>,
) -> HashMap<String, String> {
let Some(iteration) = iteration else {
return export_env;
};
let suffix = iteration.suffix();
if suffix.is_empty() {
return export_env;
}
let reserved = ["RESPONSE", "STATUS", "STATUS_TEXT", "DESCRIPTION"];
let mut transformed: HashMap<String, String> = HashMap::with_capacity(export_env.len());
for (key, value) in export_env.into_iter() {
if reserved.iter().any(|reserved_key| reserved_key == &key) {
transformed.insert(key, value);
} else {
transformed.insert(format!("{}{}", key, suffix), value);
}
}
transformed
}
#[cfg(test)]
mod tests {
use super::*;
use std::{collections::HashMap, path::PathBuf};
fn base_request() -> Request {
Request {
description: "Test Request".into(),
base_description: "Test Request".into(),
method: Method::POST,
url: "https://example.com".into(),
headers: HashMap::new(),
query_params: HashMap::new(),
form_data: HashMap::new(),
body: None,
body_content_type: None,
callback_src: vec![],
response_captures: vec![],
assertions: vec![],
declared_dependencies: vec![],
dependencies: vec![],
context: HashMap::new(),
working_dir: PathBuf::new(),
map_iteration: None,
}
}
#[test]
fn callback_assignment_parses_with_whitespace() {
let (command, variable) = parse_callback_assignment("echo foo -> $BAR")
.unwrap()
.expect("expected assignment");
assert_eq!(command, "echo foo");
assert_eq!(variable, "BAR");
}
#[test]
fn callback_assignment_requires_whitespace() {
let err = parse_callback_assignment("echo foo->$BAR").unwrap_err();
assert!(err.to_string().contains("whitespace on both sides of '->'"));
}
#[test]
fn callback_assignment_requires_dollar_prefix() {
let err = parse_callback_assignment("echo foo -> BAR").unwrap_err();
assert!(err
.to_string()
.contains("callback assignment target must start with '$'"));
}
#[test]
fn build_request_allows_manual_content_type_override_for_multipart() {
let mut request = base_request();
request
.headers
.insert("Content-Type".into(), "application/octet-stream".into());
request
.form_data
.insert("name".into(), FormDataType::Text("hen".into()));
let client = reqwest::Client::new();
let built = request
.build_reqwest_request(&client, &HashMap::new())
.expect("request should build");
let content_type = built
.headers()
.get(reqwest::header::CONTENT_TYPE)
.expect("content-type header should exist")
.to_str()
.expect("content-type should be valid utf-8");
assert_eq!(content_type, "application/octet-stream");
assert_eq!(
built
.headers()
.get_all(reqwest::header::CONTENT_TYPE)
.iter()
.count(),
1
);
}
#[test]
fn build_request_uses_curl_style_multipart_boundary_without_override() {
let mut request = base_request();
request
.form_data
.insert("name".into(), FormDataType::Text("hen".into()));
let client = reqwest::Client::new();
let built = request
.build_reqwest_request(&client, &HashMap::new())
.expect("request should build");
let content_type = built
.headers()
.get(reqwest::header::CONTENT_TYPE)
.expect("content-type header should exist")
.to_str()
.expect("content-type should be valid utf-8");
let boundary = content_type
.strip_prefix("multipart/form-data; boundary=")
.expect("multipart boundary should be present");
let body = built
.body()
.and_then(|body| body.as_bytes())
.expect("multipart body should be buffered");
let body = String::from_utf8(body.to_vec()).expect("text-only multipart should be utf-8");
assert!(boundary.starts_with("------------------------"));
assert_eq!(boundary.len(), 56);
assert!(body.starts_with(&format!("--{}\r\n", boundary)));
assert!(body.contains("Content-Disposition: form-data; name=\"name\"\r\n\r\nhen\r\n"));
assert!(body.ends_with(&format!("--{}--\r\n", boundary)));
}
#[test]
fn as_curl_omits_auto_content_type_when_manual_override_exists() {
let mut request = base_request();
request.body = Some("{}".into());
request.body_content_type = Some("application/json".into());
request
.headers
.insert("Content-Type".into(), "application/octet-stream".into());
let curl = request.as_curl();
assert!(curl.contains("-H 'Content-Type: application/octet-stream'"));
assert!(!curl.contains("-H 'Content-Type: application/json'"));
}
#[test]
fn build_request_returns_invalid_header_name_error() {
let mut request = base_request();
request.headers.insert("bad header".into(), "value".into());
let client = reqwest::Client::new();
let err = request
.build_reqwest_request(&client, &HashMap::new())
.expect_err("request should fail");
assert!(err.to_string().contains("Invalid header name"));
}
}