Skip to main content

imp_core/
guardrails.rs

1use std::path::Path;
2use std::process::Stdio;
3
4use imp_llm::truncate_chars_with_suffix;
5use project_detect::{detect_walk, ProjectKind};
6use serde::{Deserialize, Serialize};
7use tokio::process::Command;
8
9/// How strongly guardrail failures influence agent execution.
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
11#[serde(rename_all = "kebab-case")]
12pub enum GuardrailLevel {
13    /// Run checks and surface failures clearly, but do not block the turn.
14    #[default]
15    Advisory,
16    /// Run checks and treat failures as blocking.
17    Enforce,
18}
19
20/// Built-in guardrail starter profiles.
21#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
22#[serde(rename_all = "kebab-case")]
23pub enum GuardrailProfile {
24    /// Infer the profile from the current project using `project-detect`.
25    Auto,
26    /// Language-neutral fallback profile.
27    Generic,
28    /// Zig starter profile.
29    Zig,
30    /// Rust starter profile.
31    Rust,
32    /// TypeScript starter profile.
33    #[serde(rename = "typescript")]
34    TypeScript,
35    /// C / C-family build-system starter profile.
36    C,
37    /// Go starter profile.
38    Go,
39    /// Elixir starter profile.
40    Elixir,
41}
42
43impl GuardrailProfile {
44    /// Concise prompt guidance for the agent, tailored to this profile.
45    #[must_use]
46    pub fn prompt_guidance(&self) -> &'static str {
47        match self {
48            Self::Auto => Self::Generic.prompt_guidance(),
49            Self::Generic => GUIDANCE_GENERIC,
50            Self::Zig => GUIDANCE_ZIG,
51            Self::Rust => GUIDANCE_RUST,
52            Self::TypeScript => GUIDANCE_TYPESCRIPT,
53            Self::C => GUIDANCE_C,
54            Self::Go => GUIDANCE_GO,
55            Self::Elixir => GUIDANCE_ELIXIR,
56        }
57    }
58
59    /// Default after-write check commands for this profile.
60    #[must_use]
61    pub fn default_after_write(&self) -> &'static [&'static str] {
62        match self {
63            Self::Auto | Self::Generic => &[],
64            Self::Zig => &["zig fmt --check .", "zig build", "zig build test"],
65            Self::Rust => &[
66                "cargo fmt --check",
67                "cargo clippy -- -D warnings",
68                "cargo test",
69            ],
70            Self::TypeScript => &[],
71            Self::C => &[],
72            Self::Go => &["gofmt -l .", "go vet ./...", "go test ./..."],
73            Self::Elixir => &[
74                "mix format --check-formatted",
75                "mix compile --warnings-as-errors",
76                "mix test",
77            ],
78        }
79    }
80
81    /// Resolve a detected project kind to the nearest built-in profile.
82    #[must_use]
83    pub fn from_project_kind(kind: &ProjectKind) -> Self {
84        match kind {
85            ProjectKind::Zig => Self::Zig,
86            ProjectKind::Cargo => Self::Rust,
87            ProjectKind::Go => Self::Go,
88            ProjectKind::Elixir { .. } => Self::Elixir,
89            ProjectKind::Node { .. } => Self::TypeScript,
90            ProjectKind::CMake | ProjectKind::Meson | ProjectKind::Make => Self::C,
91            _ => Self::Generic,
92        }
93    }
94}
95
96/// Configurable engineering guardrails for agent-time guidance and checks.
97#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
98pub struct GuardrailConfig {
99    /// Master switch. `None` means "use the default".
100    pub enabled: Option<bool>,
101    /// Advisory vs blocking behavior.
102    pub level: Option<GuardrailLevel>,
103    /// Built-in profile selection.
104    pub profile: Option<GuardrailProfile>,
105    /// File globs that should trigger guardrail checks after writes.
106    pub critical_paths: Option<Vec<String>>,
107    /// Commands to run after writes. `None` means use profile defaults.
108    pub after_write: Option<Vec<String>>,
109}
110
111impl GuardrailConfig {
112    /// Returns whether guardrails are enabled.
113    #[must_use]
114    pub fn is_enabled(&self) -> bool {
115        self.enabled.unwrap_or(false)
116    }
117
118    /// Returns the effective configured level.
119    #[must_use]
120    pub fn effective_level(&self) -> GuardrailLevel {
121        self.level.unwrap_or_default()
122    }
123
124    /// Returns the configured profile before auto-detection.
125    #[must_use]
126    pub fn configured_profile(&self) -> GuardrailProfile {
127        self.profile.unwrap_or(GuardrailProfile::Generic)
128    }
129
130    /// Resolve the effective profile for a path.
131    #[must_use]
132    pub fn resolve_effective_profile(&self, cwd: &Path) -> GuardrailProfile {
133        match self.configured_profile() {
134            GuardrailProfile::Auto => detect_walk(cwd)
135                .map(|(kind, _)| GuardrailProfile::from_project_kind(&kind))
136                .unwrap_or(GuardrailProfile::Generic),
137            profile => profile,
138        }
139    }
140
141    /// Check whether a file path should trigger guardrail after-write checks.
142    #[must_use]
143    pub fn should_check_path(&self, path: &Path) -> bool {
144        match &self.critical_paths {
145            None => true,
146            Some(patterns) if patterns.is_empty() => true,
147            Some(patterns) => {
148                let path_str = path.to_string_lossy();
149                patterns.iter().any(|pat| {
150                    glob::Pattern::new(pat)
151                        .map(|g| g.matches(&path_str))
152                        .unwrap_or(false)
153                })
154            }
155        }
156    }
157
158    /// Merge another guardrail config into this one.
159    pub fn merge(&mut self, other: GuardrailConfig) {
160        if other.enabled.is_some() {
161            self.enabled = other.enabled;
162        }
163        if other.level.is_some() {
164            self.level = other.level;
165        }
166        if other.profile.is_some() {
167            self.profile = other.profile;
168        }
169        if other.critical_paths.is_some() {
170            self.critical_paths = other.critical_paths;
171        }
172        if other.after_write.is_some() {
173            self.after_write = other.after_write;
174        }
175    }
176}
177
178/// Assemble the guardrails prompt layer for a resolved profile.
179#[must_use]
180pub fn guardrails_layer(profile: GuardrailProfile) -> String {
181    let mut s = String::from("## Engineering Guardrails\n\n");
182    s.push_str(profile.prompt_guidance());
183    s
184}
185
186/// Result of running a single guardrail check command.
187#[derive(Debug, Clone)]
188pub struct CheckResult {
189    pub command: String,
190    pub success: bool,
191    pub output: String,
192}
193
194/// Run guardrail after-write check commands and collect results.
195pub async fn run_after_write_checks(
196    config: &GuardrailConfig,
197    effective_profile: GuardrailProfile,
198    cwd: &Path,
199) -> Vec<CheckResult> {
200    let commands: Vec<String> = match &config.after_write {
201        Some(cmds) if !cmds.is_empty() => cmds.clone(),
202        _ => effective_profile
203            .default_after_write()
204            .iter()
205            .map(|s| (*s).to_string())
206            .collect(),
207    };
208
209    let mut results = Vec::new();
210    for cmd in &commands {
211        let result = Command::new("sh")
212            .arg("-c")
213            .arg(cmd)
214            .current_dir(cwd)
215            .stdin(Stdio::null())
216            .stdout(Stdio::piped())
217            .stderr(Stdio::piped())
218            .output()
219            .await;
220
221        match result {
222            Ok(output) => {
223                let stdout = String::from_utf8_lossy(&output.stdout);
224                let stderr = String::from_utf8_lossy(&output.stderr);
225                let combined = if stderr.is_empty() {
226                    stdout.to_string()
227                } else {
228                    format!("{stdout}{stderr}")
229                };
230                // Truncate to avoid flooding context
231                let truncated = if combined.len() > 2000 {
232                    format!(
233                        "{}\n... (truncated)",
234                        truncate_chars_with_suffix(&combined, 2000, "")
235                    )
236                } else {
237                    combined
238                };
239                results.push(CheckResult {
240                    command: cmd.clone(),
241                    success: output.status.success(),
242                    output: truncated,
243                });
244            }
245            Err(e) => {
246                results.push(CheckResult {
247                    command: cmd.clone(),
248                    success: false,
249                    output: format!("Failed to run: {e}"),
250                });
251            }
252        }
253    }
254    results
255}
256
257/// Format check results into a message for the agent.
258#[must_use]
259pub fn format_check_results(results: &[CheckResult], level: GuardrailLevel) -> String {
260    if results.is_empty() {
261        return String::new();
262    }
263
264    let all_passed = results.iter().all(|r| r.success);
265    if all_passed {
266        return "Guardrail checks passed.".to_string();
267    }
268
269    let mut s = match level {
270        GuardrailLevel::Enforce => {
271            String::from("⚠ GUARDRAIL CHECK FAILED (enforce mode — fix before proceeding):\n")
272        }
273        GuardrailLevel::Advisory => {
274            String::from("⚠ Guardrail check failed (advisory — review before continuing):\n")
275        }
276    };
277
278    for r in results {
279        if !r.success {
280            s.push_str(&format!("\n  Command: {}\n", r.command));
281            if !r.output.is_empty() {
282                for line in r.output.lines().take(20) {
283                    s.push_str(&format!("    {line}\n"));
284                }
285            }
286        }
287    }
288    s
289}
290
291// -- Prompt guidance text per profile ----------------------------------------
292
293const GUIDANCE_GENERIC: &str = "\
294- Prefer the smallest, local fix over a cross-file refactor.
295- Search for existing patterns first; mirror naming, error handling, and conventions.
296- Keep control flow straightforward and easy to follow.
297- Keep loops, retries, and timeouts bounded.
298- Make error handling explicit — don't silently ignore failures.
299- Leave code warning-free and easy to verify.
300- Don't add new dependencies without explicit user approval.
301";
302
303const GUIDANCE_ZIG: &str = "\
304- Keep control flow straightforward and easy to follow.
305- Keep loops, retries, and buffers bounded.
306- Handle errors explicitly with try/catch — avoid casual catch unreachable.
307- Keep allocator ownership and lifetime clear.
308- Prefer small, readable functions with minimal hidden control flow.
309- Leave code formatted, buildable, and warning-free.
310";
311
312const GUIDANCE_RUST: &str = "\
313- Keep control flow straightforward and easy to follow.
314- Keep loops, retries, and timeouts bounded.
315- Use Result with meaningful error propagation — avoid unwrap() in non-test code.
316- Keep async behavior bounded and timeouts explicit.
317- Prefer small, focused changes over broad rewrites.
318- Leave code clippy-clean with zero warnings.
319";
320
321const GUIDANCE_TYPESCRIPT: &str = "\
322- Keep control flow straightforward and easy to follow.
323- Keep loops, retries, and timeouts bounded.
324- Make error handling explicit — don't silently swallow rejections or errors.
325- Use strict typing — avoid any unless justified.
326- Keep async/Promise flows bounded and understandable.
327- Leave typecheck and lint status clean.
328";
329
330const GUIDANCE_C: &str = "\
331- Keep control flow straightforward and easy to follow.
332- Keep loops, retries, and buffer sizes bounded.
333- Make error handling explicit — check return values.
334- Keep pointer usage straightforward and well-scoped.
335- Avoid preprocessor complexity when simpler code works.
336- Leave build and test status clean.
337";
338
339const GUIDANCE_GO: &str = "\
340- Keep control flow straightforward and easy to follow.
341- Keep loops, retries, and timeouts bounded.
342- Check and propagate errors explicitly — don't ignore returned errors.
343- Keep goroutine lifecycle and cancellation understandable.
344- Prefer small functions and direct control flow.
345- Leave formatting and vet status clean.
346";
347
348const GUIDANCE_ELIXIR: &str = "\
349- Keep control flow straightforward and easy to follow.
350- Keep retries and message flows bounded.
351- Keep process and supervision boundaries clear.
352- Handle {:ok, value} / {:error, reason} tuples explicitly.
353- Avoid hiding important behavior in opaque control flow.
354- Leave formatting and compilation warnings-free.
355";
356
357#[cfg(test)]
358mod tests {
359    use super::*;
360    use serde::Deserialize;
361    use tempfile::TempDir;
362
363    #[derive(Debug, Deserialize)]
364    struct GuardrailToml {
365        guardrails: GuardrailConfig,
366    }
367
368    #[test]
369    fn guardrail_toml_deserializes() {
370        let parsed: GuardrailToml = toml::from_str(
371            r#"
372[guardrails]
373enabled = true
374level = "enforce"
375profile = "zig"
376critical_paths = ["src/**", "lib/**"]
377after_write = ["zig fmt --check .", "zig build"]
378"#,
379        )
380        .unwrap();
381
382        assert_eq!(parsed.guardrails.enabled, Some(true));
383        assert_eq!(parsed.guardrails.level, Some(GuardrailLevel::Enforce));
384        assert_eq!(parsed.guardrails.profile, Some(GuardrailProfile::Zig));
385        assert_eq!(
386            parsed.guardrails.critical_paths,
387            Some(vec!["src/**".into(), "lib/**".into()])
388        );
389        assert_eq!(
390            parsed.guardrails.after_write,
391            Some(vec!["zig fmt --check .".into(), "zig build".into()])
392        );
393    }
394
395    #[test]
396    fn guardrail_auto_profile_resolves_zig() {
397        let dir = TempDir::new().unwrap();
398        std::fs::write(dir.path().join("build.zig"), "").unwrap();
399
400        let config = GuardrailConfig {
401            profile: Some(GuardrailProfile::Auto),
402            ..Default::default()
403        };
404
405        assert_eq!(
406            config.resolve_effective_profile(dir.path()),
407            GuardrailProfile::Zig
408        );
409    }
410
411    #[test]
412    fn guardrail_auto_profile_resolves_rust_from_subdirectory() {
413        let dir = TempDir::new().unwrap();
414        std::fs::write(
415            dir.path().join("Cargo.toml"),
416            "[package]\nname='x'\nversion='0.1.0'\n",
417        )
418        .unwrap();
419        let nested = dir.path().join("src").join("nested");
420        std::fs::create_dir_all(&nested).unwrap();
421
422        let config = GuardrailConfig {
423            profile: Some(GuardrailProfile::Auto),
424            ..Default::default()
425        };
426
427        assert_eq!(
428            config.resolve_effective_profile(&nested),
429            GuardrailProfile::Rust
430        );
431    }
432
433    #[test]
434    fn guardrail_auto_profile_resolves_go() {
435        let dir = TempDir::new().unwrap();
436        std::fs::write(dir.path().join("go.mod"), "module example.com/test\n").unwrap();
437
438        let config = GuardrailConfig {
439            profile: Some(GuardrailProfile::Auto),
440            ..Default::default()
441        };
442
443        assert_eq!(
444            config.resolve_effective_profile(dir.path()),
445            GuardrailProfile::Go
446        );
447    }
448
449    #[test]
450    fn guardrail_auto_profile_resolves_elixir() {
451        let dir = TempDir::new().unwrap();
452        std::fs::write(
453            dir.path().join("mix.exs"),
454            "defmodule Demo.MixProject do end\n",
455        )
456        .unwrap();
457
458        let config = GuardrailConfig {
459            profile: Some(GuardrailProfile::Auto),
460            ..Default::default()
461        };
462
463        assert_eq!(
464            config.resolve_effective_profile(dir.path()),
465            GuardrailProfile::Elixir
466        );
467    }
468
469    #[test]
470    fn guardrail_auto_profile_falls_back_to_generic() {
471        let dir = TempDir::new().unwrap();
472        let config = GuardrailConfig {
473            profile: Some(GuardrailProfile::Auto),
474            ..Default::default()
475        };
476
477        assert_eq!(
478            config.resolve_effective_profile(dir.path()),
479            GuardrailProfile::Generic
480        );
481    }
482
483    #[test]
484    fn guardrail_prompt_guidance_varies_by_profile() {
485        let zig = GuardrailProfile::Zig.prompt_guidance();
486        let rust = GuardrailProfile::Rust.prompt_guidance();
487        let generic = GuardrailProfile::Generic.prompt_guidance();
488
489        assert!(zig.contains("catch unreachable"));
490        assert!(zig.contains("allocator"));
491        assert!(rust.contains("clippy"));
492        assert!(rust.contains("unwrap"));
493        assert!(generic.contains("warning-free"));
494        assert_ne!(zig, rust);
495        assert_ne!(zig, generic);
496    }
497
498    #[test]
499    fn guardrail_default_after_write_zig() {
500        let cmds = GuardrailProfile::Zig.default_after_write();
501        assert_eq!(cmds.len(), 3);
502        assert!(cmds[0].contains("zig fmt"));
503    }
504
505    #[test]
506    fn guardrail_default_after_write_generic_is_empty() {
507        assert!(GuardrailProfile::Generic.default_after_write().is_empty());
508    }
509
510    #[test]
511    fn guardrail_layer_contains_header() {
512        let layer = guardrails_layer(GuardrailProfile::Zig);
513        assert!(layer.starts_with("## Engineering Guardrails"));
514        assert!(layer.contains("catch unreachable"));
515    }
516
517    #[test]
518    fn guardrail_format_check_results_all_passed() {
519        let results = vec![CheckResult {
520            command: "zig build".into(),
521            success: true,
522            output: String::new(),
523        }];
524        let msg = format_check_results(&results, GuardrailLevel::Advisory);
525        assert_eq!(msg, "Guardrail checks passed.");
526    }
527
528    #[test]
529    fn guardrail_format_check_results_failure_enforce() {
530        let results = vec![CheckResult {
531            command: "cargo clippy".into(),
532            success: false,
533            output: "warning: unused variable".into(),
534        }];
535        let msg = format_check_results(&results, GuardrailLevel::Enforce);
536        assert!(msg.contains("GUARDRAIL CHECK FAILED"));
537        assert!(msg.contains("enforce"));
538        assert!(msg.contains("cargo clippy"));
539    }
540
541    #[test]
542    fn guardrail_format_check_results_failure_advisory() {
543        let results = vec![CheckResult {
544            command: "mix test".into(),
545            success: false,
546            output: "1 test failed".into(),
547        }];
548        let msg = format_check_results(&results, GuardrailLevel::Advisory);
549        assert!(msg.contains("advisory"));
550        assert!(msg.contains("mix test"));
551    }
552
553    #[test]
554    fn guardrail_merge_only_overrides_present_fields() {
555        let mut base = GuardrailConfig {
556            enabled: Some(true),
557            level: Some(GuardrailLevel::Advisory),
558            profile: Some(GuardrailProfile::Rust),
559            critical_paths: Some(vec!["src/**".into()]),
560            after_write: None,
561        };
562
563        let overlay = GuardrailConfig {
564            enabled: None,
565            level: Some(GuardrailLevel::Enforce),
566            profile: None,
567            critical_paths: None,
568            after_write: Some(vec!["cargo test".into()]),
569        };
570
571        base.merge(overlay);
572
573        assert_eq!(base.enabled, Some(true));
574        assert_eq!(base.level, Some(GuardrailLevel::Enforce));
575        assert_eq!(base.profile, Some(GuardrailProfile::Rust));
576        assert_eq!(base.critical_paths, Some(vec!["src/**".into()]));
577        assert_eq!(base.after_write, Some(vec!["cargo test".into()]));
578    }
579}