use crate::parser::parse;
use crate::Command;
use std::error::Error;
use std::io::Write as _;
pub trait Runner {
fn run(&mut self, command: &Command) -> Result<String, Box<dyn Error>>;
fn start_script(&mut self) -> Result<(), Box<dyn Error>> {
Ok(())
}
fn end_script(&mut self) -> Result<(), Box<dyn Error>> {
Ok(())
}
fn start_block(&mut self) -> Result<String, Box<dyn Error>> {
Ok(String::new())
}
fn end_block(&mut self) -> Result<String, Box<dyn Error>> {
Ok(String::new())
}
#[allow(unused_variables)]
fn start_command(&mut self, command: &Command) -> Result<String, Box<dyn Error>> {
Ok(String::new())
}
#[allow(unused_variables)]
fn end_command(&mut self, command: &Command) -> Result<String, Box<dyn Error>> {
Ok(String::new())
}
}
pub fn run<R: Runner, P: AsRef<std::path::Path>>(runner: &mut R, path: P) -> std::io::Result<()> {
let path = path.as_ref();
let Some(dir) = path.parent() else {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
format!("invalid path '{path:?}'"),
));
};
let Some(filename) = path.file_name() else {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
format!("invalid path '{path:?}'"),
));
};
let input = std::fs::read_to_string(dir.join(filename))?;
let output = generate(runner, &input)?;
goldenfile::Mint::new(dir).new_goldenfile(filename)?.write_all(output.as_bytes())
}
pub fn generate<R: Runner>(runner: &mut R, input: &str) -> std::io::Result<String> {
let mut output = String::with_capacity(input.len());
let eol = match input.find("\r\n") {
Some(_) => "\r\n",
None => "\n",
};
let blocks = parse(input).map_err(|e| {
std::io::Error::new(
std::io::ErrorKind::InvalidInput,
format!(
"parse error at line {} column {} for {:?}:\n{}\n{}^",
e.input.location_line(),
e.input.get_column(),
e.code,
String::from_utf8_lossy(e.input.get_line_beginning()),
' '.to_string().repeat(e.input.get_utf8_column() - 1)
),
)
})?;
runner.start_script().map_err(|e| {
std::io::Error::new(std::io::ErrorKind::Other, format!("start_script failed: {e}"))
})?;
for (i, block) in blocks.iter().enumerate() {
if block.commands.is_empty() {
output.push_str(&block.literal);
continue;
}
let mut block_output = String::new();
block_output.push_str(&ensure_eol(
runner.start_block().map_err(|e| {
std::io::Error::new(
std::io::ErrorKind::Other,
format!("start_block failed at line {}: {e}", block.line_number),
)
})?,
eol,
));
for command in &block.commands {
let mut command_output = String::new();
command_output.push_str(&ensure_eol(
runner.start_command(command).map_err(|e| {
std::io::Error::new(
std::io::ErrorKind::Other,
format!("start_command failed at line {}: {e}", command.line_number),
)
})?,
eol,
));
let run = std::panic::AssertUnwindSafe(|| runner.run(command));
command_output.push_str(&match std::panic::catch_unwind(run) {
Ok(Ok(output)) if command.fail => {
return Err(std::io::Error::new(
std::io::ErrorKind::Other,
format!(
"expected command '{}' to fail at line {}, succeeded with: {output}",
command.name, command.line_number
),
))
}
Ok(Ok(output)) => output,
Ok(Err(e)) if command.fail => format!("Error: {e}"),
Ok(Err(e)) => {
return Err(std::io::Error::new(
std::io::ErrorKind::Other,
format!(
"command '{}' failed at line {}: {e}",
command.name, command.line_number
),
))
}
Err(panic) if command.fail => {
let message = panic
.downcast_ref::<&str>()
.map(|s| s.to_string())
.or_else(|| panic.downcast_ref::<String>().cloned())
.unwrap_or_else(|| std::panic::resume_unwind(panic));
format!("Panic: {message}")
}
Err(panic) => std::panic::resume_unwind(panic),
});
command_output = ensure_eol(command_output, eol);
command_output.push_str(&ensure_eol(
runner.end_command(command).map_err(|e| {
std::io::Error::new(
std::io::ErrorKind::Other,
format!("end_command failed at line {}: {e}", command.line_number),
)
})?,
eol,
));
if command.silent {
command_output = "".to_string();
}
if let Some(prefix) = &command.prefix {
if !command_output.is_empty() {
command_output = format!(
"{prefix}: {}{eol}",
command_output
.strip_suffix(eol)
.unwrap_or(command_output.as_str())
.replace('\n', &format!("\n{prefix}: "))
);
}
}
block_output.push_str(&command_output);
}
block_output.push_str(&ensure_eol(
runner.end_block().map_err(|e| {
std::io::Error::new(
std::io::ErrorKind::Other,
format!("end_block failed at line {}: {e}", block.line_number),
)
})?,
eol,
));
if block_output.is_empty() {
block_output.push_str("ok\n")
}
if block_output.starts_with('\n')
|| block_output.starts_with("\r\n")
|| block_output.contains("\n\n")
|| block_output.contains("\n\r\n")
{
block_output = format!("> {}", block_output.replace('\n', "\n> "));
block_output.pop();
block_output.pop();
}
output.push_str(&format!("{}---{eol}{}", block.literal, block_output));
if i < blocks.len() - 1 {
output.push_str(eol);
}
}
runner.end_script().map_err(|e| {
std::io::Error::new(std::io::ErrorKind::Other, format!("end_script failed: {e}"))
})?;
Ok(output)
}
fn ensure_eol(mut s: String, eol: &str) -> String {
if let Some(c) = s.chars().next_back() {
if c != '\n' {
s.push_str(eol)
}
}
s
}
#[cfg(test)]
mod tests {
use super::*;
#[derive(Default)]
struct HookRunner {
start_script_count: usize,
end_script_count: usize,
start_block_count: usize,
end_block_count: usize,
start_command_count: usize,
end_command_count: usize,
}
impl Runner for HookRunner {
fn run(&mut self, _: &Command) -> Result<String, Box<dyn Error>> {
Ok(String::new())
}
fn start_script(&mut self) -> Result<(), Box<dyn Error>> {
self.start_script_count += 1;
Ok(())
}
fn end_script(&mut self) -> Result<(), Box<dyn Error>> {
self.end_script_count += 1;
Ok(())
}
fn start_block(&mut self) -> Result<String, Box<dyn Error>> {
self.start_block_count += 1;
Ok(String::new())
}
fn end_block(&mut self) -> Result<String, Box<dyn Error>> {
self.end_block_count += 1;
Ok(String::new())
}
fn start_command(&mut self, _: &Command) -> Result<String, Box<dyn Error>> {
self.start_command_count += 1;
Ok(String::new())
}
fn end_command(&mut self, _: &Command) -> Result<String, Box<dyn Error>> {
self.end_command_count += 1;
Ok(String::new())
}
}
#[test]
fn hooks() {
let mut runner = HookRunner::default();
generate(
&mut runner,
r#"
command
---
command
command
---
"#,
)
.unwrap();
assert_eq!(runner.start_script_count, 1);
assert_eq!(runner.end_script_count, 1);
assert_eq!(runner.start_block_count, 2);
assert_eq!(runner.end_block_count, 2);
assert_eq!(runner.start_command_count, 3);
assert_eq!(runner.end_command_count, 3);
}
}