use std::{
collections::HashMap,
fmt,
sync::{OnceLock, RwLock},
};
use pest::Parser;
use pest_derive::Parser;
#[derive(Parser)]
#[grammar = "src/parser/context.pest"]
struct VarPlaceholderParser;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PromptMode {
Interactive,
NonInteractive,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PromptInputError {
prompt: String,
default: Option<String>,
}
impl PromptInputError {
fn missing(prompt: String, default: Option<String>) -> Self {
Self { prompt, default }
}
pub fn prompt(&self) -> &str {
&self.prompt
}
pub fn default(&self) -> Option<&str> {
self.default.as_deref()
}
}
impl fmt::Display for PromptInputError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match &self.default {
Some(default) => write!(
f,
"Missing value for prompt '{}' (default: {})",
self.prompt, default
),
None => write!(f, "Missing value for prompt '{}'", self.prompt),
}
}
}
impl std::error::Error for PromptInputError {}
pub fn inject_from_prompt(input: &str) -> String {
try_inject_from_prompt(input).expect("prompt resolution failed")
}
pub fn try_inject_from_prompt(input: &str) -> Result<String, PromptInputError> {
let mut output = String::new();
let mut pairs = VarPlaceholderParser::parse(Rule::text, input).unwrap();
let pair = pairs.next().unwrap();
for inner_pair in pair.into_inner() {
match inner_pair.as_rule() {
Rule::word => {
output.push_str(inner_pair.as_str());
}
Rule::input => {
let (key, default) = parse_input_pair(inner_pair.as_str());
if let Some(value) = prompt_inputs().read().unwrap().get(&key).cloned() {
output.push_str(value.as_str());
continue;
}
if prompt_mode() == PromptMode::NonInteractive {
return Err(PromptInputError::missing(key, default));
}
let prompt = match &default {
Some(def) => format!("Provide a value for \"{}\" (default: {})", key, def),
None => format!("Provide a value for \"{}\"", key),
};
let mut dialog = dialoguer::Input::new().with_prompt(prompt);
if let Some(def) = &default {
dialog = dialog.default(def.to_string());
}
let input: String = dialog.interact().unwrap();
prompt_inputs()
.write()
.unwrap()
.insert(key.clone(), input.clone());
output.push_str(input.as_str());
}
Rule::var => {
output.push_str(format!("{{{{{}}}}}", inner_pair.as_str()).as_str());
}
_ => {
unreachable!("unexpected rule: {:?}", inner_pair.as_rule());
}
}
}
Ok(output)
}
pub fn inject_from_variable(input: &str, context: &HashMap<String, String>) -> String {
let mut output = String::new();
let mut pairs = VarPlaceholderParser::parse(Rule::text, input).unwrap();
let pair = pairs.next().unwrap();
for inner_pair in pair.into_inner() {
match inner_pair.as_rule() {
Rule::word => {
output.push_str(inner_pair.as_str());
}
Rule::var => {
let key = inner_pair.as_str().to_string();
if let Some(value) = context.get(&key) {
output.push_str(value);
} else {
output.push_str(format!("{{{{{}}}}}", key).as_str());
}
}
Rule::input => {
output.push_str(format!("[[{}]]", inner_pair.as_str()).as_str());
}
_ => {
unreachable!("unexpected rule: {:?}", inner_pair.as_rule());
}
}
}
output
}
pub fn resolve_with_context(input: &str, context: &HashMap<String, String>) -> String {
let mut current = inject_from_variable(input, context);
loop {
let next = inject_from_variable(current.as_str(), context);
if next == current {
return current;
}
current = next;
}
}
pub fn extract_placeholders(input: &str) -> Vec<String> {
let mut names = Vec::new();
let mut pairs = VarPlaceholderParser::parse(Rule::text, input).unwrap();
let pair = pairs.next().unwrap();
for inner_pair in pair.into_inner() {
if inner_pair.as_rule() == Rule::var {
names.push(inner_pair.as_str().to_string());
}
}
names
}
pub fn set_prompt_inputs(inputs: HashMap<String, String>) {
let mut map = prompt_inputs().write().unwrap();
map.clear();
map.extend(inputs);
}
pub fn set_prompt_mode(mode: PromptMode) {
*prompt_mode_lock().write().unwrap() = mode;
}
pub fn prompt_mode() -> PromptMode {
*prompt_mode_lock().read().unwrap()
}
pub fn extract_prompt_placeholders(input: &str) -> Vec<(String, Option<String>)> {
let mut prompts = Vec::new();
let mut remainder = input;
while let Some(start) = remainder.find("[[") {
let after_start = &remainder[start + 2..];
let Some(end) = after_start.find("]]") else {
break;
};
let raw = after_start[..end].trim();
if !raw.is_empty() {
prompts.push(parse_input_pair(raw));
}
remainder = &after_start[end + 2..];
}
prompts
}
pub fn has_prompt_input(key: &str) -> bool {
prompt_inputs().read().unwrap().contains_key(key)
}
fn prompt_inputs() -> &'static RwLock<HashMap<String, String>> {
static PROMPT_INPUTS: OnceLock<RwLock<HashMap<String, String>>> = OnceLock::new();
PROMPT_INPUTS.get_or_init(|| RwLock::new(HashMap::new()))
}
fn prompt_mode_lock() -> &'static RwLock<PromptMode> {
static PROMPT_MODE: OnceLock<RwLock<PromptMode>> = OnceLock::new();
PROMPT_MODE.get_or_init(|| RwLock::new(PromptMode::Interactive))
}
fn parse_input_pair(raw: &str) -> (String, Option<String>) {
let mut parts = raw.splitn(2, '=');
let key = parts.next().unwrap().trim().to_string();
let default = parts.next().map(|value| value.trim().to_string());
(key, default)
}
#[cfg(test)]
mod tests {
use super::*;
fn test_guard() -> std::sync::MutexGuard<'static, ()> {
static TEST_GUARD: OnceLock<std::sync::Mutex<()>> = OnceLock::new();
TEST_GUARD
.get_or_init(|| std::sync::Mutex::new(()))
.lock()
.unwrap()
}
#[test]
fn should_replace_variables() {
let _guard = test_guard();
let input = "this is a test with a {{variable}}";
let mut context = HashMap::new();
context.insert("variable".to_string(), "value".to_string());
let output = inject_from_variable(input, &context);
assert_eq!(output, "this is a test with a value");
}
#[test]
fn should_use_provided_prompt_inputs() {
let _guard = test_guard();
let mut inputs = HashMap::new();
inputs.insert("foo".to_string(), "bar".to_string());
set_prompt_mode(PromptMode::Interactive);
set_prompt_inputs(inputs);
let output = inject_from_prompt("value [[ foo ]]");
assert_eq!(output, "value bar");
set_prompt_inputs(HashMap::new());
set_prompt_mode(PromptMode::Interactive);
}
#[test]
fn should_extract_prompt_placeholders_from_multiline_source() {
let prompts = extract_prompt_placeholders(
"GET https://example.com/[[ token ]]\n? region = [[ region = us-east-1 ]]",
);
assert_eq!(
prompts,
vec![
("token".to_string(), None),
("region".to_string(), Some("us-east-1".to_string()))
]
);
}
#[test]
fn should_fail_on_missing_prompt_when_noninteractive() {
let _guard = test_guard();
set_prompt_mode(PromptMode::NonInteractive);
set_prompt_inputs(HashMap::new());
let err = try_inject_from_prompt("value [[ foo = bar ]]").unwrap_err();
assert_eq!(err.prompt(), "foo");
assert_eq!(err.default(), Some("bar"));
set_prompt_mode(PromptMode::Interactive);
}
}