use std::collections::BTreeMap;
use indexmap::IndexMap;
use schemars::JsonSchema;
use serde::ser::SerializeMap;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, PartialEq)]
pub struct Command {
pub name: Option<String>,
pub template: String,
pub expanded: String,
}
impl Command {
pub fn new(name: Option<String>, template: String) -> Self {
Self {
name,
expanded: template.clone(),
template,
}
}
pub fn with_expansion(name: Option<String>, template: String, expanded: String) -> Self {
Self {
name,
template,
expanded,
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum HookStep {
Single(Command),
Concurrent(Vec<Command>),
}
#[derive(Debug, Clone, PartialEq)]
pub struct CommandConfig {
steps: Vec<HookStep>,
}
impl CommandConfig {
pub fn single(template: impl Into<String>) -> Self {
Self {
steps: vec![HookStep::Single(Command::new(None, template.into()))],
}
}
pub fn commands(&self) -> impl Iterator<Item = &Command> {
self.steps.iter().flat_map(|step| match step {
HookStep::Single(cmd) => std::slice::from_ref(cmd).iter(),
HookStep::Concurrent(cmds) => cmds.iter(),
})
}
pub fn is_pipeline(&self) -> bool {
self.steps.len() > 1
}
pub fn steps(&self) -> &[HookStep] {
&self.steps
}
pub fn merge_append(&self, other: &Self) -> Self {
let mut steps = self.steps.clone();
steps.extend(other.steps.iter().cloned());
Self { steps }
}
}
fn validate_no_colons<E: serde::de::Error>(map: &IndexMap<String, String>) -> Result<(), E> {
for name in map.keys() {
if name.contains(':') {
return Err(serde::de::Error::custom(format!(
"hook name '{}' cannot contain colons",
name
)));
}
}
Ok(())
}
fn map_to_step(map: IndexMap<String, String>) -> HookStep {
if map.len() == 1 {
let (name, template) = map.into_iter().next().unwrap();
HookStep::Single(Command::new(Some(name), template))
} else {
HookStep::Concurrent(
map.into_iter()
.map(|(name, template)| Command::new(Some(name), template))
.collect(),
)
}
}
pub fn append_aliases(
base: &mut BTreeMap<String, CommandConfig>,
additions: &BTreeMap<String, CommandConfig>,
) {
for (k, v) in additions {
base.entry(k.clone())
.and_modify(|existing| *existing = existing.merge_append(v))
.or_insert_with(|| v.clone());
}
}
const EXPECTING: &str = r#"a command in one of these forms:
- a string: "cargo build"
- a named table: { build = "cargo build", test = "cargo test" }
- a pipeline list: ["cargo build", { test = "cargo test" }]
run `wt hook --help` for details"#;
const EXPECTING_PIPELINE_ENTRY: &str =
r#"a command string "cargo build" or a named table { build = "cargo build" }"#;
enum PipelineEntry {
Anonymous(String),
Named(IndexMap<String, String>),
}
impl<'de> Deserialize<'de> for PipelineEntry {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
struct PipelineEntryVisitor;
impl<'de> serde::de::Visitor<'de> for PipelineEntryVisitor {
type Value = PipelineEntry;
fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
f.write_str(EXPECTING_PIPELINE_ENTRY)
}
fn visit_str<E: serde::de::Error>(self, v: &str) -> Result<Self::Value, E> {
Ok(PipelineEntry::Anonymous(v.to_string()))
}
fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
where
A: serde::de::MapAccess<'de>,
{
let mut entries: IndexMap<String, String> = IndexMap::new();
while let Some(key) = map.next_key::<String>()? {
let value = map.next_value::<String>()?;
entries.insert(key, value);
}
Ok(PipelineEntry::Named(entries))
}
}
deserializer.deserialize_any(PipelineEntryVisitor)
}
}
impl<'de> Deserialize<'de> for CommandConfig {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
struct CommandConfigVisitor;
impl<'de> serde::de::Visitor<'de> for CommandConfigVisitor {
type Value = CommandConfig;
fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
f.write_str(EXPECTING)
}
fn visit_str<E: serde::de::Error>(self, v: &str) -> Result<Self::Value, E> {
Ok(CommandConfig {
steps: vec![HookStep::Single(Command::new(None, v.to_string()))],
})
}
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where
A: serde::de::SeqAccess<'de>,
{
let mut steps = Vec::new();
while let Some(entry) = seq.next_element::<PipelineEntry>()? {
match entry {
PipelineEntry::Anonymous(cmd) => {
steps.push(HookStep::Single(Command::new(None, cmd)));
}
PipelineEntry::Named(map) => {
if map.is_empty() {
continue;
}
validate_no_colons(&map)?;
steps.push(map_to_step(map));
}
}
}
Ok(CommandConfig { steps })
}
fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
where
A: serde::de::MapAccess<'de>,
{
let mut entries: IndexMap<String, String> = IndexMap::new();
while let Some(key) = map.next_key::<String>()? {
let value = map.next_value::<String>()?;
entries.insert(key, value);
}
validate_no_colons(&entries)?;
let commands: Vec<Command> = entries
.into_iter()
.map(|(name, template)| Command::new(Some(name), template))
.collect();
Ok(CommandConfig {
steps: vec![HookStep::Concurrent(commands)],
})
}
}
deserializer.deserialize_any(CommandConfigVisitor)
}
}
impl JsonSchema for CommandConfig {
fn schema_name() -> std::borrow::Cow<'static, str> {
"CommandConfig".into()
}
fn json_schema(_: &mut schemars::SchemaGenerator) -> schemars::Schema {
schemars::json_schema!({
"oneOf": [
{ "type": "string" },
{
"type": "object",
"additionalProperties": { "type": "string" }
},
{
"type": "array",
"items": {
"oneOf": [
{ "type": "string" },
{
"type": "object",
"additionalProperties": { "type": "string" }
}
]
}
}
]
})
}
}
impl Serialize for CommandConfig {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
if self.steps.len() == 1
&& let HookStep::Single(cmd) = &self.steps[0]
&& cmd.name.is_none()
{
return cmd.template.serialize(serializer);
}
if self.steps.len() == 1
&& let HookStep::Concurrent(cmds) = &self.steps[0]
{
return serialize_commands_as_map(cmds, serializer);
}
use serde::ser::SerializeSeq;
let mut seq = serializer.serialize_seq(Some(self.steps.len()))?;
for step in &self.steps {
match step {
HookStep::Single(cmd) => {
if let Some(name) = &cmd.name {
let mut map = IndexMap::new();
map.insert(name.as_str(), cmd.template.as_str());
seq.serialize_element(&map)?;
} else {
seq.serialize_element(&cmd.template)?;
}
}
HookStep::Concurrent(cmds) => {
let mut map = IndexMap::new();
let mut unnamed_counter = 0u32;
for c in cmds {
let key = match &c.name {
Some(name) => name.as_str().to_string(),
None => {
unnamed_counter += 1;
format!("_{unnamed_counter}")
}
};
map.insert(key, c.template.as_str());
}
seq.serialize_element(&map)?;
}
}
}
seq.end()
}
}
fn serialize_commands_as_map<S>(cmds: &[Command], serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let mut map = serializer.serialize_map(Some(cmds.len()))?;
let mut unnamed_counter = 0u32;
for cmd in cmds {
let key = match &cmd.name {
Some(name) => name.clone(),
None => {
unnamed_counter += 1;
format!("_{unnamed_counter}")
}
};
map.serialize_entry(&key, &cmd.template)?;
}
map.end()
}
#[cfg(test)]
mod tests {
use insta::assert_snapshot;
use super::*;
#[test]
fn test_deserialize_single_string() {
let toml_str = r#"command = "npm install""#;
#[derive(Deserialize)]
struct Wrapper {
command: CommandConfig,
}
let wrapper: Wrapper = toml::from_str(toml_str).unwrap();
let commands: Vec<_> = wrapper.command.commands().collect();
assert_eq!(commands.len(), 1);
assert_eq!(commands[0].name, None);
assert_eq!(commands[0].template, "npm install");
assert_eq!(wrapper.command.steps().len(), 1);
assert!(matches!(&wrapper.command.steps()[0], HookStep::Single(_)));
}
#[test]
fn test_deserialize_named_table() {
let toml_str = r#"
[command]
build = "cargo build"
test = "cargo test"
"#;
#[derive(Deserialize)]
struct Wrapper {
command: CommandConfig,
}
let wrapper: Wrapper = toml::from_str(toml_str).unwrap();
let commands: Vec<_> = wrapper.command.commands().collect();
assert_eq!(commands.len(), 2);
assert!(commands.iter().any(|c| c.name == Some("build".to_string())));
assert!(commands.iter().any(|c| c.name == Some("test".to_string())));
assert_eq!(wrapper.command.steps().len(), 1);
assert!(matches!(
&wrapper.command.steps()[0],
HookStep::Concurrent(cmds) if cmds.len() == 2
));
}
#[test]
fn test_deserialize_preserves_order() {
let toml_str = r#"
[command]
first = "echo 1"
second = "echo 2"
third = "echo 3"
"#;
#[derive(Deserialize)]
struct Wrapper {
command: CommandConfig,
}
let wrapper: Wrapper = toml::from_str(toml_str).unwrap();
let commands: Vec<_> = wrapper.command.commands().collect();
assert_eq!(commands.len(), 3);
assert_eq!(commands[0].name, Some("first".to_string()));
assert_eq!(commands[1].name, Some("second".to_string()));
assert_eq!(commands[2].name, Some("third".to_string()));
}
#[test]
fn test_deserialize_rejects_colons_in_name() {
let toml_str = r#"
[command]
"my:server" = "npm start"
"#;
#[derive(Debug, Deserialize)]
struct Wrapper {
#[serde(rename = "command")]
_command: CommandConfig,
}
let result: Result<Wrapper, _> = toml::from_str(toml_str);
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(
err.contains("cannot contain colons"),
"Expected colon rejection error: {}",
err
);
}
#[test]
fn test_deserialize_pipeline_strings() {
let toml_str = r#"command = ["npm install", "npm run build"]"#;
#[derive(Deserialize)]
struct Wrapper {
command: CommandConfig,
}
let wrapper: Wrapper = toml::from_str(toml_str).unwrap();
assert_eq!(wrapper.command.steps().len(), 2);
assert!(
matches!(&wrapper.command.steps()[0], HookStep::Single(c) if c.template == "npm install")
);
assert!(
matches!(&wrapper.command.steps()[1], HookStep::Single(c) if c.template == "npm run build")
);
}
#[test]
fn test_deserialize_pipeline_mixed() {
let toml_str = r#"command = [
"npm install",
{ build = "npm run build", lint = "npm run lint" }
]"#;
#[derive(Deserialize)]
struct Wrapper {
command: CommandConfig,
}
let wrapper: Wrapper = toml::from_str(toml_str).unwrap();
assert_eq!(wrapper.command.steps().len(), 2);
assert!(matches!(&wrapper.command.steps()[0], HookStep::Single(c) if c.name.is_none()));
assert!(matches!(
&wrapper.command.steps()[1],
HookStep::Concurrent(cmds) if cmds.len() == 2
));
let commands: Vec<_> = wrapper.command.commands().collect();
assert_eq!(commands.len(), 3);
}
#[test]
fn test_deserialize_pipeline_named_single() {
let toml_str = r#"command = [
{ install = "npm install" },
{ build = "npm run build", lint = "npm run lint" }
]"#;
#[derive(Deserialize)]
struct Wrapper {
command: CommandConfig,
}
let wrapper: Wrapper = toml::from_str(toml_str).unwrap();
assert_eq!(wrapper.command.steps().len(), 2);
if let HookStep::Single(cmd) = &wrapper.command.steps()[0] {
assert_eq!(cmd.name.as_deref(), Some("install"));
assert_eq!(cmd.template, "npm install");
} else {
panic!("Expected Single step");
}
assert!(matches!(
&wrapper.command.steps()[1],
HookStep::Concurrent(cmds) if cmds.len() == 2
));
}
#[test]
fn test_deserialize_pipeline_rejects_colons() {
let toml_str = r#"command = [{ "my:hook" = "npm start" }]"#;
#[derive(Debug, Deserialize)]
struct Wrapper {
#[serde(rename = "command")]
_command: CommandConfig,
}
let result: Result<Wrapper, _> = toml::from_str(toml_str);
assert!(result.is_err());
}
#[derive(Debug, Deserialize)]
struct CommandWrapper {
#[serde(rename = "command")]
_command: CommandConfig,
}
fn deserialize_err(toml_str: &str) -> String {
toml::from_str::<CommandWrapper>(toml_str)
.unwrap_err()
.to_string()
}
#[test]
fn test_error_lists_accepted_forms_at_top_level() {
assert_snapshot!(deserialize_err("command = 42"), @r#"
TOML parse error at line 1, column 11
|
1 | command = 42
| ^^
invalid type: integer `42`, expected a command in one of these forms:
- a string: "cargo build"
- a named table: { build = "cargo build", test = "cargo test" }
- a pipeline list: ["cargo build", { test = "cargo test" }]
run `wt hook --help` for details
"#);
}
#[test]
fn test_error_identifies_non_string_value_in_named_table() {
assert_snapshot!(
deserialize_err(
r#"[command]
build = "cargo build"
broken = 42
"#,
),
@r#"
TOML parse error at line 3, column 10
|
3 | broken = 42
| ^^
invalid type: integer `42`, expected a string
"#
);
}
#[test]
fn test_error_describes_pipeline_entry_forms_for_wrong_type() {
assert_snapshot!(deserialize_err("command = [42]"), @r#"
TOML parse error at line 1, column 12
|
1 | command = [42]
| ^^
invalid type: integer `42`, expected a command string "cargo build" or a named table { build = "cargo build" }
"#);
}
#[test]
fn test_error_identifies_non_string_value_in_pipeline_map() {
assert_snapshot!(
deserialize_err(
r#"command = [
{ build = "cargo build", ignore_exit = true }
]"#,
),
@r#"
TOML parse error at line 2, column 44
|
2 | { build = "cargo build", ignore_exit = true }
| ^^^^
invalid type: boolean `true`, expected a string
"#
);
}
#[test]
fn test_serialize_single_unnamed() {
#[derive(Serialize)]
struct Wrapper {
cmd: CommandConfig,
}
let wrapper = Wrapper {
cmd: CommandConfig {
steps: vec![HookStep::Single(Command::new(
None,
"npm install".to_string(),
))],
},
};
assert_snapshot!(toml::to_string(&wrapper).unwrap(), @r#"cmd = "npm install""#);
}
#[test]
fn test_serialize_concurrent() {
#[derive(Serialize)]
struct Wrapper {
cmd: CommandConfig,
}
let wrapper = Wrapper {
cmd: CommandConfig {
steps: vec![HookStep::Concurrent(vec![
Command::new(Some("build".to_string()), "cargo build".to_string()),
Command::new(Some("test".to_string()), "cargo test".to_string()),
])],
},
};
assert_snapshot!(toml::to_string(&wrapper).unwrap(), @r#"
[cmd]
build = "cargo build"
test = "cargo test"
"#);
}
#[test]
fn test_serialize_pipeline() {
#[derive(Serialize)]
struct Wrapper {
cmd: CommandConfig,
}
let wrapper = Wrapper {
cmd: CommandConfig {
steps: vec![
HookStep::Single(Command::new(None, "npm install".to_string())),
HookStep::Concurrent(vec![
Command::new(Some("build".to_string()), "npm run build".to_string()),
Command::new(Some("lint".to_string()), "npm run lint".to_string()),
]),
],
},
};
assert_snapshot!(toml::to_string(&wrapper).unwrap(), @r#"cmd = ["npm install", { build = "npm run build", lint = "npm run lint" }]"#);
}
#[test]
fn test_serialize_deserialize_roundtrip_single() {
let config = CommandConfig {
steps: vec![HookStep::Single(Command::new(
None,
"echo hello".to_string(),
))],
};
#[derive(Serialize, Deserialize)]
struct Wrapper {
cmd: CommandConfig,
}
let wrapper = Wrapper { cmd: config };
let serialized = toml::to_string(&wrapper).unwrap();
let deserialized: Wrapper = toml::from_str(&serialized).unwrap();
assert_eq!(deserialized.cmd.commands().count(), 1);
assert_eq!(
deserialized.cmd.commands().next().unwrap().template,
"echo hello"
);
}
#[test]
fn test_serialize_deserialize_roundtrip_named() {
let config = CommandConfig {
steps: vec![HookStep::Concurrent(vec![
Command::new(Some("a".to_string()), "echo a".to_string()),
Command::new(Some("b".to_string()), "echo b".to_string()),
])],
};
#[derive(Serialize, Deserialize)]
struct Wrapper {
cmd: CommandConfig,
}
let wrapper = Wrapper { cmd: config };
let serialized = toml::to_string(&wrapper).unwrap();
let deserialized: Wrapper = toml::from_str(&serialized).unwrap();
assert_eq!(deserialized.cmd.commands().count(), 2);
}
#[test]
fn test_commands_flattens_pipeline() {
let config = CommandConfig {
steps: vec![
HookStep::Single(Command::new(None, "cmd1".to_string())),
HookStep::Concurrent(vec![
Command::new(Some("a".to_string()), "cmd2".to_string()),
Command::new(Some("b".to_string()), "cmd3".to_string()),
]),
HookStep::Single(Command::new(None, "cmd4".to_string())),
],
};
let cmds: Vec<_> = config.commands().collect();
assert_eq!(cmds.len(), 4);
assert_eq!(cmds[0].template, "cmd1");
assert_eq!(cmds[1].template, "cmd2");
assert_eq!(cmds[2].template, "cmd3");
assert_eq!(cmds[3].template, "cmd4");
}
#[test]
fn test_merge_append_steps() {
let base = CommandConfig {
steps: vec![HookStep::Single(Command::new(None, "step1".to_string()))],
};
let overlay = CommandConfig {
steps: vec![HookStep::Concurrent(vec![
Command::new(Some("a".to_string()), "step2a".to_string()),
Command::new(Some("b".to_string()), "step2b".to_string()),
])],
};
let merged = base.merge_append(&overlay);
assert_eq!(merged.steps.len(), 2);
assert!(matches!(&merged.steps[0], HookStep::Single(_)));
assert!(matches!(&merged.steps[1], HookStep::Concurrent(_)));
}
#[test]
fn test_serialize_mixed_named_unnamed_succeeds() {
#[derive(Serialize)]
struct Wrapper {
cmd: CommandConfig,
}
let global = CommandConfig {
steps: vec![HookStep::Single(Command::new(
None,
"npm install".to_string(),
))],
};
let per_project = CommandConfig {
steps: vec![HookStep::Concurrent(vec![Command::new(
Some("setup".to_string()),
"echo setup".to_string(),
)])],
};
let merged = global.merge_append(&per_project);
assert_eq!(merged.steps.len(), 2);
let wrapper = Wrapper { cmd: merged };
assert_snapshot!(toml::to_string(&wrapper).unwrap(), @r#"cmd = ["npm install", { setup = "echo setup" }]"#);
}
}