use crate::api::DiscourseClient;
use crate::cli::ListFormat;
use crate::config::{Config, DiscourseConfig};
use anyhow::{Result, anyhow};
use serde::Serialize;
use std::collections::{HashMap, VecDeque};
use std::io::Write;
use std::process::{Command, Stdio};
use std::sync::{Arc, Mutex, mpsc};
use std::thread;
const DEFAULT_PARALLEL_CHECK_WORKERS: usize = 8;
#[derive(Serialize)]
struct CheckReport {
name: String,
baseurl: String,
api: CheckStatus,
ssh: Option<CheckStatus>,
}
#[derive(Serialize)]
struct CheckStatus {
ok: bool,
detail: String,
}
pub fn config_check(
config: &Config,
format: ListFormat,
skip_ssh: bool,
parallel: bool,
max: Option<usize>,
) -> Result<()> {
if config.discourse.is_empty() {
return Err(anyhow!("no discourses configured"));
}
let text = matches!(format, ListFormat::Text);
let name_width = config
.discourse
.iter()
.map(|d| d.name.len())
.max()
.unwrap_or(0)
.max(4);
let workers = if parallel {
check_worker_count(max, config.discourse.len())
} else {
1
};
eprintln!(
"Checking {} discourse(s) via API{}{}... contacts each over the network and can take a while.",
config.discourse.len(),
if skip_ssh { "" } else { " + SSH" },
if workers > 1 {
format!(" ({workers} parallel)")
} else {
String::new()
}
);
let on_done = |report: &CheckReport| {
if text {
print_report_text(report, name_width);
let _ = std::io::stdout().flush();
}
};
let mut reports = if parallel && workers > 1 {
check_parallel(&config.discourse, skip_ssh, workers, on_done)
} else {
let mut out = Vec::with_capacity(config.discourse.len());
for discourse in &config.discourse {
let report = check_one(discourse, skip_ssh);
on_done(&report);
out.push(report);
}
out
};
if !text {
let order: HashMap<&str, usize> = config
.discourse
.iter()
.enumerate()
.map(|(i, d)| (d.name.as_str(), i))
.collect();
reports.sort_by_key(|r| order.get(r.name.as_str()).copied().unwrap_or(usize::MAX));
}
let failed = reports
.iter()
.filter(|r| !(r.api.ok && r.ssh.as_ref().map(|s| s.ok).unwrap_or(true)))
.count();
match format {
ListFormat::Text => eprintln!(
"Done: {} ok, {} failed (of {}).",
reports.len() - failed,
failed,
reports.len()
),
ListFormat::Json => println!("{}", serde_json::to_string_pretty(&reports)?),
ListFormat::Yaml => println!("{}", serde_yaml::to_string(&reports)?),
}
if failed == 0 {
Ok(())
} else {
Err(anyhow!("{failed} discourse(s) failed checks"))
}
}
fn check_one(discourse: &DiscourseConfig, skip_ssh: bool) -> CheckReport {
let api = check_api(discourse);
let ssh = if skip_ssh {
None
} else {
discourse
.ssh_host
.as_deref()
.filter(|h| !h.trim().is_empty())
.map(check_ssh)
};
CheckReport {
name: discourse.name.clone(),
baseurl: discourse.baseurl.clone(),
api,
ssh,
}
}
fn check_parallel(
discourses: &[DiscourseConfig],
skip_ssh: bool,
workers: usize,
mut on_done: impl FnMut(&CheckReport),
) -> Vec<CheckReport> {
let queue: Arc<Mutex<VecDeque<DiscourseConfig>>> =
Arc::new(Mutex::new(discourses.iter().cloned().collect()));
let (tx, rx) = mpsc::channel::<CheckReport>();
let mut handles = Vec::with_capacity(workers);
for _ in 0..workers {
let queue = Arc::clone(&queue);
let tx = tx.clone();
handles.push(thread::spawn(move || {
loop {
let next = queue.lock().unwrap().pop_front();
let Some(discourse) = next else { break };
if tx.send(check_one(&discourse, skip_ssh)).is_err() {
break;
}
}
}));
}
drop(tx);
let mut reports = Vec::with_capacity(discourses.len());
for report in rx {
on_done(&report);
reports.push(report);
}
for handle in handles {
let _ = handle.join();
}
reports
}
fn check_worker_count(max: Option<usize>, count: usize) -> usize {
max.unwrap_or(DEFAULT_PARALLEL_CHECK_WORKERS)
.max(1)
.min(count.max(1))
}
fn check_api(discourse: &DiscourseConfig) -> CheckStatus {
if discourse.baseurl.trim().is_empty() {
return CheckStatus {
ok: false,
detail: "missing baseurl".to_string(),
};
}
let client = match DiscourseClient::new(discourse) {
Ok(client) => client,
Err(err) => {
return CheckStatus {
ok: false,
detail: format!("client init failed: {err}"),
};
}
};
match client.get("/about.json") {
Ok(response) => {
let status = response.status();
if status.is_success() {
CheckStatus {
ok: true,
detail: format!("{} OK", status.as_u16()),
}
} else if status == reqwest::StatusCode::UNAUTHORIZED
|| status == reqwest::StatusCode::FORBIDDEN
{
CheckStatus {
ok: false,
detail: format!("{} — check apikey/api_username", status.as_u16()),
}
} else {
CheckStatus {
ok: false,
detail: format!("HTTP {}", status.as_u16()),
}
}
}
Err(err) => CheckStatus {
ok: false,
detail: format!("request failed: {err}"),
},
}
}
fn check_ssh(host: &str) -> CheckStatus {
let output = Command::new("ssh")
.args([
"-o",
"BatchMode=yes",
"-o",
"ConnectTimeout=5",
"-o",
"StrictHostKeyChecking=accept-new",
host,
"true",
])
.stdin(Stdio::null())
.stdout(Stdio::null())
.stderr(Stdio::piped())
.output();
match output {
Ok(out) if out.status.success() => CheckStatus {
ok: true,
detail: format!("ssh {} OK", host),
},
Ok(out) => {
let stderr = String::from_utf8_lossy(&out.stderr);
let first = stderr.lines().next().unwrap_or("").trim();
CheckStatus {
ok: false,
detail: if first.is_empty() {
format!("ssh failed (exit {})", out.status)
} else {
format!("ssh: {}", first)
},
}
}
Err(err) => CheckStatus {
ok: false,
detail: format!("ssh spawn failed: {err}"),
},
}
}
#[cfg(test)]
mod tests {
use super::{check_api, check_worker_count};
use crate::config::DiscourseConfig;
#[test]
fn default_check_workers_is_eight() {
assert_eq!(check_worker_count(None, 20), 8);
}
#[test]
fn check_workers_capped_by_forum_count() {
assert_eq!(check_worker_count(Some(8), 3), 3);
assert_eq!(check_worker_count(None, 2), 2);
}
#[test]
fn check_workers_floor_of_one() {
assert_eq!(check_worker_count(Some(0), 5), 1);
}
#[test]
fn check_api_flags_empty_baseurl() {
let discourse = DiscourseConfig {
name: "empty".to_string(),
baseurl: String::new(),
..DiscourseConfig::default()
};
let status = check_api(&discourse);
assert!(!status.ok);
assert!(
status.detail.contains("missing baseurl"),
"expected missing-baseurl detail, got {:?}",
status.detail
);
}
#[test]
fn check_api_flags_whitespace_baseurl() {
let discourse = DiscourseConfig {
name: "ws".to_string(),
baseurl: " ".to_string(),
..DiscourseConfig::default()
};
let status = check_api(&discourse);
assert!(!status.ok);
assert!(status.detail.contains("missing baseurl"));
}
}
fn print_report_text(r: &CheckReport, name_width: usize) {
let api_mark = if r.api.ok { "ok " } else { "FAIL" };
println!(
"{:<width$} api {:<4} {}",
r.name,
api_mark,
r.api.detail,
width = name_width
);
if let Some(ssh) = &r.ssh {
let mark = if ssh.ok { "ok " } else { "FAIL" };
println!(
"{:<width$} ssh {:<4} {}",
r.name,
mark,
ssh.detail,
width = name_width
);
}
}