use http::Method;
use reqwest::header::HeaderMap;
use serde_json::Value;
use std::{collections::HashMap, path::PathBuf};
use crate::parser::{self, context};
mod assertion;
mod planner;
mod response_capture;
mod runner;
mod template;
pub use assertion::Assertion;
pub use planner::RequestPlanner;
pub use response_capture::{ResponseCapture, ResponseSnapshot};
pub use runner::{
execute_plan as execute_request_plan,
execute_plan_with_observer as execute_request_plan_with_observer, ExecutionEvent,
ExecutionObserver, ExecutionOptions, ExecutionRecord, RequestFailure, RequestFailureKind,
};
pub use template::{expand_templates, RequestTemplate, TemplateError, VariableStore};
#[derive(Debug, Clone)]
pub enum FormDataType {
Text(String),
File(PathBuf),
}
#[derive(Debug, Clone)]
pub struct Request {
pub description: String,
pub base_description: String,
pub method: Method,
pub url: String,
pub headers: HashMap<String, String>,
pub query_params: HashMap<String, String>,
pub form_data: HashMap<String, FormDataType>,
pub body: Option<String>,
pub body_content_type: Option<String>,
pub callback_src: Vec<String>,
pub response_captures: Vec<ResponseCapture>,
pub assertions: Vec<Assertion>,
pub declared_dependencies: Vec<String>,
pub dependencies: Vec<String>,
pub context: HashMap<String, String>,
pub working_dir: PathBuf,
pub map_iteration: Option<MapIteration>,
}
#[derive(Debug, Clone)]
pub struct RequestExecution {
pub output: String,
pub export_env: HashMap<String, String>,
pub snapshot: ResponseSnapshot,
}
#[derive(Debug, Clone)]
pub struct MapIteration {
pub variables: Vec<(String, String)>,
pub label: String,
}
impl MapIteration {
pub fn suffix(&self) -> String {
if self.label.is_empty() {
String::new()
} else {
format!("[{}]", self.label)
}
}
}
impl Request {
pub async fn exec(
&self,
inherited_context: &HashMap<String, String>,
dependency_snapshots: &HashMap<String, ResponseSnapshot>,
observer: Option<&ExecutionObserver>,
) -> Result<RequestExecution, Box<dyn std::error::Error>> {
let client = reqwest::Client::new();
let mut context_map = inherited_context.clone();
for (key, value) in &self.context {
context_map.insert(key.clone(), value.clone());
}
let resolved_url = context::resolve_with_context(&self.url, &context_map);
let mut request = client.request(
self.method.clone(),
context::inject_from_prompt(&resolved_url),
);
let mut header_map = HeaderMap::new();
if let Some(content_type) = &self.body_content_type {
let resolved = context::resolve_with_context(content_type, &context_map);
header_map.insert(
reqwest::header::CONTENT_TYPE,
context::inject_from_prompt(&resolved)
.parse::<reqwest::header::HeaderValue>()
.unwrap(),
);
}
for (key, value) in &self.headers {
let resolved_value = context::resolve_with_context(value, &context_map);
header_map.insert(
key.parse::<reqwest::header::HeaderName>()
.expect(format!("Invalid header name: {}", key).as_str()),
context::inject_from_prompt(&resolved_value)
.parse::<reqwest::header::HeaderValue>()
.unwrap(),
);
}
let mut form = reqwest::multipart::Form::new();
for (key, value) in &self.form_data {
match value {
FormDataType::Text(text) => {
let resolved = context::resolve_with_context(text, &context_map);
form = form.text(key.clone(), context::inject_from_prompt(&resolved));
}
FormDataType::File(filepath) => {
let file = reqwest::multipart::Part::bytes(std::fs::read(filepath)?)
.file_name(filepath.file_name().unwrap().to_str().unwrap().to_string());
form = form.part(key.clone(), file);
}
}
}
let mut query_params: HashMap<String, String> = HashMap::new();
self.query_params.iter().for_each(|(key, value)| {
let resolved = context::resolve_with_context(value, &context_map);
query_params.insert(key.clone(), context::inject_from_prompt(&resolved));
});
request = match self.form_data.len() > 0 {
true => request
.headers(header_map)
.query(&query_params)
.multipart(form),
false => request.headers(header_map).query(&query_params),
};
if let Some(body) = &self.body {
let resolved = context::resolve_with_context(body, &context_map);
request = request.body(context::inject_from_prompt(&resolved));
}
log::debug!("REQUEST\n{:#?}", request);
let resp = request.send().await?;
log::debug!("RESPONSE\n{:#?}", resp);
let status = resp.status();
let headers = resp.headers().clone();
let resp_text = resp.text().await.unwrap();
let json_body = serde_json::from_str::<Value>(&resp_text).ok();
let sanitized_response = resp_text.replace('\0', "");
let snapshot = ResponseSnapshot {
status,
headers: headers.clone(),
body: sanitized_response.clone(),
json: json_body.clone(),
};
let mut captured_values: HashMap<String, String> = HashMap::new();
for capture in &self.response_captures {
let source_snapshot = match &capture.source {
response_capture::CaptureSource::Current => &snapshot,
response_capture::CaptureSource::Dependency(name) => {
dependency_snapshots.get(name).ok_or_else(|| {
Box::<dyn std::error::Error>::from(CaptureDependencyError {
dependency: name.clone(),
request: self.description.clone(),
})
})?
}
};
match capture.extract_from_snapshot(source_snapshot) {
Ok(Some((name, value))) => {
captured_values.insert(name, value);
}
Ok(None) => {}
Err(e) => return Err(Box::new(e)),
}
}
if !captured_values.is_empty() {
log::debug!("CAPTURED VALUES\n{:#?}", captured_values);
}
let status_code = status.as_str().to_string();
let mut export_env: HashMap<String, String> = context_map.clone();
export_env.insert("RESPONSE".to_string(), sanitized_response.clone());
export_env.insert("STATUS".to_string(), status_code.clone());
if let Some(reason) = status.canonical_reason() {
export_env.insert("STATUS_TEXT".to_string(), reason.to_string());
}
export_env.insert("DESCRIPTION".to_string(), self.description.clone());
for (key, value) in &captured_values {
export_env.insert(key.clone(), value.replace('\0', ""));
}
let mut assignment_callbacks: Vec<(String, String)> = Vec::new();
let mut regular_callbacks: Vec<&String> = Vec::new();
if !self.callback_src.is_empty() {
for src in &self.callback_src {
match parse_callback_assignment(src) {
Ok(Some((command, variable))) => assignment_callbacks.push((command, variable)),
Ok(None) => regular_callbacks.push(src),
Err(err) => return Err(Box::new(err)),
}
}
}
for (command, variable) in assignment_callbacks {
let output =
parser::eval_shell_script(&command, &self.working_dir, Some(export_env.clone()));
let sanitized = sanitize_callback_assignment_output(&output);
export_env.insert(variable, sanitized);
}
for assertion in &self.assertions {
assertion
.evaluate(&export_env, &snapshot, &dependency_snapshots)
.map_err(|e| -> Box<dyn std::error::Error> { Box::new(e) })?;
if let Some(callback) = observer {
callback(ExecutionEvent::AssertionPassed {
request: self.description.clone(),
assertion: assertion.raw.clone(),
});
}
}
let mut callback_resp: Vec<String> = vec![];
if !regular_callbacks.is_empty() {
for src in regular_callbacks {
let output =
parser::eval_shell_script(src, &self.working_dir, Some(export_env.clone()));
callback_resp.push(output);
}
log::debug!("CALLBACK RESPONSE\n{:#?}", callback_resp);
}
let output = if callback_resp.len() > 0 {
callback_resp.join("\n")
} else {
resp_text
};
let final_export_env = apply_map_suffix_to_exports(export_env, self.map_iteration.as_ref());
Ok(RequestExecution {
output,
export_env: final_export_env,
snapshot,
})
}
pub fn as_curl(&self) -> String {
let mut curl = String::new();
let headers = self
.headers
.iter()
.map(|(key, value)| format!("-H '{}: {}'", key, value))
.collect::<Vec<String>>()
.join(" ");
let headers = match &self.body_content_type {
Some(content_type) => format!("-H 'Content-Type: {}' {}", content_type, headers),
None => headers,
};
let query_params = self
.query_params
.iter()
.map(|(key, value)| format!("{}={}", key, value))
.collect::<Vec<String>>()
.join("&");
let query_params = match query_params.len() {
0 => "".to_string(),
_ => format!("?{}", query_params),
};
let form_data = self
.form_data
.iter()
.map(|(key, value)| match value {
FormDataType::Text(text) => format!("-F '{}={}'", key, text),
FormDataType::File(filepath) => format!("-F '{}=@{}'", key, filepath.display()),
})
.collect::<Vec<String>>()
.join(" ");
let body = match &self.body {
Some(body) => format!("-d ' {}'", body),
None => "".to_string(),
};
curl.push_str(&format!(
"curl -X {} '{}{}' {} {} {}",
self.method, self.url, query_params, headers, form_data, body
));
curl
}
}
fn parse_callback_assignment(
src: &str,
) -> Result<Option<(String, String)>, CallbackAssignmentError> {
let trimmed = src.trim();
let Some(index) = trimmed.find("->") else {
return Ok(None);
};
if index == 0 {
return Err(CallbackAssignmentError::new(
"callback assignment is missing a command before '->'",
));
}
let before_arrow = &trimmed[..index];
let after_arrow = &trimmed[index + 2..];
let has_surrounding_whitespace = before_arrow
.chars()
.last()
.map(|c| c.is_whitespace())
.unwrap_or(false)
&& after_arrow
.chars()
.next()
.map(|c| c.is_whitespace())
.unwrap_or(false);
if !has_surrounding_whitespace {
return Err(CallbackAssignmentError::new(
"callback assignment requires whitespace on both sides of '->'",
));
}
let command = before_arrow.trim();
if command.is_empty() {
return Err(CallbackAssignmentError::new(
"callback assignment command cannot be empty",
));
}
let remainder = after_arrow.trim();
if !remainder.starts_with('$') {
return Err(CallbackAssignmentError::new(
"callback assignment target must start with '$'",
));
}
let name = remainder.trim_start_matches('$').trim();
if name.is_empty() {
return Err(CallbackAssignmentError::new(
"callback assignment target variable name cannot be empty",
));
}
Ok(Some((command.to_string(), name.to_string())))
}
fn sanitize_callback_assignment_output(output: &str) -> String {
output
.trim_end_matches(['\n', '\r'])
.replace('\0', "")
.to_string()
}
#[derive(Debug)]
struct CallbackAssignmentError {
message: String,
}
impl CallbackAssignmentError {
fn new(message: impl Into<String>) -> Self {
Self {
message: message.into(),
}
}
}
impl std::fmt::Display for CallbackAssignmentError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.message)
}
}
impl std::error::Error for CallbackAssignmentError {}
#[derive(Debug)]
struct CaptureDependencyError {
dependency: String,
request: String,
}
impl std::fmt::Display for CaptureDependencyError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Request '{}' references dependency '{}' in a capture but it has not executed yet.",
self.request, self.dependency
)
}
}
impl std::error::Error for CaptureDependencyError {}
fn apply_map_suffix_to_exports(
export_env: HashMap<String, String>,
iteration: Option<&MapIteration>,
) -> HashMap<String, String> {
let Some(iteration) = iteration else {
return export_env;
};
let suffix = iteration.suffix();
if suffix.is_empty() {
return export_env;
}
let reserved = ["RESPONSE", "STATUS", "STATUS_TEXT", "DESCRIPTION"];
let mut transformed: HashMap<String, String> = HashMap::with_capacity(export_env.len());
for (key, value) in export_env.into_iter() {
if reserved.iter().any(|reserved_key| reserved_key == &key) {
transformed.insert(key, value);
} else {
transformed.insert(format!("{}{}", key, suffix), value);
}
}
transformed
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn callback_assignment_parses_with_whitespace() {
let (command, variable) = parse_callback_assignment("echo foo -> $BAR")
.unwrap()
.expect("expected assignment");
assert_eq!(command, "echo foo");
assert_eq!(variable, "BAR");
}
#[test]
fn callback_assignment_requires_whitespace() {
let err = parse_callback_assignment("echo foo->$BAR").unwrap_err();
assert!(err.to_string().contains("whitespace on both sides of '->'"));
}
#[test]
fn callback_assignment_requires_dollar_prefix() {
let err = parse_callback_assignment("echo foo -> BAR").unwrap_err();
assert!(err
.to_string()
.contains("callback assignment target must start with '$'"));
}
}