use std::collections::HashSet;
use std::fs;
use std::path::Path;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use thiserror::Error;
use super::prompt::PromptDoc;
pub const DEFAULT_PLAN_NAME: &str = "default";
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct GrindPlan {
#[serde(skip)]
pub name: String,
#[serde(default)]
pub prompts: Vec<PlanPromptRef>,
#[serde(default = "default_max_parallel")]
pub max_parallel: u32,
#[serde(default)]
pub hooks: Hooks,
#[serde(default)]
pub budgets: PlanBudgets,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct PlanPromptRef {
pub name: String,
#[serde(default)]
pub weight_override: Option<u32>,
#[serde(default)]
pub every_override: Option<u32>,
#[serde(default)]
pub max_runs_override: Option<u32>,
}
#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(default, deny_unknown_fields)]
pub struct Hooks {
pub pre_session: Option<String>,
pub post_session: Option<String>,
pub on_failure: Option<String>,
}
#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
#[serde(default, deny_unknown_fields)]
pub struct PlanBudgets {
pub max_iterations: Option<u32>,
pub until: Option<DateTime<Utc>>,
pub max_cost_usd: Option<f64>,
pub max_tokens: Option<u64>,
}
fn default_max_parallel() -> u32 {
1
}
#[derive(Debug, Error)]
pub enum PlanLoadError {
#[error("failed to read plan file {path}: {source}")]
Io {
path: String,
#[source]
source: std::io::Error,
},
#[error("plan path has no UTF-8 file stem: {path}")]
MissingName {
path: String,
},
#[error("{path}: malformed plan: {message}")]
Malformed {
path: String,
message: String,
},
#[error("{path}: duplicate prompt entry {name:?}")]
DuplicatePrompt {
path: String,
name: String,
},
#[error("{path}: invalid plan: {message}")]
Invalid {
path: String,
message: String,
},
}
impl PartialEq for PlanLoadError {
fn eq(&self, other: &Self) -> bool {
use PlanLoadError::*;
match (self, other) {
(Io { path: a, .. }, Io { path: b, .. }) => a == b,
(MissingName { path: a }, MissingName { path: b }) => a == b,
(
Malformed {
path: a,
message: am,
},
Malformed {
path: b,
message: bm,
},
) => a == b && am == bm,
(DuplicatePrompt { path: a, name: an }, DuplicatePrompt { path: b, name: bn }) => {
a == b && an == bn
}
(
Invalid {
path: a,
message: am,
},
Invalid {
path: b,
message: bm,
},
) => a == b && am == bm,
_ => false,
}
}
}
#[derive(Debug, Error, PartialEq, Eq)]
pub enum PlanValidationError {
#[error("plan {plan:?} references unknown prompt {prompt:?}")]
UnknownPrompt {
plan: String,
prompt: String,
},
}
pub fn load_plan(path: &Path) -> Result<GrindPlan, PlanLoadError> {
let display = path.display().to_string();
let raw = fs::read_to_string(path).map_err(|e| PlanLoadError::Io {
path: display.clone(),
source: e,
})?;
let name = path
.file_stem()
.and_then(|s| s.to_str())
.ok_or_else(|| PlanLoadError::MissingName {
path: display.clone(),
})?
.to_string();
parse_plan_str(&raw, name, &display)
}
pub fn parse_plan_str(raw: &str, name: String, display: &str) -> Result<GrindPlan, PlanLoadError> {
let mut plan: GrindPlan = toml::from_str(raw).map_err(|e| PlanLoadError::Malformed {
path: display.to_string(),
message: one_line(&e.to_string()),
})?;
plan.name = name;
if plan.max_parallel == 0 {
return Err(PlanLoadError::Invalid {
path: display.to_string(),
message: "max_parallel must be >= 1".to_string(),
});
}
let mut seen: HashSet<&str> = HashSet::new();
for entry in &plan.prompts {
if !seen.insert(entry.name.as_str()) {
return Err(PlanLoadError::DuplicatePrompt {
path: display.to_string(),
name: entry.name.clone(),
});
}
if entry.weight_override == Some(0) {
return Err(PlanLoadError::Invalid {
path: display.to_string(),
message: format!("prompts[{:?}].weight_override must be >= 1", entry.name),
});
}
if entry.every_override == Some(0) {
return Err(PlanLoadError::Invalid {
path: display.to_string(),
message: format!("prompts[{:?}].every_override must be >= 1", entry.name),
});
}
}
Ok(plan)
}
fn one_line(s: &str) -> String {
s.lines().next().unwrap_or(s).trim().to_string()
}
pub fn default_plan_from_dir(prompts: &[PromptDoc]) -> GrindPlan {
let mut seen: HashSet<&str> = HashSet::new();
let mut refs: Vec<PlanPromptRef> = Vec::new();
for p in prompts {
if seen.insert(p.meta.name.as_str()) {
refs.push(PlanPromptRef {
name: p.meta.name.clone(),
weight_override: None,
every_override: None,
max_runs_override: None,
});
}
}
GrindPlan {
name: DEFAULT_PLAN_NAME.to_string(),
prompts: refs,
max_parallel: 1,
hooks: Hooks::default(),
budgets: PlanBudgets::default(),
}
}
impl GrindPlan {
pub fn validate_against(&self, prompts: &[PromptDoc]) -> Result<(), PlanValidationError> {
let names: HashSet<&str> = prompts.iter().map(|p| p.meta.name.as_str()).collect();
for entry in &self.prompts {
if !names.contains(entry.name.as_str()) {
return Err(PlanValidationError::UnknownPrompt {
plan: self.name.clone(),
prompt: entry.name.clone(),
});
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::grind::prompt::{PromptMeta, PromptSource};
use std::path::PathBuf;
fn fake_prompt(name: &str) -> PromptDoc {
PromptDoc {
meta: PromptMeta {
name: name.to_string(),
description: "desc".to_string(),
weight: 1,
every: 1,
max_runs: None,
verify: false,
parallel_safe: false,
tags: Vec::new(),
max_session_seconds: None,
max_session_cost_usd: None,
},
body: String::new(),
source_path: PathBuf::from(format!("/fixture/{name}.md")),
source_kind: PromptSource::Project,
}
}
fn parse(raw: &str, name: &str) -> Result<GrindPlan, PlanLoadError> {
parse_plan_str(raw, name.to_string(), "/fixture/plan.toml")
}
#[test]
fn full_plan_round_trips() {
let raw = r#"
max_parallel = 4
[[prompts]]
name = "fp-hunter"
weight_override = 5
every_override = 2
max_runs_override = 10
[[prompts]]
name = "triage"
[hooks]
pre_session = "echo start"
post_session = "echo done"
on_failure = "echo fail"
[budgets]
max_iterations = 50
until = "2026-05-01T00:00:00Z"
max_cost_usd = 5.0
max_tokens = 1000000
"#;
let plan = parse(raw, "fp-cleanup").expect("parse should succeed");
assert_eq!(plan.name, "fp-cleanup");
assert_eq!(plan.max_parallel, 4);
assert_eq!(plan.prompts.len(), 2);
assert_eq!(plan.prompts[0].name, "fp-hunter");
assert_eq!(plan.prompts[0].weight_override, Some(5));
assert_eq!(plan.prompts[0].every_override, Some(2));
assert_eq!(plan.prompts[0].max_runs_override, Some(10));
assert_eq!(plan.prompts[1].name, "triage");
assert_eq!(plan.prompts[1].weight_override, None);
assert_eq!(plan.hooks.pre_session.as_deref(), Some("echo start"));
assert_eq!(plan.hooks.post_session.as_deref(), Some("echo done"));
assert_eq!(plan.hooks.on_failure.as_deref(), Some("echo fail"));
assert_eq!(plan.budgets.max_iterations, Some(50));
assert_eq!(plan.budgets.max_cost_usd, Some(5.0));
assert_eq!(plan.budgets.max_tokens, Some(1_000_000));
assert!(plan.budgets.until.is_some());
}
#[test]
fn empty_body_yields_defaults_with_supplied_name() {
let plan = parse("", "empty").expect("parse should succeed");
assert_eq!(plan.name, "empty");
assert_eq!(plan.max_parallel, 1);
assert!(plan.prompts.is_empty());
assert_eq!(plan.hooks, Hooks::default());
assert_eq!(plan.budgets, PlanBudgets::default());
}
#[test]
fn duplicate_prompt_name_is_rejected() {
let raw = r#"
[[prompts]]
name = "fp-hunter"
[[prompts]]
name = "fp-hunter"
"#;
let err = parse(raw, "p").unwrap_err();
match err {
PlanLoadError::DuplicatePrompt { name, .. } => assert_eq!(name, "fp-hunter"),
other => panic!("expected DuplicatePrompt, got {other:?}"),
}
}
#[test]
fn name_in_body_is_rejected_as_unknown_field() {
let raw = "name = \"oops\"\n";
let err = parse(raw, "real-name").unwrap_err();
assert!(matches!(err, PlanLoadError::Malformed { .. }));
}
#[test]
fn unknown_top_level_key_is_rejected() {
let raw = "frobnicate = 7\n";
let err = parse(raw, "p").unwrap_err();
assert!(matches!(err, PlanLoadError::Malformed { .. }));
}
#[test]
fn malformed_toml_is_rejected() {
let raw = "[[prompts\nname = 'broken'\n";
let err = parse(raw, "p").unwrap_err();
assert!(matches!(err, PlanLoadError::Malformed { .. }));
}
#[test]
fn weight_override_zero_is_rejected() {
let raw = r#"
[[prompts]]
name = "fp-hunter"
weight_override = 0
"#;
let err = parse(raw, "p").unwrap_err();
match err {
PlanLoadError::Invalid { message, .. } => {
assert!(message.contains("weight_override"), "msg: {message}");
}
other => panic!("expected Invalid, got {other:?}"),
}
}
#[test]
fn every_override_zero_is_rejected() {
let raw = r#"
[[prompts]]
name = "fp-hunter"
every_override = 0
"#;
let err = parse(raw, "p").unwrap_err();
match err {
PlanLoadError::Invalid { message, .. } => {
assert!(message.contains("every_override"), "msg: {message}");
}
other => panic!("expected Invalid, got {other:?}"),
}
}
#[test]
fn max_parallel_zero_is_rejected() {
let raw = "max_parallel = 0\n";
let err = parse(raw, "p").unwrap_err();
match err {
PlanLoadError::Invalid { message, .. } => {
assert!(message.contains("max_parallel"), "msg: {message}");
}
other => panic!("expected Invalid, got {other:?}"),
}
}
#[test]
fn load_plan_uses_file_stem_as_name() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("nightly-cleanup.toml");
std::fs::write(&path, "max_parallel = 2\n").unwrap();
let plan = load_plan(&path).expect("load should succeed");
assert_eq!(plan.name, "nightly-cleanup");
assert_eq!(plan.max_parallel, 2);
}
#[test]
fn load_plan_reports_io_error_for_missing_path() {
let err = load_plan(Path::new("/no/such/plan.toml")).unwrap_err();
assert!(matches!(err, PlanLoadError::Io { .. }));
}
#[test]
fn default_plan_synthesizes_one_entry_per_prompt() {
let prompts = vec![
fake_prompt("alpha"),
fake_prompt("bravo"),
fake_prompt("charlie"),
];
let plan = default_plan_from_dir(&prompts);
assert_eq!(plan.name, DEFAULT_PLAN_NAME);
assert_eq!(plan.max_parallel, 1);
assert_eq!(plan.hooks, Hooks::default());
assert_eq!(plan.budgets, PlanBudgets::default());
let names: Vec<&str> = plan.prompts.iter().map(|r| r.name.as_str()).collect();
assert_eq!(names, vec!["alpha", "bravo", "charlie"]);
for r in &plan.prompts {
assert_eq!(r.weight_override, None);
assert_eq!(r.every_override, None);
assert_eq!(r.max_runs_override, None);
}
}
#[test]
fn default_plan_handles_empty_prompt_set() {
let plan = default_plan_from_dir(&[]);
assert_eq!(plan.name, DEFAULT_PLAN_NAME);
assert!(plan.prompts.is_empty());
}
#[test]
fn validate_against_accepts_known_prompts() {
let prompts = vec![fake_prompt("alpha"), fake_prompt("bravo")];
let plan = default_plan_from_dir(&prompts);
plan.validate_against(&prompts).unwrap();
}
#[test]
fn validate_against_rejects_unknown_prompt() {
let prompts = vec![fake_prompt("alpha")];
let raw = r#"
[[prompts]]
name = "ghost"
"#;
let plan = parse(raw, "p").unwrap();
let err = plan.validate_against(&prompts).unwrap_err();
assert_eq!(
err,
PlanValidationError::UnknownPrompt {
plan: "p".to_string(),
prompt: "ghost".to_string(),
}
);
}
#[test]
fn plan_round_trips_through_serialize_and_load() {
let original = GrindPlan {
name: "round-trip".to_string(),
prompts: vec![
PlanPromptRef {
name: "alpha".to_string(),
weight_override: Some(2),
every_override: None,
max_runs_override: Some(7),
},
PlanPromptRef {
name: "bravo".to_string(),
weight_override: None,
every_override: Some(3),
max_runs_override: None,
},
],
max_parallel: 3,
hooks: Hooks {
pre_session: Some("setup".to_string()),
post_session: None,
on_failure: Some("page-oncall".to_string()),
},
budgets: PlanBudgets {
max_iterations: Some(20),
until: Some("2026-05-01T00:00:00Z".parse::<DateTime<Utc>>().unwrap()),
max_cost_usd: Some(2.25),
max_tokens: Some(500_000),
},
};
let body = toml::to_string(&original).expect("serialize plan");
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("round-trip.toml");
std::fs::write(&path, body).unwrap();
let reparsed = load_plan(&path).expect("load round-trip plan");
assert_eq!(reparsed, original);
}
}