Skip to main content

tracevault_cli/commands/
check.rs

1use crate::api_client::{resolve_credentials, ApiClient, CheckPoliciesRequest, SessionCheckData};
2use crate::config::TracevaultConfig;
3use std::collections::HashSet;
4use std::fs;
5use std::path::Path;
6use std::process::Command;
7
8fn git_repo_name(project_root: &Path) -> String {
9    Command::new("git")
10        .args(["rev-parse", "--show-toplevel"])
11        .current_dir(project_root)
12        .output()
13        .ok()
14        .filter(|o| o.status.success())
15        .map(|o| String::from_utf8_lossy(&o.stdout).trim().to_string())
16        .as_deref()
17        .and_then(|p| p.rsplit('/').next())
18        .map(String::from)
19        .unwrap_or_else(|| "unknown".into())
20}
21
22fn collect_session_data(session_dir: &Path) -> Option<SessionCheckData> {
23    let session_id = session_dir.file_name()?.to_string_lossy().to_string();
24
25    // Read events.jsonl for files_modified
26    let events_path = session_dir.join("events.jsonl");
27    let mut files_modified = Vec::new();
28    let mut files_seen = HashSet::new();
29
30    if events_path.exists() {
31        if let Ok(content) = fs::read_to_string(&events_path) {
32            for line in content.lines() {
33                let event: serde_json::Value = match serde_json::from_str(line) {
34                    Ok(v) => v,
35                    Err(_) => continue,
36                };
37                if let Some(path) = event
38                    .get("tool_input")
39                    .and_then(|v| v.get("file_path"))
40                    .and_then(|v| v.as_str())
41                {
42                    if files_seen.insert(path.to_string()) {
43                        files_modified.push(path.to_string());
44                    }
45                }
46            }
47        }
48    }
49
50    // Read transcript for tool_calls
51    let meta_path = session_dir.join("metadata.json");
52    let metadata: Option<serde_json::Value> = meta_path
53        .exists()
54        .then(|| fs::read_to_string(&meta_path).ok())
55        .flatten()
56        .and_then(|c| serde_json::from_str(&c).ok());
57
58    let transcript_path = metadata
59        .as_ref()
60        .and_then(|m| m.get("transcript_path"))
61        .and_then(|v| v.as_str())
62        .map(|s| s.to_string());
63
64    let mut tool_calls_map: std::collections::HashMap<String, i32> =
65        std::collections::HashMap::new();
66    let mut total_tool_calls: i32 = 0;
67
68    if let Some(path) = &transcript_path {
69        if let Ok(content) = fs::read_to_string(path) {
70            for line in content.lines() {
71                let entry: serde_json::Value = match serde_json::from_str(line) {
72                    Ok(v) => v,
73                    Err(_) => continue,
74                };
75
76                if entry.get("type").and_then(|v| v.as_str()) == Some("assistant") {
77                    if let Some(content_arr) = entry
78                        .get("message")
79                        .and_then(|m| m.get("content"))
80                        .and_then(|c| c.as_array())
81                    {
82                        for block in content_arr {
83                            if block.get("type").and_then(|v| v.as_str()) == Some("tool_use") {
84                                if let Some(name) = block.get("name").and_then(|v| v.as_str()) {
85                                    *tool_calls_map.entry(name.to_string()).or_insert(0) += 1;
86                                    total_tool_calls += 1;
87                                }
88                            }
89                        }
90                    }
91                }
92            }
93        }
94    }
95
96    let tool_calls = if tool_calls_map.is_empty() {
97        None
98    } else {
99        serde_json::to_value(&tool_calls_map).ok()
100    };
101
102    Some(SessionCheckData {
103        session_id,
104        tool_calls,
105        files_modified: if files_modified.is_empty() {
106            None
107        } else {
108            Some(files_modified)
109        },
110        total_tool_calls: if total_tool_calls > 0 {
111            Some(total_tool_calls)
112        } else {
113            None
114        },
115    })
116}
117
118pub async fn check_policies(project_root: &Path) -> Result<(), Box<dyn std::error::Error>> {
119    let (server_url, token) = resolve_credentials(project_root);
120
121    let server_url = match server_url {
122        Some(url) => url,
123        None => {
124            return Err("No server URL configured. Run 'tracevault login' first.".into());
125        }
126    };
127
128    if token.is_none() {
129        return Err("Not logged in. Run 'tracevault login' to check policies.".into());
130    }
131
132    let org_slug = TracevaultConfig::load(project_root)
133        .and_then(|c| c.org_slug)
134        .ok_or("No org_slug in config. Run 'tracevault init' first.")?;
135
136    let client = ApiClient::new(&server_url, token.as_deref());
137
138    // Resolve repo_id by name
139    let repo_name = git_repo_name(project_root);
140    let repos = client.list_repos(&org_slug).await?;
141    let repo = repos.iter().find(|r| r.name == repo_name).ok_or_else(|| {
142        format!(
143            "Repo '{}' not found on server. Run 'tracevault sync' first.",
144            repo_name
145        )
146    })?;
147
148    // Collect session data from unpushed sessions
149    let sessions_dir = project_root.join(".tracevault").join("sessions");
150    let mut sessions = Vec::new();
151
152    if sessions_dir.exists() {
153        for entry in fs::read_dir(&sessions_dir)? {
154            let entry = entry?;
155            if !entry.file_type()?.is_dir() {
156                continue;
157            }
158            let session_dir = entry.path();
159            let pushed_marker = session_dir.join(".pushed");
160            if pushed_marker.exists() {
161                continue;
162            }
163            if let Some(data) = collect_session_data(&session_dir) {
164                sessions.push(data);
165            }
166        }
167    }
168
169    if sessions.is_empty() {
170        println!("No unpushed sessions to check.");
171        return Ok(());
172    }
173
174    println!("Checking {} session(s) against policies...", sessions.len());
175
176    let result = client
177        .check_policies(&org_slug, &repo.id, CheckPoliciesRequest { sessions })
178        .await?;
179
180    // Print results
181    for r in &result.results {
182        let icon = match r.result.as_str() {
183            "pass" => "\x1b[32m✓\x1b[0m",                             // green
184            "fail" if r.action == "block_push" => "\x1b[31m✗\x1b[0m", // red
185            "fail" => "\x1b[33m!\x1b[0m",                             // yellow
186            _ => " ",
187        };
188        println!(
189            "  {} [{}] {} — {}",
190            icon, r.severity, r.rule_name, r.details
191        );
192    }
193
194    if result.blocked {
195        eprintln!("\n\x1b[31mPolicy check failed: push blocked.\x1b[0m");
196        std::process::exit(1);
197    } else if result.passed {
198        println!("\n\x1b[32mAll policy checks passed.\x1b[0m");
199    } else {
200        println!("\n\x1b[33mPolicy warnings found (push not blocked).\x1b[0m");
201    }
202
203    Ok(())
204}