Skip to main content

ward/cli/
doctor.rs

1use anyhow::Result;
2use clap::Args;
3use console::style;
4
5use crate::config::auth;
6use crate::config::manifest::Manifest;
7
8#[derive(Args)]
9pub struct DoctorCommand;
10
11struct Check {
12    name: &'static str,
13    status: CheckStatus,
14    detail: String,
15}
16
17enum CheckStatus {
18    Pass,
19    Warn,
20    Fail,
21}
22
23impl DoctorCommand {
24    pub async fn run(&self, config_path: Option<&str>) -> Result<()> {
25        println!();
26        println!("  {}", style("Ward Doctor").bold());
27        println!("  {}", style("Diagnosing your setup...").dim());
28        println!();
29
30        let mut checks = vec![
31            check_config(config_path),
32            check_token(),
33            check_gh_cli(),
34            check_templates_dir(),
35            check_audit_log(),
36        ];
37
38        let manifest = Manifest::load(config_path).ok();
39
40        if let Some(ref m) = manifest {
41            checks.push(check_org(m));
42            checks.push(check_systems(m));
43            checks.push(check_policies(m));
44            checks.push(check_api_connectivity(config_path).await);
45        }
46
47        let mut pass = 0;
48        let mut warn = 0;
49        let mut fail = 0;
50
51        for check in &checks {
52            let icon = match check.status {
53                CheckStatus::Pass => style("[ok]").green().bold(),
54                CheckStatus::Warn => style("[!!]").yellow().bold(),
55                CheckStatus::Fail => style("[x]").red().bold(),
56            };
57            println!(
58                "  {} {:<30} {}",
59                icon,
60                check.name,
61                style(&check.detail).dim()
62            );
63
64            match check.status {
65                CheckStatus::Pass => pass += 1,
66                CheckStatus::Warn => warn += 1,
67                CheckStatus::Fail => fail += 1,
68            }
69        }
70
71        println!();
72        println!(
73            "  {} passed, {} warnings, {} errors",
74            style(pass).green().bold(),
75            style(warn).yellow().bold(),
76            style(fail).red().bold(),
77        );
78
79        if fail > 0 {
80            println!();
81            println!(
82                "  {}",
83                style("Fix the errors above to get Ward working.").red()
84            );
85        } else if warn > 0 {
86            println!();
87            println!(
88                "  {}",
89                style("Ward is functional but some things could be improved.").yellow()
90            );
91        } else {
92            println!();
93            println!("  {}", style("Everything looks good.").green());
94        }
95
96        println!();
97        Ok(())
98    }
99}
100
101fn check_config(path: Option<&str>) -> Check {
102    let config_path = path.unwrap_or("ward.toml");
103    if std::path::Path::new(config_path).exists() {
104        match std::fs::read_to_string(config_path) {
105            Ok(content) => match toml::from_str::<Manifest>(&content) {
106                Ok(_) => Check {
107                    name: "Configuration",
108                    status: CheckStatus::Pass,
109                    detail: format!("{config_path} found and valid"),
110                },
111                Err(e) => Check {
112                    name: "Configuration",
113                    status: CheckStatus::Fail,
114                    detail: format!("parse error: {e}"),
115                },
116            },
117            Err(e) => Check {
118                name: "Configuration",
119                status: CheckStatus::Fail,
120                detail: format!("cannot read: {e}"),
121            },
122        }
123    } else {
124        Check {
125            name: "Configuration",
126            status: CheckStatus::Fail,
127            detail: format!("{config_path} not found -- run 'ward init'"),
128        }
129    }
130}
131
132fn check_token() -> Check {
133    match auth::resolve_token() {
134        Ok(token) => {
135            let prefix = &token[..std::cmp::min(8, token.len())];
136            let source = if std::env::var("GH_TOKEN").is_ok() {
137                "GH_TOKEN"
138            } else if std::env::var("GITHUB_TOKEN").is_ok() {
139                "GITHUB_TOKEN"
140            } else {
141                "gh auth token"
142            };
143            Check {
144                name: "GitHub token",
145                status: CheckStatus::Pass,
146                detail: format!("{prefix}... via {source}"),
147            }
148        }
149        Err(e) => Check {
150            name: "GitHub token",
151            status: CheckStatus::Fail,
152            detail: format!("{e}"),
153        },
154    }
155}
156
157fn check_gh_cli() -> Check {
158    match std::process::Command::new("gh").arg("--version").output() {
159        Ok(output) if output.status.success() => {
160            let version = String::from_utf8_lossy(&output.stdout);
161            let version_line = version.lines().next().unwrap_or("unknown").trim();
162            Check {
163                name: "GitHub CLI",
164                status: CheckStatus::Pass,
165                detail: version_line.to_string(),
166            }
167        }
168        _ => Check {
169            name: "GitHub CLI",
170            status: CheckStatus::Warn,
171            detail: "not installed (optional, used for token fallback)".to_string(),
172        },
173    }
174}
175
176fn check_templates_dir() -> Check {
177    let dir = dirs_path("templates");
178    if dir.exists() {
179        let count = std::fs::read_dir(&dir)
180            .map(|entries| entries.filter_map(|e| e.ok()).count())
181            .unwrap_or(0);
182        Check {
183            name: "Custom templates",
184            status: CheckStatus::Pass,
185            detail: format!("{} custom templates in {}", count, dir.display()),
186        }
187    } else {
188        Check {
189            name: "Custom templates",
190            status: CheckStatus::Pass,
191            detail: "no custom templates directory (using built-ins only)".to_string(),
192        }
193    }
194}
195
196fn check_audit_log() -> Check {
197    let log = dirs_path("audit.log");
198    if log.exists() {
199        match std::fs::metadata(&log) {
200            Ok(meta) => {
201                let size_kb = meta.len() / 1024;
202                let detail = if size_kb > 10_000 {
203                    format!("{} KB -- consider rotating", size_kb)
204                } else {
205                    format!("{} KB", size_kb)
206                };
207                Check {
208                    name: "Audit log",
209                    status: if size_kb > 10_000 {
210                        CheckStatus::Warn
211                    } else {
212                        CheckStatus::Pass
213                    },
214                    detail,
215                }
216            }
217            Err(_) => Check {
218                name: "Audit log",
219                status: CheckStatus::Pass,
220                detail: "exists but unreadable".to_string(),
221            },
222        }
223    } else {
224        Check {
225            name: "Audit log",
226            status: CheckStatus::Pass,
227            detail: "not yet created (will be on first apply)".to_string(),
228        }
229    }
230}
231
232fn check_org(manifest: &Manifest) -> Check {
233    if manifest.org.name.is_empty() {
234        Check {
235            name: "Organization",
236            status: CheckStatus::Fail,
237            detail: "org.name is empty in ward.toml".to_string(),
238        }
239    } else {
240        Check {
241            name: "Organization",
242            status: CheckStatus::Pass,
243            detail: manifest.org.name.clone(),
244        }
245    }
246}
247
248fn check_systems(manifest: &Manifest) -> Check {
249    let count = manifest.systems.len();
250    if count == 0 {
251        Check {
252            name: "Systems",
253            status: CheckStatus::Warn,
254            detail: "no systems defined -- add [[systems]] to ward.toml".to_string(),
255        }
256    } else {
257        let names: Vec<&str> = manifest.systems.iter().map(|s| s.id.as_str()).collect();
258        Check {
259            name: "Systems",
260            status: CheckStatus::Pass,
261            detail: format!("{count} defined ({})", names.join(", ")),
262        }
263    }
264}
265
266fn check_policies(manifest: &Manifest) -> Check {
267    let count = manifest.policies.len();
268    if count == 0 {
269        Check {
270            name: "Policies",
271            status: CheckStatus::Pass,
272            detail: "none defined (optional)".to_string(),
273        }
274    } else {
275        Check {
276            name: "Policies",
277            status: CheckStatus::Pass,
278            detail: format!("{count} rules configured"),
279        }
280    }
281}
282
283async fn check_api_connectivity(config_path: Option<&str>) -> Check {
284    let manifest = match Manifest::load(config_path) {
285        Ok(m) => m,
286        Err(_) => {
287            return Check {
288                name: "API connectivity",
289                status: CheckStatus::Fail,
290                detail: "cannot load config".to_string(),
291            };
292        }
293    };
294
295    let token = match auth::resolve_token() {
296        Ok(t) => t,
297        Err(_) => {
298            return Check {
299                name: "API connectivity",
300                status: CheckStatus::Fail,
301                detail: "no token available".to_string(),
302            };
303        }
304    };
305
306    let client = match reqwest::Client::builder()
307        .default_headers({
308            let mut headers = reqwest::header::HeaderMap::new();
309            headers.insert(
310                reqwest::header::AUTHORIZATION,
311                reqwest::header::HeaderValue::from_str(&format!("Bearer {token}"))
312                    .unwrap_or_else(|_| reqwest::header::HeaderValue::from_static("")),
313            );
314            headers.insert(
315                reqwest::header::ACCEPT,
316                reqwest::header::HeaderValue::from_static("application/vnd.github+json"),
317            );
318            headers.insert(
319                reqwest::header::USER_AGENT,
320                reqwest::header::HeaderValue::from_static("ward-cli/doctor"),
321            );
322            headers
323        })
324        .build()
325    {
326        Ok(c) => c,
327        Err(_) => {
328            return Check {
329                name: "API connectivity",
330                status: CheckStatus::Fail,
331                detail: "cannot build HTTP client".to_string(),
332            };
333        }
334    };
335
336    let url = format!("https://api.github.com/orgs/{}", manifest.org.name);
337    match client.get(&url).send().await {
338        Ok(resp) => {
339            let status = resp.status();
340            let remaining = resp
341                .headers()
342                .get("x-ratelimit-remaining")
343                .and_then(|v| v.to_str().ok())
344                .unwrap_or("?");
345
346            if status.is_success() {
347                Check {
348                    name: "API connectivity",
349                    status: CheckStatus::Pass,
350                    detail: format!(
351                        "authenticated to {} (rate limit: {} remaining)",
352                        manifest.org.name, remaining
353                    ),
354                }
355            } else if status.as_u16() == 401 {
356                Check {
357                    name: "API connectivity",
358                    status: CheckStatus::Fail,
359                    detail: "401 Unauthorized -- token is invalid or expired".to_string(),
360                }
361            } else if status.as_u16() == 403 {
362                Check {
363                    name: "API connectivity",
364                    status: CheckStatus::Fail,
365                    detail: format!(
366                        "403 Forbidden -- token lacks access to {}",
367                        manifest.org.name
368                    ),
369                }
370            } else if status.as_u16() == 404 {
371                Check {
372                    name: "API connectivity",
373                    status: CheckStatus::Fail,
374                    detail: format!(
375                        "org '{}' not found -- check org.name in ward.toml",
376                        manifest.org.name
377                    ),
378                }
379            } else {
380                Check {
381                    name: "API connectivity",
382                    status: CheckStatus::Warn,
383                    detail: format!("unexpected status: {status}"),
384                }
385            }
386        }
387        Err(e) => Check {
388            name: "API connectivity",
389            status: CheckStatus::Fail,
390            detail: format!("connection failed: {e}"),
391        },
392    }
393}
394
395fn dirs_path(name: &str) -> std::path::PathBuf {
396    let home = std::env::var("HOME").unwrap_or_else(|_| ".".to_string());
397    std::path::PathBuf::from(home).join(".ward").join(name)
398}
399
400#[cfg(test)]
401mod tests {
402    use super::*;
403
404    #[test]
405    fn test_check_config_missing_file() {
406        let check = check_config(Some("/nonexistent/path/ward.toml"));
407        assert!(matches!(check.status, CheckStatus::Fail));
408        assert!(check.detail.contains("not found"));
409    }
410
411    #[test]
412    fn test_check_config_valid_file() {
413        let dir = tempfile::tempdir().unwrap();
414        let path = dir.path().join("ward.toml");
415        std::fs::write(&path, "[org]\nname = \"test-org\"\n").unwrap();
416        let check = check_config(Some(path.to_str().unwrap()));
417        assert!(matches!(check.status, CheckStatus::Pass));
418    }
419
420    #[test]
421    fn test_check_config_invalid_toml() {
422        let dir = tempfile::tempdir().unwrap();
423        let path = dir.path().join("ward.toml");
424        std::fs::write(&path, "this is not valid toml [[[").unwrap();
425        let check = check_config(Some(path.to_str().unwrap()));
426        assert!(matches!(check.status, CheckStatus::Fail));
427        assert!(check.detail.contains("parse error"));
428    }
429
430    #[test]
431    fn test_check_org_empty() {
432        let manifest = Manifest::default();
433        let check = check_org(&manifest);
434        assert!(matches!(check.status, CheckStatus::Fail));
435    }
436
437    #[test]
438    fn test_check_org_valid() {
439        let mut manifest = Manifest::default();
440        manifest.org.name = "my-org".to_string();
441        let check = check_org(&manifest);
442        assert!(matches!(check.status, CheckStatus::Pass));
443        assert!(check.detail.contains("my-org"));
444    }
445
446    #[test]
447    fn test_check_systems_none() {
448        let manifest = Manifest::default();
449        let check = check_systems(&manifest);
450        assert!(matches!(check.status, CheckStatus::Warn));
451    }
452
453    #[test]
454    fn test_check_systems_present() {
455        let mut manifest = Manifest::default();
456        manifest
457            .systems
458            .push(crate::config::manifest::SystemConfig {
459                id: "backend".to_string(),
460                name: "Backend".to_string(),
461                exclude: vec![],
462                repos: vec![],
463                security: None,
464                teams: vec![],
465            });
466        let check = check_systems(&manifest);
467        assert!(matches!(check.status, CheckStatus::Pass));
468        assert!(check.detail.contains("backend"));
469    }
470
471    #[test]
472    fn test_check_policies_none() {
473        let manifest = Manifest::default();
474        let check = check_policies(&manifest);
475        assert!(matches!(check.status, CheckStatus::Pass));
476        assert!(check.detail.contains("none"));
477    }
478}