Skip to main content

agent_shell_parser/
lib.rs

1use serde::Deserialize;
2use std::path::Path;
3use std::process::Command;
4
5pub mod guard;
6
7#[derive(Debug, thiserror::Error)]
8pub enum Error {
9    #[error("failed to read stdin: {0}")]
10    Stdin(#[from] std::io::Error),
11    #[error("failed to parse JSON input: {0}")]
12    Json(#[from] serde_json::Error),
13    #[error("jj command failed: {0}")]
14    Jj(String),
15}
16
17#[derive(Debug, Deserialize)]
18pub struct WorktreeCreateInput {
19    pub name: String,
20    pub cwd: String,
21    #[serde(default)]
22    pub session_id: Option<String>,
23}
24
25#[derive(Debug, Deserialize)]
26pub struct WorktreeRemoveInput {
27    pub worktree_path: String,
28    #[serde(default)]
29    pub session_id: Option<String>,
30}
31
32#[derive(Debug, Deserialize)]
33pub struct PreToolUseInput {
34    pub tool_name: String,
35    #[serde(default)]
36    pub tool_input: serde_json::Value,
37    #[serde(default)]
38    pub cwd: Option<String>,
39}
40
41pub fn parse_input<T: serde::de::DeserializeOwned>() -> Result<T, Error> {
42    let input = std::io::read_to_string(std::io::stdin())?;
43    Ok(serde_json::from_str(&input)?)
44}
45
46/// Returns the index of the first token that is the actual command being invoked,
47/// skipping leading env-var assignments (`FOO=bar`).
48///
49/// An env-var assignment is a token matching `[A-Za-z_][A-Za-z0-9_]*=.*`.
50pub fn find_command_position(words: &[String]) -> Option<usize> {
51    for (i, word) in words.iter().enumerate() {
52        if is_env_assignment(word) {
53            continue;
54        }
55        return Some(i);
56    }
57    None
58}
59
60fn is_env_assignment(word: &str) -> bool {
61    let Some(eq) = word.find('=') else {
62        return false;
63    };
64    if eq == 0 {
65        return false;
66    }
67    let name = &word[..eq];
68    let mut chars = name.chars();
69    let first = chars.next().unwrap();
70    (first.is_ascii_alphabetic() || first == '_')
71        && chars.all(|c| c.is_ascii_alphanumeric() || c == '_')
72}
73
74pub fn is_jj_colocated(cwd: &Path) -> bool {
75    Command::new("jj")
76        .arg("root")
77        .current_dir(cwd)
78        .stdout(std::process::Stdio::null())
79        .stderr(std::process::Stdio::null())
80        .status()
81        .map(|s| s.success())
82        .unwrap_or(false)
83}
84
85pub fn jj_version() -> Option<(u32, u32, u32)> {
86    let output = Command::new("jj").arg("--version").output().ok()?;
87    let text = String::from_utf8_lossy(&output.stdout);
88    let version_str = text.strip_prefix("jj ")?;
89    let mut parts = version_str.trim().split('.');
90    let major = parts.next()?.parse().ok()?;
91    let minor = parts.next()?.parse().ok()?;
92    let patch = parts.next().and_then(|p| p.parse().ok()).unwrap_or(0);
93    Some((major, minor, patch))
94}
95
96pub fn require_jj_version(min_major: u32, min_minor: u32) -> Result<(), String> {
97    match jj_version() {
98        None => Err("jj-cli not found. Install with: cargo install --locked jj-cli".into()),
99        Some((major, minor, _))
100            if major < min_major || (major == min_major && minor < min_minor) =>
101        {
102            Err(format!(
103                "jj-cli {major}.{minor} found, but >= {min_major}.{min_minor} required. \
104                 Upgrade with: cargo install --locked jj-cli"
105            ))
106        }
107        Some(_) => Ok(()),
108    }
109}
110
111#[cfg(test)]
112mod tests {
113    use super::*;
114
115    #[test]
116    fn command_position_simple() {
117        let w: Vec<String> = vec!["git".into(), "push".into()];
118        assert_eq!(find_command_position(&w), Some(0));
119    }
120
121    #[test]
122    fn command_position_with_env_vars() {
123        let w: Vec<String> = vec!["FOO=bar".into(), "git".into(), "push".into()];
124        assert_eq!(find_command_position(&w), Some(1));
125    }
126
127    #[test]
128    fn command_position_multiple_env_vars() {
129        let w: Vec<String> = vec!["A=1".into(), "B=2".into(), "git".into(), "push".into()];
130        assert_eq!(find_command_position(&w), Some(2));
131    }
132
133    #[test]
134    fn command_position_jj() {
135        let w: Vec<String> = vec!["jj".into(), "git".into(), "push".into()];
136        assert_eq!(find_command_position(&w), Some(0));
137    }
138
139    #[test]
140    fn command_position_empty() {
141        let w: Vec<String> = vec![];
142        assert_eq!(find_command_position(&w), None);
143    }
144
145    #[test]
146    fn command_position_only_assignments() {
147        let w: Vec<String> = vec!["FOO=bar".into()];
148        assert_eq!(find_command_position(&w), None);
149    }
150
151    #[test]
152    fn env_assignment_valid() {
153        assert!(is_env_assignment("FOO=bar"));
154        assert!(is_env_assignment("_X=1"));
155        assert!(is_env_assignment("GIT_CONFIG_GLOBAL=~/.gitconfig.ai"));
156    }
157
158    #[test]
159    fn env_assignment_invalid() {
160        assert!(!is_env_assignment("git"));
161        assert!(!is_env_assignment("--flag=value"));
162        assert!(!is_env_assignment("=bar"));
163        assert!(!is_env_assignment("123=bar"));
164    }
165
166    #[test]
167    fn parse_worktree_create_input() {
168        let json = r#"{"name": "test-feature", "cwd": "/tmp/repo", "session_id": "abc123"}"#;
169        let input: WorktreeCreateInput = serde_json::from_str(json).unwrap();
170        assert_eq!(input.name, "test-feature");
171        assert_eq!(input.cwd, "/tmp/repo");
172        assert_eq!(input.session_id.as_deref(), Some("abc123"));
173    }
174
175    #[test]
176    fn parse_worktree_remove_input() {
177        let json = r#"{"worktree_path": "/tmp/repo/.claude/worktrees/test-feature"}"#;
178        let input: WorktreeRemoveInput = serde_json::from_str(json).unwrap();
179        assert_eq!(
180            input.worktree_path,
181            "/tmp/repo/.claude/worktrees/test-feature"
182        );
183        assert!(input.session_id.is_none());
184    }
185
186    #[test]
187    fn parse_pre_tool_use_input() {
188        let json = r#"{"tool_name": "Bash", "tool_input": {"command": "git commit -m test"}, "cwd": "/tmp/repo"}"#;
189        let input: PreToolUseInput = serde_json::from_str(json).unwrap();
190        assert_eq!(input.tool_name, "Bash");
191        assert_eq!(input.cwd.as_deref(), Some("/tmp/repo"));
192    }
193}