hen 0.9.0

Run API collections from the command line.
/// Parses the contents of a collection file into a Collection struct.
use std::{collections::HashMap, hash::Hash, path::PathBuf};

use pest::Parser;
use pest_derive::Parser;

pub mod context;
mod preprocessor;

use crate::{
    collection::Collection,
    request::{Assertion, FormDataType, Request, ResponseCapture},
};

#[derive(Parser)]
#[grammar = "src/parser/grammar.pest"]
struct CollectionParser;

pub fn parse_collection(
    input: &str,
    working_dir: PathBuf,
) -> Result<Collection, pest::error::Error<Rule>> {
    let preprocessed = preprocessor::preprocess(input, working_dir.clone()).map_err(|e| {
        pest::error::Error::new_from_span(
            pest::error::ErrorVariant::CustomError {
                message: e.to_string(),
            },
            pest::Span::new(input, 0, input.len()).unwrap(),
        )
    })?;

    log::debug!("PREPROCESSED COMPLETE:\n{}", preprocessed);

    let mut pairs = CollectionParser::parse(Rule::request_collection, preprocessed.as_str())?;

    let collection = pairs.next().unwrap();

    let mut name = String::new();
    let mut description = String::new();
    let mut requests = Vec::new();
    let mut context = HashMap::new();
    let mut global_headers: HashMap<String, String> = HashMap::new();
    let mut global_queries: HashMap<String, String> = HashMap::new();
    let mut global_callbacks: Vec<String> = vec![];

    for pair in collection.into_inner() {
        match pair.as_rule() {
            Rule::collection_name => {
                name = pair.as_str().trim().to_string();
            }

            Rule::collection_description => {
                description.push_str(pair.as_str().trim());
            }

            Rule::variable => {
                let inner_pairs = pair.into_inner();
                let key = inner_pairs
                    .clone()
                    .next()
                    .unwrap()
                    .as_str()
                    .trim()
                    .to_string();
                let value = inner_pairs.clone().nth(1).unwrap().as_str().to_string();

                // if value is a shell script, evaluate it
                if value.starts_with("$(") {
                    let script = value.trim_start_matches("$(").trim_end_matches(")");
                    let value = eval_shell_script(script, &working_dir, None)
                        .trim()
                        .to_string();
                    context.insert(key, value);
                    continue;
                }

                context.insert(key, context::inject_from_prompt(&value));
            }

            Rule::header => {
                let mut inner_pairs = pair.into_inner();
                let key = inner_pairs.next().unwrap().as_str().trim().to_string();
                let value = inner_pairs.next().unwrap().as_str().trim().to_string();

                global_headers.insert(key, context::inject_from_prompt(&value));
            }

            Rule::query => {
                let mut inner_pairs = pair.into_inner();
                let key = inner_pairs.next().unwrap().as_str().trim().to_string();
                let value = inner_pairs.next().unwrap().as_str().trim().to_string();

                global_queries.insert(key, context::inject_from_prompt(&value));
            }

            Rule::callback => {
                // drop the leading "!" character
                global_callbacks.push(pair.as_str().strip_prefix('!').unwrap().to_string());
            }

            Rule::requests => {
                for request_pair in pair.into_inner() {
                    requests.push(parse_request(
                        request_pair,
                        context.clone(),
                        global_headers.clone(),
                        global_queries.clone(),
                        global_callbacks.clone(),
                        &working_dir,
                    ));
                }
            }

            _ => {
                unreachable!("unexpected rule: {:?}", pair.as_rule());
            }
        }
    }

    Ok(Collection {
        name,
        description,
        requests,
    })
}

pub fn parse_request(
    pair: pest::iterators::Pair<Rule>,
    collection_context: HashMap<String, String>,
    global_headers: HashMap<String, String>,
    global_queries: HashMap<String, String>,
    global_callbacks: Vec<String>,
    working_dir: &PathBuf,
) -> Request {
    let mut method = None;
    let mut url_raw: Option<String> = None;
    let mut header_raw: HashMap<String, String> = HashMap::new();
    let mut query_raw: HashMap<String, String> = HashMap::new();
    let mut form_text_raw: HashMap<String, String> = HashMap::new();
    let mut form_file_raw: HashMap<String, String> = HashMap::new();
    let mut body_raw: Option<String> = None;
    let mut body_content_type_raw: Option<String> = None;
    let mut description = String::new();
    let mut callback_src: Vec<String> = global_callbacks;
    let mut response_captures: Vec<ResponseCapture> = Vec::new();
    let mut dependencies: Vec<String> = Vec::new();
    let mut assertions: Vec<Assertion> = Vec::new();
    let mut request_context = collection_context.clone();
    for pair in pair.into_inner() {
        match pair.as_rule() {
            Rule::description => {
                // push description to string
                description.push_str(pair.as_str().trim());
            }
            Rule::variable => {
                let mut inner_pairs = pair.into_inner();
                let key = inner_pairs.next().unwrap().as_str().trim().to_string();
                let raw_value = inner_pairs.next().unwrap().as_str().to_string();

                let value = if raw_value.starts_with("$(") {
                    let script = raw_value.trim_start_matches("$(").trim_end_matches(")");
                    eval_shell_script(script, working_dir, None)
                        .trim()
                        .to_string()
                } else {
                    context::inject_from_variable(raw_value.as_str(), &request_context)
                };

                request_context.insert(key, value);
            }
            Rule::http_method => {
                method = Some(pair.as_str().parse().unwrap());
            }
            Rule::url => {
                url_raw = Some(pair.as_str().trim().to_string());
            }
            Rule::header => {
                let mut inner_pairs = pair.into_inner();
                let key = inner_pairs.next().unwrap().as_str().trim().to_string();
                let value = inner_pairs.next().unwrap().as_str().trim().to_string();

                header_raw.insert(key, value);
            }
            Rule::query => {
                let mut inner_pairs = pair.into_inner();
                let key = inner_pairs.next().unwrap().as_str().trim().to_string();
                let value = inner_pairs.next().unwrap().as_str().trim().to_string();
                query_raw.insert(key, value);
            }
            Rule::form => {
                let mut inner_pairs = pair.into_inner();
                let key = inner_pairs.next().unwrap().as_str().trim().to_string();

                // the value can be either a File of Text type
                let value = inner_pairs.next().unwrap();
                match value.as_rule() {
                    Rule::file => {
                        // drop the first character from the filepath which is a "@" symbol
                        let trimmed = value.as_str().trim_start_matches('@').trim().to_string();
                        form_file_raw.insert(key, trimmed);
                    }
                    Rule::text => {
                        form_text_raw.insert(key, value.as_str().to_string());
                    }
                    _ => {
                        unreachable!("unexpected rule: {:?}", value.as_rule());
                    }
                }
            }
            Rule::body => {
                body_raw = Some(pair.as_str().to_string());
            }
            Rule::body_content_type => {
                body_content_type_raw = Some(pair.as_str().trim().to_string());
            }
            Rule::dependency => {
                let dependency = parse_dependency(pair.as_str());
                dependencies.push(dependency);
            }
            Rule::callback => {
                // drop the leading "!" character
                callback_src.push(pair.as_str().strip_prefix('!').unwrap().to_string());
            }
            Rule::response_capture => {
                let raw = pair.as_str();
                let capture = ResponseCapture::parse(raw, &request_context)
                    .map_err(|e| {
                        pest::error::Error::<Rule>::new_from_span(
                            pest::error::ErrorVariant::CustomError { message: e.0 },
                            pair.as_span(),
                        )
                    })
                    .unwrap();

                response_captures.push(capture);
            }
            Rule::assertion => {
                let raw = pair.as_str();
                let assertion = Assertion::parse(raw, &request_context)
                    .map_err(|e| {
                        pest::error::Error::<Rule>::new_from_span(
                            pest::error::ErrorVariant::CustomError { message: e.0 },
                            pair.as_span(),
                        )
                    })
                    .unwrap();

                assertions.push(assertion);
            }
            _ => {
                unreachable!("unexpected rule: {:?}", pair.as_rule());
            }
        }
    }

    // if the description is empty, set it to "No description"
    if description.is_empty() {
        description = "[No Description]".to_string();
    }

    let url = context::inject_from_variable(
        url_raw.as_ref().expect("Request is missing a URL"),
        &request_context,
    );

    let mut headers = HashMap::new();
    for (key, value) in global_headers {
        let resolved = context::inject_from_variable(value.as_str(), &request_context);
        headers.insert(key, resolved);
    }

    for (key, raw_value) in header_raw {
        let resolved = context::inject_from_variable(raw_value.as_str(), &request_context);
        headers.insert(key, resolved);
    }

    let mut query_params = HashMap::new();
    for (key, value) in global_queries {
        let resolved = context::inject_from_variable(value.as_str(), &request_context);
        query_params.insert(key, resolved);
    }

    for (key, raw_value) in query_raw {
        let resolved = context::inject_from_variable(raw_value.as_str(), &request_context);
        query_params.insert(key, resolved);
    }

    let mut form_data: HashMap<String, FormDataType> = HashMap::new();
    for (key, raw_value) in form_text_raw {
        let resolved = context::inject_from_variable(raw_value.as_str(), &request_context);
        form_data.insert(key, FormDataType::Text(resolved));
    }

    for (key, raw_value) in form_file_raw {
        let resolved = context::inject_from_variable(raw_value.as_str(), &request_context);
        let candidate_path = PathBuf::from(&resolved);
        let abs_path = if candidate_path.is_absolute() {
            candidate_path
        } else {
            working_dir.join(candidate_path)
        };
        form_data.insert(key, FormDataType::File(abs_path));
    }

    let body = body_raw
        .as_ref()
        .map(|raw| context::inject_from_variable(raw.as_str(), &request_context));

    let body_content_type = body_content_type_raw
        .as_ref()
        .map(|raw| context::inject_from_variable(raw.as_str(), &request_context));

    Request {
        description,
        method: method.unwrap(),
        url,
        headers,
        query_params,
        form_data,
        body,
        body_content_type,
        callback_src,
        response_captures,
        assertions,
        dependencies,
        context: request_context,
        working_dir: working_dir.clone(),
    }
}

fn parse_dependency(raw: &str) -> String {
    let remainder = raw
        .strip_prefix('>')
        .map(|s| s.trim())
        .unwrap_or(raw.trim());

    let remainder = remainder
        .strip_prefix("requires")
        .map(|s| s.trim_start())
        .unwrap_or(remainder);

    let remainder = remainder
        .strip_prefix(':')
        .map(|s| s.trim_start())
        .unwrap_or(remainder);

    let name = remainder.trim();

    if name.starts_with('"') && name.ends_with('"') && name.len() >= 2 {
        name[1..name.len() - 1].trim().to_string()
    } else {
        name.to_string()
    }
}

pub fn eval_shell_script(
    script: &str,
    working_dir: &PathBuf,
    env: Option<HashMap<String, String>>,
) -> String {
    let env = env.unwrap_or_default();
    log::debug!("evaluating shell script: {}", script);
    log::debug!("using directory {:?}", working_dir);
    let output = std::process::Command::new("sh")
        .current_dir(working_dir)
        .arg("-c")
        .envs(env)
        .arg(script)
        .output()
        .expect("failed to execute process");

    String::from_utf8(output.stdout).unwrap()
}