agentic_jujutsu/
wrapper.rs

1//! Main wrapper for Jujutsu operations
2
3use crate::{
4    config::JJConfig,
5    error::{JJError, Result},
6    operations::{JJOperation, JJOperationLog, OperationType},
7    types::{JJBranch, JJCommit, JJConflict, JJDiff, JJResult},
8};
9use chrono::Utc;
10use std::sync::{Arc, Mutex};
11use std::time::Instant;
12use wasm_bindgen::prelude::*;
13
14/// Validate command arguments to prevent command injection
15fn validate_command_args(args: &[&str]) -> Result<()> {
16    for arg in args {
17        // Block shell metacharacters that could enable command injection
18        if arg.contains(&['$', '`', '&', '|', ';', '\n', '>', '<'][..]) {
19            return Err(JJError::InvalidConfig(format!(
20                "Invalid character in argument: {}. Shell metacharacters are not allowed.",
21                arg
22            )));
23        }
24        // Block null bytes
25        if arg.contains('\0') {
26            return Err(JJError::InvalidConfig(
27                "Null bytes are not allowed in arguments".to_string(),
28            ));
29        }
30    }
31    Ok(())
32}
33
34// Import the appropriate execute_jj_command based on target architecture
35#[cfg(not(target_arch = "wasm32"))]
36use crate::native::execute_jj_command;
37
38#[cfg(target_arch = "wasm32")]
39use crate::wasm::execute_jj_command;
40
41/// Main wrapper for Jujutsu operations
42#[cfg_attr(target_arch = "wasm32", wasm_bindgen)]
43#[derive(Clone)]
44pub struct JJWrapper {
45    config: JJConfig,
46    #[cfg_attr(target_arch = "wasm32", wasm_bindgen(skip))]
47    operation_log: Arc<Mutex<JJOperationLog>>,
48}
49
50#[cfg_attr(target_arch = "wasm32", wasm_bindgen)]
51impl JJWrapper {
52    /// Create a new JJWrapper with default configuration
53    #[cfg_attr(target_arch = "wasm32", wasm_bindgen(constructor))]
54    pub fn new() -> JJWrapper {
55        Self::with_config(JJConfig::default()).unwrap_or_else(|_| {
56            // Fallback to a basic wrapper if default config fails
57            JJWrapper {
58                config: JJConfig::default(),
59                operation_log: Arc::new(Mutex::new(JJOperationLog::new(1000))),
60            }
61        })
62    }
63
64    /// Create a new JJWrapper with custom configuration
65    pub fn with_config(config: JJConfig) -> Result<JJWrapper> {
66        let operation_log = Arc::new(Mutex::new(JJOperationLog::new(config.max_log_entries)));
67
68        Ok(JJWrapper {
69            config,
70            operation_log,
71        })
72    }
73
74    /// Get the current configuration
75    #[cfg_attr(target_arch = "wasm32", wasm_bindgen(js_name = getConfig))]
76    pub fn get_config(&self) -> JJConfig {
77        self.config.clone()
78    }
79
80    /// Get operation log statistics as JSON string
81    #[cfg_attr(target_arch = "wasm32", wasm_bindgen(js_name = getStats))]
82    pub fn get_stats(&self) -> String {
83        let log = self.operation_log.lock().unwrap();
84        serde_json::json!({
85            "total_operations": log.count(),
86            "avg_duration_ms": log.avg_duration_ms(),
87            "success_rate": log.success_rate(),
88        })
89        .to_string()
90    }
91
92    /// Execute a jj command and return the result
93    pub async fn execute(&self, args: &[&str]) -> Result<JJResult> {
94        // Validate arguments for security
95        validate_command_args(args)?;
96
97        let start = Instant::now();
98        let command = format!("jj {}", args.join(" "));
99
100        #[cfg(not(target_arch = "wasm32"))]
101        let result = {
102            let timeout = std::time::Duration::from_millis(self.config.timeout_ms);
103            match execute_jj_command(&self.config.jj_path(), args, timeout).await {
104                Ok(output) => {
105                    JJResult::new(output, String::new(), 0, start.elapsed().as_millis() as u64)
106                }
107                Err(e) => {
108                    return Err(JJError::CommandFailed(e.to_string()));
109                }
110            }
111        };
112
113        #[cfg(target_arch = "wasm32")]
114        let result = {
115            let timeout = std::time::Duration::from_millis(self.config.timeout_ms);
116            match execute_jj_command(&self.config.jj_path(), args, timeout).await {
117                Ok(output) => {
118                    JJResult::new(output, String::new(), 0, start.elapsed().as_millis() as u64)
119                }
120                Err(e) => {
121                    return Err(JJError::CommandFailed(e.to_string()));
122                }
123            }
124        };
125
126        // Log the operation
127        let hostname = std::env::var("HOSTNAME").unwrap_or_else(|_| "unknown".to_string());
128        let username = std::env::var("USER").unwrap_or_else(|_| "unknown".to_string());
129
130        let mut operation = JJOperation::new(
131            format!("{}@{}", Utc::now().timestamp(), hostname),
132            command.clone(),
133            username,
134            hostname,
135        );
136
137        operation.operation_type = Self::detect_operation_type(args);
138        operation.success = result.success();
139        operation.duration_ms = result.execution_time_ms;
140
141        self.operation_log.lock().unwrap().add_operation(operation);
142
143        Ok(result)
144    }
145
146    /// Detect operation type from command arguments
147    fn detect_operation_type(args: &[&str]) -> OperationType {
148        if args.is_empty() {
149            return OperationType::Unknown;
150        }
151
152        match args[0] {
153            "describe" => OperationType::Describe,
154            "new" => OperationType::New,
155            "edit" => OperationType::Edit,
156            "abandon" => OperationType::Abandon,
157            "rebase" => OperationType::Rebase,
158            "squash" => OperationType::Squash,
159            "resolve" => OperationType::Resolve,
160            "branch" => OperationType::Branch,
161            "bookmark" => OperationType::Bookmark,
162            "git" if args.len() > 1 && args[1] == "fetch" => OperationType::GitFetch,
163            "git" if args.len() > 1 && args[1] == "push" => OperationType::GitPush,
164            "undo" => OperationType::Undo,
165            "restore" => OperationType::Restore,
166            _ => OperationType::Unknown,
167        }
168    }
169
170    /// Get operations from the operation log
171    pub fn get_operations(&self, limit: usize) -> Result<Vec<JJOperation>> {
172        Ok(self.operation_log.lock().unwrap().get_recent(limit))
173    }
174
175    /// Get user-initiated operations (exclude snapshots)
176    pub fn get_user_operations(&self, limit: usize) -> Result<Vec<JJOperation>> {
177        Ok(self
178            .operation_log
179            .lock()
180            .unwrap()
181            .get_user_operations(limit))
182    }
183
184    /// Get conflicts in the current commit or specified commit
185    pub async fn get_conflicts(&self, commit: Option<&str>) -> Result<Vec<JJConflict>> {
186        let args = if let Some(c) = commit {
187            vec!["resolve", "--list", "-r", c]
188        } else {
189            vec!["resolve", "--list"]
190        };
191
192        let result = self.execute(&args).await?;
193        Self::parse_conflicts(&result.stdout)
194    }
195
196    /// Parse conflict list output
197    fn parse_conflicts(output: &str) -> Result<Vec<JJConflict>> {
198        let mut conflicts = Vec::new();
199
200        for line in output.lines() {
201            let line = line.trim();
202            if line.is_empty() || line.starts_with("No conflicts") {
203                continue;
204            }
205
206            // Parse format: "path/to/file    2-sided conflict"
207            let parts: Vec<&str> = line.split_whitespace().collect();
208            if parts.len() >= 2 {
209                let path = parts[0].to_string();
210                let conflict_info = parts[1..].join(" ");
211
212                let num_conflicts = conflict_info
213                    .split('-')
214                    .next()
215                    .and_then(|s| s.trim().parse::<usize>().ok())
216                    .unwrap_or(1);
217
218                let mut conflict = JJConflict::new(path, num_conflicts, "content".to_string());
219
220                // Extract number of sides
221                if conflict_info.contains("sided") {
222                    for _ in 0..num_conflicts {
223                        conflict.add_side(format!("side-{}", conflicts.len()));
224                    }
225                }
226
227                conflicts.push(conflict);
228            }
229        }
230
231        Ok(conflicts)
232    }
233
234    /// Describe the current commit with a message
235    pub async fn describe(&self, message: &str) -> Result<JJOperation> {
236        let args = vec!["describe", "-m", message];
237        let result = self.execute(&args).await?;
238
239        if !result.success() {
240            return Err(JJError::CommandFailed(result.stderr));
241        }
242
243        // Return the most recent operation
244        self.get_operations(1)?
245            .into_iter()
246            .next()
247            .ok_or_else(|| JJError::OperationNotFound("describe".to_string()))
248    }
249
250    /// Get repository status
251    pub async fn status(&self) -> Result<JJResult> {
252        self.execute(&["status"]).await
253    }
254
255
256    /// Get diff between two commits
257    pub async fn diff(&self, from: &str, to: &str) -> Result<JJDiff> {
258        let args = vec!["diff", "--from", from, "--to", to];
259        let result = self.execute(&args).await?;
260
261        Self::parse_diff(&result.stdout)
262    }
263
264    /// Parse diff output
265    fn parse_diff(output: &str) -> Result<JJDiff> {
266        let mut diff = JJDiff::new();
267        diff.content = output.to_string();
268
269        for line in output.lines() {
270            if line.starts_with("+++") {
271                // Added file
272                if let Some(path) = line.strip_prefix("+++ ") {
273                    let path = path.trim_start_matches("b/");
274                    if path != "/dev/null" {
275                        diff.added.push(path.to_string());
276                    }
277                }
278            } else if line.starts_with("---") {
279                // Deleted file
280                if let Some(path) = line.strip_prefix("--- ") {
281                    let path = path.trim_start_matches("a/");
282                    if path != "/dev/null" {
283                        diff.deleted.push(path.to_string());
284                    }
285                }
286            } else if line.starts_with("+") && !line.starts_with("+++") {
287                diff.additions += 1;
288            } else if line.starts_with("-") && !line.starts_with("---") {
289                diff.deletions += 1;
290            }
291        }
292
293        Ok(diff)
294    }
295
296    /// Create a new commit (renamed from 'new' to avoid confusion with constructor)
297    pub async fn new_commit(&self, message: Option<&str>) -> Result<JJResult> {
298        let mut args = vec!["new"];
299        if let Some(msg) = message {
300            args.extend(&["-m", msg]);
301        }
302        self.execute(&args).await
303    }
304
305    /// Edit a commit
306    pub async fn edit(&self, revision: &str) -> Result<JJResult> {
307        self.execute(&["edit", revision]).await
308    }
309
310    /// Abandon a commit
311    pub async fn abandon(&self, revision: &str) -> Result<JJResult> {
312        self.execute(&["abandon", revision]).await
313    }
314
315    /// Squash commits
316    pub async fn squash(&self, from: Option<&str>, to: Option<&str>) -> Result<JJResult> {
317        let mut args = vec!["squash"];
318        if let Some(f) = from {
319            args.extend(&["-r", f]);
320        }
321        if let Some(t) = to {
322            args.extend(&["--into", t]);
323        }
324        self.execute(&args).await
325    }
326
327    /// Rebase commits
328    pub async fn rebase(&self, source: &str, destination: &str) -> Result<JJResult> {
329        self.execute(&["rebase", "-s", source, "-d", destination])
330            .await
331    }
332
333    /// Resolve conflicts
334    pub async fn resolve(&self, path: Option<&str>) -> Result<JJResult> {
335        let mut args = vec!["resolve"];
336        if let Some(p) = path {
337            args.push(p);
338        }
339        self.execute(&args).await
340    }
341
342    /// Create a branch
343    pub async fn branch_create(&self, name: &str, revision: Option<&str>) -> Result<JJResult> {
344        let mut args = vec!["branch", "create", name];
345        if let Some(rev) = revision {
346            args.extend(&["-r", rev]);
347        }
348        self.execute(&args).await
349    }
350
351    /// Delete a branch
352    pub async fn branch_delete(&self, name: &str) -> Result<JJResult> {
353        self.execute(&["branch", "delete", name]).await
354    }
355
356    /// List branches
357    pub async fn branch_list(&self) -> Result<Vec<JJBranch>> {
358        let result = self.execute(&["branch", "list"]).await?;
359        Self::parse_branches(&result.stdout)
360    }
361
362    /// Parse branch list output
363    fn parse_branches(output: &str) -> Result<Vec<JJBranch>> {
364        let mut branches = Vec::new();
365
366        for line in output.lines() {
367            let line = line.trim();
368            if line.is_empty() {
369                continue;
370            }
371
372            // Parse format: "branch-name: commit-id"
373            let parts: Vec<&str> = line.split(':').collect();
374            if parts.len() >= 2 {
375                let name = parts[0].trim().to_string();
376                let target = parts[1]
377                    .trim()
378                    .split_whitespace()
379                    .next()
380                    .unwrap_or("")
381                    .to_string();
382
383                let is_remote = name.contains('/');
384                let mut branch = JJBranch::new(name.clone(), target, is_remote);
385
386                if is_remote {
387                    if let Some((remote, _)) = name.split_once('/') {
388                        branch.set_remote(remote.to_string());
389                    }
390                }
391
392                branches.push(branch);
393            }
394        }
395
396        Ok(branches)
397    }
398
399    /// Undo the last operation
400    pub async fn undo(&self) -> Result<JJResult> {
401        self.execute(&["undo"]).await
402    }
403
404    /// Restore files
405    pub async fn restore(&self, paths: &[&str]) -> Result<JJResult> {
406        let mut args = vec!["restore"];
407        args.extend(paths);
408        self.execute(&args).await
409    }
410
411    /// Show commit log
412    pub async fn log(&self, limit: Option<usize>) -> Result<Vec<JJCommit>> {
413        let mut args = vec!["log"];
414        let limit_str;
415        if let Some(l) = limit {
416            limit_str = l.to_string();
417            args.extend(&["--limit", &limit_str]);
418        }
419        let result = self.execute(&args).await?;
420        Self::parse_log(&result.stdout)
421    }
422
423    /// Parse log output
424    fn parse_log(output: &str) -> Result<Vec<JJCommit>> {
425        let mut commits = Vec::new();
426
427        // Simple parser - in production, use `jj log --template` with JSON output
428        for block in output.split("\n\n") {
429            let lines: Vec<&str> = block.lines().collect();
430            if lines.is_empty() {
431                continue;
432            }
433
434            let mut commit = JJCommit::new(
435                "unknown".to_string(),
436                "unknown".to_string(),
437                String::new(),
438                "unknown".to_string(),
439                "unknown@example.com".to_string(),
440            );
441
442            for line in lines {
443                if let Some(id) = line.strip_prefix("Commit ID: ") {
444                    commit.id = id.trim().to_string();
445                } else if let Some(change) = line.strip_prefix("Change ID: ") {
446                    commit.change_id = change.trim().to_string();
447                } else if let Some(author) = line.strip_prefix("Author: ") {
448                    let parts: Vec<&str> = author.split('<').collect();
449                    if parts.len() == 2 {
450                        commit.author = parts[0].trim().to_string();
451                        commit.author_email = parts[1].trim_end_matches('>').trim().to_string();
452                    }
453                }
454            }
455
456            commits.push(commit);
457        }
458
459        Ok(commits)
460    }
461
462    /// Clear operation log
463    pub fn clear_log(&self) {
464        self.operation_log.lock().unwrap().clear();
465    }
466}
467
468// Non-WASM impl block for Rust-only methods
469impl JJWrapper {
470    /// Create wrapper with config (Rust-only)
471    pub fn with_config_checked(config: JJConfig) -> Result<JJWrapper> {
472        Self::with_config(config)
473    }
474}
475
476impl Default for JJWrapper {
477    fn default() -> Self {
478        Self::new()
479    }
480}
481
482#[cfg(test)]
483mod tests {
484    use super::*;
485
486    #[test]
487    fn test_wrapper_creation() {
488        let wrapper = JJWrapper::new();
489        assert!(wrapper.is_ok());
490
491        let config = JJConfig::default().with_verbose(true);
492        let wrapper = JJWrapper::with_config(config);
493        assert!(wrapper.is_ok());
494    }
495
496    #[test]
497    fn test_detect_operation_type() {
498        assert_eq!(
499            JJWrapper::detect_operation_type(&["describe", "-m", "test"]),
500            OperationType::Describe
501        );
502        assert_eq!(
503            JJWrapper::detect_operation_type(&["new"]),
504            OperationType::New
505        );
506        assert_eq!(
507            JJWrapper::detect_operation_type(&["git", "fetch"]),
508            OperationType::GitFetch
509        );
510    }
511
512    #[test]
513    fn test_parse_conflicts() {
514        let output = "file1.txt    2-sided conflict\nfile2.rs    3-sided conflict";
515        let conflicts = JJWrapper::parse_conflicts(output).unwrap();
516
517        assert_eq!(conflicts.len(), 2);
518        assert_eq!(conflicts[0].path, "file1.txt");
519        assert_eq!(conflicts[0].num_conflicts, 2);
520        assert_eq!(conflicts[1].path, "file2.rs");
521        assert_eq!(conflicts[1].num_conflicts, 3);
522    }
523
524    #[test]
525    fn test_parse_diff() {
526        let output = r#"
527+++ b/new.txt
528--- a/deleted.txt
529+Added line
530-Removed line
531        "#;
532
533        let diff = JJWrapper::parse_diff(output).unwrap();
534        assert_eq!(diff.additions, 1);
535        assert_eq!(diff.deletions, 1);
536    }
537
538    #[test]
539    fn test_parse_branches() {
540        let output = "main: abc123\norigin/main: def456";
541        let branches = JJWrapper::parse_branches(output).unwrap();
542
543        assert_eq!(branches.len(), 2);
544        assert_eq!(branches[0].name, "main");
545        assert!(!branches[0].is_remote);
546        assert_eq!(branches[1].name, "origin/main");
547        assert!(branches[1].is_remote);
548    }
549}