Skip to main content

imp_core/
guardrails.rs

1use std::path::Path;
2use std::process::Stdio;
3
4use imp_llm::truncate_chars_with_suffix;
5use project_detect::{detect_walk, ProjectKind};
6use serde::{Deserialize, Serialize};
7use tokio::process::Command;
8
9/// How strongly guardrail failures influence agent execution.
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
11#[serde(rename_all = "kebab-case")]
12pub enum GuardrailLevel {
13    /// Run checks and surface failures clearly, but do not block the turn.
14    #[default]
15    Advisory,
16    /// Run checks and treat failures as blocking.
17    Enforce,
18}
19
20/// Built-in guardrail starter profiles.
21#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
22#[serde(rename_all = "kebab-case")]
23pub enum GuardrailProfile {
24    /// Infer the profile from the current project using `project-detect`.
25    Auto,
26    /// Language-neutral fallback profile.
27    Generic,
28    /// Zig starter profile.
29    Zig,
30    /// Rust starter profile.
31    Rust,
32    /// TypeScript starter profile.
33    #[serde(rename = "typescript")]
34    TypeScript,
35    /// C / C-family build-system starter profile.
36    C,
37    /// Go starter profile.
38    Go,
39    /// Elixir starter profile.
40    Elixir,
41    /// Kotlin starter profile.
42    Kotlin,
43}
44
45impl GuardrailProfile {
46    /// Concise prompt guidance for the agent, tailored to this profile.
47    #[must_use]
48    pub fn prompt_guidance(&self) -> &'static str {
49        match self {
50            Self::Auto => Self::Generic.prompt_guidance(),
51            Self::Generic => GUIDANCE_GENERIC,
52            Self::Zig => GUIDANCE_ZIG,
53            Self::Rust => GUIDANCE_RUST,
54            Self::TypeScript => GUIDANCE_TYPESCRIPT,
55            Self::C => GUIDANCE_C,
56            Self::Go => GUIDANCE_GO,
57            Self::Elixir => GUIDANCE_ELIXIR,
58            Self::Kotlin => GUIDANCE_KOTLIN,
59        }
60    }
61
62    /// Default after-write check commands for this profile.
63    #[must_use]
64    pub fn default_after_write(&self) -> &'static [&'static str] {
65        match self {
66            Self::Auto | Self::Generic => &[],
67            Self::Zig => &["zig fmt --check .", "zig build", "zig build test"],
68            Self::Rust => &[
69                "cargo fmt --check",
70                "cargo clippy -- -D warnings",
71                "cargo test",
72            ],
73            Self::TypeScript => &[],
74            Self::C => &[],
75            Self::Go => &["gofmt -l .", "go vet ./...", "go test ./..."],
76            Self::Elixir => &[
77                "mix format --check-formatted",
78                "mix compile --warnings-as-errors",
79                "mix test",
80            ],
81            Self::Kotlin => &[],
82        }
83    }
84
85    /// Resolve a detected project kind to the nearest built-in profile.
86    #[must_use]
87    pub fn from_project_kind(kind: &ProjectKind) -> Self {
88        match kind {
89            ProjectKind::Zig => Self::Zig,
90            ProjectKind::Cargo => Self::Rust,
91            ProjectKind::Go => Self::Go,
92            ProjectKind::Elixir { .. } => Self::Elixir,
93            ProjectKind::Kotlin { .. } | ProjectKind::Gradle { .. } | ProjectKind::Maven => {
94                Self::Kotlin
95            }
96            ProjectKind::Node { .. } => Self::TypeScript,
97            ProjectKind::CMake | ProjectKind::Meson | ProjectKind::Make => Self::C,
98            _ => Self::Generic,
99        }
100    }
101}
102
103/// Configurable engineering guardrails for agent-time guidance and checks.
104#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
105pub struct GuardrailConfig {
106    /// Master switch. `None` means "use the default".
107    pub enabled: Option<bool>,
108    /// Advisory vs blocking behavior.
109    pub level: Option<GuardrailLevel>,
110    /// Built-in profile selection.
111    pub profile: Option<GuardrailProfile>,
112    /// File globs that should trigger guardrail checks after writes.
113    pub critical_paths: Option<Vec<String>>,
114    /// Commands to run after writes. `None` means use profile defaults.
115    pub after_write: Option<Vec<String>>,
116}
117
118impl GuardrailConfig {
119    /// Returns whether guardrails are enabled.
120    #[must_use]
121    pub fn is_enabled(&self) -> bool {
122        self.enabled.unwrap_or(false)
123    }
124
125    /// Returns the effective configured level.
126    #[must_use]
127    pub fn effective_level(&self) -> GuardrailLevel {
128        self.level.unwrap_or_default()
129    }
130
131    /// Returns the configured profile before auto-detection.
132    #[must_use]
133    pub fn configured_profile(&self) -> GuardrailProfile {
134        self.profile.unwrap_or(GuardrailProfile::Generic)
135    }
136
137    /// Resolve the effective profile for a path.
138    #[must_use]
139    pub fn resolve_effective_profile(&self, cwd: &Path) -> GuardrailProfile {
140        match self.configured_profile() {
141            GuardrailProfile::Auto => detect_walk(cwd)
142                .map(|(kind, _)| GuardrailProfile::from_project_kind(&kind))
143                .unwrap_or(GuardrailProfile::Generic),
144            profile => profile,
145        }
146    }
147
148    /// Check whether a file path should trigger guardrail after-write checks.
149    #[must_use]
150    pub fn should_check_path(&self, path: &Path) -> bool {
151        match &self.critical_paths {
152            None => true,
153            Some(patterns) if patterns.is_empty() => true,
154            Some(patterns) => {
155                let path_str = path.to_string_lossy();
156                patterns.iter().any(|pat| {
157                    glob::Pattern::new(pat)
158                        .map(|g| g.matches(&path_str))
159                        .unwrap_or(false)
160                })
161            }
162        }
163    }
164
165    /// Merge another guardrail config into this one.
166    pub fn merge(&mut self, other: GuardrailConfig) {
167        if other.enabled.is_some() {
168            self.enabled = other.enabled;
169        }
170        if other.level.is_some() {
171            self.level = other.level;
172        }
173        if other.profile.is_some() {
174            self.profile = other.profile;
175        }
176        if other.critical_paths.is_some() {
177            self.critical_paths = other.critical_paths;
178        }
179        if other.after_write.is_some() {
180            self.after_write = other.after_write;
181        }
182    }
183}
184
185/// Assemble the guardrails prompt layer for a resolved profile.
186#[must_use]
187pub fn guardrails_layer(profile: GuardrailProfile) -> String {
188    let mut s = String::from("## Engineering Guardrails\n\n");
189    s.push_str(profile.prompt_guidance());
190    s
191}
192
193/// Result of running a single guardrail check command.
194#[derive(Debug, Clone)]
195pub struct CheckResult {
196    pub command: String,
197    pub success: bool,
198    pub output: String,
199}
200
201/// Run guardrail after-write check commands and collect results.
202pub async fn run_after_write_checks(
203    config: &GuardrailConfig,
204    effective_profile: GuardrailProfile,
205    cwd: &Path,
206) -> Vec<CheckResult> {
207    let commands: Vec<String> = match &config.after_write {
208        Some(cmds) if !cmds.is_empty() => cmds.clone(),
209        _ => effective_profile
210            .default_after_write()
211            .iter()
212            .map(|s| (*s).to_string())
213            .collect(),
214    };
215
216    let mut results = Vec::new();
217    for cmd in &commands {
218        let result = Command::new("sh")
219            .arg("-c")
220            .arg(cmd)
221            .current_dir(cwd)
222            .stdin(Stdio::null())
223            .stdout(Stdio::piped())
224            .stderr(Stdio::piped())
225            .output()
226            .await;
227
228        match result {
229            Ok(output) => {
230                let stdout = String::from_utf8_lossy(&output.stdout);
231                let stderr = String::from_utf8_lossy(&output.stderr);
232                let combined = if stderr.is_empty() {
233                    stdout.to_string()
234                } else {
235                    format!("{stdout}{stderr}")
236                };
237                // Truncate to avoid flooding context
238                let truncated = if combined.len() > 2000 {
239                    format!(
240                        "{}\n... (truncated)",
241                        truncate_chars_with_suffix(&combined, 2000, "")
242                    )
243                } else {
244                    combined
245                };
246                results.push(CheckResult {
247                    command: cmd.clone(),
248                    success: output.status.success(),
249                    output: truncated,
250                });
251            }
252            Err(e) => {
253                results.push(CheckResult {
254                    command: cmd.clone(),
255                    success: false,
256                    output: format!("Failed to run: {e}"),
257                });
258            }
259        }
260    }
261    results
262}
263
264/// Format check results into a message for the agent.
265#[must_use]
266pub fn format_check_results(results: &[CheckResult], level: GuardrailLevel) -> String {
267    if results.is_empty() {
268        return String::new();
269    }
270
271    let all_passed = results.iter().all(|r| r.success);
272    if all_passed {
273        return "Guardrail checks passed.".to_string();
274    }
275
276    let mut s = match level {
277        GuardrailLevel::Enforce => {
278            String::from("⚠ GUARDRAIL CHECK FAILED (enforce mode — fix before proceeding):\n")
279        }
280        GuardrailLevel::Advisory => {
281            String::from("⚠ Guardrail check failed (advisory — review before continuing):\n")
282        }
283    };
284
285    for r in results {
286        if !r.success {
287            s.push_str(&format!("\n  Command: {}\n", r.command));
288            if !r.output.is_empty() {
289                for line in r.output.lines().take(20) {
290                    s.push_str(&format!("    {line}\n"));
291                }
292            }
293        }
294    }
295    s
296}
297
298// -- Prompt guidance text per profile ----------------------------------------
299
300const GUIDANCE_GENERIC: &str = "\
301- Prefer the smallest, local fix over a cross-file refactor.
302- Search for existing patterns first; mirror naming, error handling, and conventions.
303- Keep control flow straightforward and easy to follow.
304- Keep loops, retries, and timeouts bounded.
305- Make error handling explicit — don't silently ignore failures.
306- Leave code warning-free and easy to verify.
307- Don't add new dependencies without explicit user approval.
308";
309
310const GUIDANCE_ZIG: &str = "\
311- Keep control flow straightforward and easy to follow.
312- Keep loops, retries, and buffers bounded.
313- Handle errors explicitly with try/catch — avoid casual catch unreachable.
314- Keep allocator ownership and lifetime clear.
315- Prefer small, readable functions with minimal hidden control flow.
316- Leave code formatted, buildable, and warning-free.
317";
318
319const GUIDANCE_RUST: &str = "\
320- Keep control flow straightforward and easy to follow.
321- Keep loops, retries, and timeouts bounded.
322- Use Result with meaningful error propagation — avoid unwrap() in non-test code.
323- Keep async behavior bounded and timeouts explicit.
324- Prefer small, focused changes over broad rewrites.
325- Leave code clippy-clean with zero warnings.
326";
327
328const GUIDANCE_TYPESCRIPT: &str = "\
329- Keep control flow straightforward and easy to follow.
330- Keep loops, retries, and timeouts bounded.
331- Make error handling explicit — don't silently swallow rejections or errors.
332- Use strict typing — avoid any unless justified.
333- Keep async/Promise flows bounded and understandable.
334- Leave typecheck and lint status clean.
335";
336
337const GUIDANCE_C: &str = "\
338- Keep control flow straightforward and easy to follow.
339- Keep loops, retries, and buffer sizes bounded.
340- Make error handling explicit — check return values.
341- Keep pointer usage straightforward and well-scoped.
342- Avoid preprocessor complexity when simpler code works.
343- Leave build and test status clean.
344";
345
346const GUIDANCE_GO: &str = "\
347- Keep control flow straightforward and easy to follow.
348- Keep loops, retries, and timeouts bounded.
349- Check and propagate errors explicitly — don't ignore returned errors.
350- Keep goroutine lifecycle and cancellation understandable.
351- Prefer small functions and direct control flow.
352- Leave formatting and vet status clean.
353";
354
355const GUIDANCE_ELIXIR: &str = "\
356- Keep control flow straightforward and easy to follow.
357- Keep retries and message flows bounded.
358- Keep process and supervision boundaries clear.
359- Handle {:ok, value} / {:error, reason} tuples explicitly.
360- Avoid hiding important behavior in opaque control flow.
361- Leave formatting and compilation warnings-free.
362";
363
364const GUIDANCE_KOTLIN: &str = "\
365- Keep control flow straightforward and easy to follow.
366- Prefer val over var; keep mutation local and obvious.
367- Treat nullability as part of the design — avoid !! outside tests or impossible states.
368- Use structured concurrency; avoid GlobalScope and do not swallow CancellationException.
369- Keep Gradle/Maven verification project-specific and use ./gradlew when available.
370- Leave formatting, lint, and tests clean for the touched module.
371";
372
373#[cfg(test)]
374mod tests {
375    use super::*;
376    use serde::Deserialize;
377    use tempfile::TempDir;
378
379    #[derive(Debug, Deserialize)]
380    struct GuardrailToml {
381        guardrails: GuardrailConfig,
382    }
383
384    #[test]
385    fn guardrail_toml_deserializes() {
386        let parsed: GuardrailToml = toml::from_str(
387            r#"
388[guardrails]
389enabled = true
390level = "enforce"
391profile = "zig"
392critical_paths = ["src/**", "lib/**"]
393after_write = ["zig fmt --check .", "zig build"]
394"#,
395        )
396        .unwrap();
397
398        assert_eq!(parsed.guardrails.enabled, Some(true));
399        assert_eq!(parsed.guardrails.level, Some(GuardrailLevel::Enforce));
400        assert_eq!(parsed.guardrails.profile, Some(GuardrailProfile::Zig));
401        assert_eq!(
402            parsed.guardrails.critical_paths,
403            Some(vec!["src/**".into(), "lib/**".into()])
404        );
405        assert_eq!(
406            parsed.guardrails.after_write,
407            Some(vec!["zig fmt --check .".into(), "zig build".into()])
408        );
409    }
410
411    #[test]
412    fn guardrail_auto_profile_resolves_zig() {
413        let dir = TempDir::new().unwrap();
414        std::fs::write(dir.path().join("build.zig"), "").unwrap();
415
416        let config = GuardrailConfig {
417            profile: Some(GuardrailProfile::Auto),
418            ..Default::default()
419        };
420
421        assert_eq!(
422            config.resolve_effective_profile(dir.path()),
423            GuardrailProfile::Zig
424        );
425    }
426
427    #[test]
428    fn guardrail_auto_profile_resolves_rust_from_subdirectory() {
429        let dir = TempDir::new().unwrap();
430        std::fs::write(
431            dir.path().join("Cargo.toml"),
432            "[package]\nname='x'\nversion='0.1.0'\n",
433        )
434        .unwrap();
435        let nested = dir.path().join("src").join("nested");
436        std::fs::create_dir_all(&nested).unwrap();
437
438        let config = GuardrailConfig {
439            profile: Some(GuardrailProfile::Auto),
440            ..Default::default()
441        };
442
443        assert_eq!(
444            config.resolve_effective_profile(&nested),
445            GuardrailProfile::Rust
446        );
447    }
448
449    #[test]
450    fn guardrail_auto_profile_resolves_go() {
451        let dir = TempDir::new().unwrap();
452        std::fs::write(dir.path().join("go.mod"), "module example.com/test\n").unwrap();
453
454        let config = GuardrailConfig {
455            profile: Some(GuardrailProfile::Auto),
456            ..Default::default()
457        };
458
459        assert_eq!(
460            config.resolve_effective_profile(dir.path()),
461            GuardrailProfile::Go
462        );
463    }
464
465    #[test]
466    fn guardrail_auto_profile_resolves_elixir() {
467        let dir = TempDir::new().unwrap();
468        std::fs::write(
469            dir.path().join("mix.exs"),
470            "defmodule Demo.MixProject do end\n",
471        )
472        .unwrap();
473
474        let config = GuardrailConfig {
475            profile: Some(GuardrailProfile::Auto),
476            ..Default::default()
477        };
478
479        assert_eq!(
480            config.resolve_effective_profile(dir.path()),
481            GuardrailProfile::Elixir
482        );
483    }
484
485    #[test]
486    fn guardrail_auto_profile_resolves_kotlin_gradle() {
487        let dir = TempDir::new().unwrap();
488        std::fs::write(
489            dir.path().join("settings.gradle.kts"),
490            "pluginManagement {}\n",
491        )
492        .unwrap();
493        std::fs::write(
494            dir.path().join("build.gradle.kts"),
495            "plugins { kotlin(\"jvm\") version \"2.0.0\" }\n",
496        )
497        .unwrap();
498
499        let config = GuardrailConfig {
500            profile: Some(GuardrailProfile::Auto),
501            ..Default::default()
502        };
503
504        assert_eq!(
505            config.resolve_effective_profile(dir.path()),
506            GuardrailProfile::Kotlin
507        );
508    }
509
510    #[test]
511    fn guardrail_auto_profile_falls_back_to_generic() {
512        let dir = TempDir::new().unwrap();
513        let config = GuardrailConfig {
514            profile: Some(GuardrailProfile::Auto),
515            ..Default::default()
516        };
517
518        assert_eq!(
519            config.resolve_effective_profile(dir.path()),
520            GuardrailProfile::Generic
521        );
522    }
523
524    #[test]
525    fn guardrail_prompt_guidance_varies_by_profile() {
526        let zig = GuardrailProfile::Zig.prompt_guidance();
527        let rust = GuardrailProfile::Rust.prompt_guidance();
528        let generic = GuardrailProfile::Generic.prompt_guidance();
529
530        assert!(zig.contains("catch unreachable"));
531        assert!(zig.contains("allocator"));
532        assert!(rust.contains("clippy"));
533        assert!(rust.contains("unwrap"));
534        assert!(generic.contains("warning-free"));
535        assert_ne!(zig, rust);
536        assert_ne!(zig, generic);
537    }
538
539    #[test]
540    fn guardrail_default_after_write_zig() {
541        let cmds = GuardrailProfile::Zig.default_after_write();
542        assert_eq!(cmds.len(), 3);
543        assert!(cmds[0].contains("zig fmt"));
544    }
545
546    #[test]
547    fn guardrail_default_after_write_generic_is_empty() {
548        assert!(GuardrailProfile::Generic.default_after_write().is_empty());
549    }
550
551    #[test]
552    fn guardrail_layer_contains_header() {
553        let layer = guardrails_layer(GuardrailProfile::Zig);
554        assert!(layer.starts_with("## Engineering Guardrails"));
555        assert!(layer.contains("catch unreachable"));
556    }
557
558    #[test]
559    fn guardrail_format_check_results_all_passed() {
560        let results = vec![CheckResult {
561            command: "zig build".into(),
562            success: true,
563            output: String::new(),
564        }];
565        let msg = format_check_results(&results, GuardrailLevel::Advisory);
566        assert_eq!(msg, "Guardrail checks passed.");
567    }
568
569    #[test]
570    fn guardrail_format_check_results_failure_enforce() {
571        let results = vec![CheckResult {
572            command: "cargo clippy".into(),
573            success: false,
574            output: "warning: unused variable".into(),
575        }];
576        let msg = format_check_results(&results, GuardrailLevel::Enforce);
577        assert!(msg.contains("GUARDRAIL CHECK FAILED"));
578        assert!(msg.contains("enforce"));
579        assert!(msg.contains("cargo clippy"));
580    }
581
582    #[test]
583    fn guardrail_format_check_results_failure_advisory() {
584        let results = vec![CheckResult {
585            command: "mix test".into(),
586            success: false,
587            output: "1 test failed".into(),
588        }];
589        let msg = format_check_results(&results, GuardrailLevel::Advisory);
590        assert!(msg.contains("advisory"));
591        assert!(msg.contains("mix test"));
592    }
593
594    #[test]
595    fn guardrail_merge_only_overrides_present_fields() {
596        let mut base = GuardrailConfig {
597            enabled: Some(true),
598            level: Some(GuardrailLevel::Advisory),
599            profile: Some(GuardrailProfile::Rust),
600            critical_paths: Some(vec!["src/**".into()]),
601            after_write: None,
602        };
603
604        let overlay = GuardrailConfig {
605            enabled: None,
606            level: Some(GuardrailLevel::Enforce),
607            profile: None,
608            critical_paths: None,
609            after_write: Some(vec!["cargo test".into()]),
610        };
611
612        base.merge(overlay);
613
614        assert_eq!(base.enabled, Some(true));
615        assert_eq!(base.level, Some(GuardrailLevel::Enforce));
616        assert_eq!(base.profile, Some(GuardrailProfile::Rust));
617        assert_eq!(base.critical_paths, Some(vec!["src/**".into()]));
618        assert_eq!(base.after_write, Some(vec!["cargo test".into()]));
619    }
620}