use std::time::Duration;
use anyhow::{Context, Result};
use base64::Engine;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize)]
pub struct ExecRequest {
pub command: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub timeout_secs: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub working_dir: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExecResponse {
pub stdout: String,
pub stderr: String,
pub exit_code: i32,
pub duration_ms: u64,
pub timed_out: bool,
}
fn normalize_endpoint(host: &str, port: u16) -> String {
let h = host.trim();
if h.starts_with("http://") || h.starts_with("https://") {
if let Some(rest) = h
.strip_prefix("http://")
.or_else(|| h.strip_prefix("https://"))
{
let scheme = if h.starts_with("https://") {
"https"
} else {
"http"
};
let host_part = rest.split('/').next().unwrap_or(rest);
if host_part.contains(':') {
return format!("{}://{}", scheme, rest.trim_end_matches('/'));
}
return format!("{}://{}:{}", scheme, rest.trim_end_matches('/'), port);
}
}
format!("http://{}:{}", h, port)
}
pub fn basic_auth(user: &str, pass: &str) -> String {
let creds = format!("{}:{}", user, pass);
let encoded = base64::engine::general_purpose::STANDARD.encode(creds);
format!("Basic {}", encoded)
}
pub async fn call_exec(
host: &str,
port: u16,
user: &str,
pass: &str,
command: &str,
timeout_secs: Option<u64>,
working_dir: Option<&str>,
total_timeout: Duration,
) -> Result<ExecResponse> {
let url = format!("{}/exec", normalize_endpoint(host, port));
let body = ExecRequest {
command: command.to_string(),
timeout_secs,
working_dir: working_dir.map(|s| s.to_string()),
};
let client = reqwest::Client::builder()
.timeout(total_timeout)
.build()
.context("failed to build reqwest client")?;
let resp = client
.post(&url)
.header("Authorization", basic_auth(user, pass))
.json(&body)
.send()
.await
.with_context(|| format!("POST {} failed", url))?;
let status = resp.status();
if !status.is_success() {
let body = resp.text().await.unwrap_or_default();
anyhow::bail!("exec server returned HTTP {}: {}", status, body);
}
resp.json::<ExecResponse>()
.await
.context("exec server response was not the expected JSON shape")
}
pub async fn call_health(host: &str, port: u16, total_timeout: Duration) -> Result<()> {
let url = format!("{}/health", normalize_endpoint(host, port));
let client = reqwest::Client::builder()
.timeout(total_timeout)
.build()
.context("failed to build reqwest client")?;
let resp = client
.get(&url)
.send()
.await
.with_context(|| format!("GET {} failed", url))?;
if !resp.status().is_success() {
anyhow::bail!("health check returned HTTP {}", resp.status());
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn endpoint_bare_host_gets_http_prefix() {
assert_eq!(normalize_endpoint("1.2.3.4", 8080), "http://1.2.3.4:8080");
assert_eq!(
normalize_endpoint("example.com", 8080),
"http://example.com:8080"
);
}
#[test]
fn endpoint_keeps_explicit_scheme() {
assert_eq!(
normalize_endpoint("http://1.2.3.4", 8080),
"http://1.2.3.4:8080"
);
assert_eq!(
normalize_endpoint("https://example.com", 9090),
"https://example.com:9090"
);
}
#[test]
fn endpoint_keeps_explicit_port_in_url() {
assert_eq!(
normalize_endpoint("http://1.2.3.4:7777", 8080),
"http://1.2.3.4:7777"
);
}
#[test]
fn endpoint_strips_trailing_slash() {
assert_eq!(
normalize_endpoint("http://example.com/", 8080),
"http://example.com:8080"
);
}
#[test]
fn basic_auth_matches_servers_python_format() {
assert_eq!(basic_auth("root", "hunter2"), "Basic cm9vdDpodW50ZXIy");
}
#[test]
fn exec_request_omits_optional_fields_when_none() {
let r = ExecRequest {
command: "ls".to_string(),
timeout_secs: None,
working_dir: None,
};
let json = serde_json::to_string(&r).unwrap();
assert!(!json.contains("timeout_secs"));
assert!(!json.contains("working_dir"));
assert!(json.contains(r#""command":"ls""#));
}
}