treadmill-cli 0.3.1

CLI client for the Treadmill distributed hardware testbed
// auth.rs

use anyhow::{Context, Result, anyhow, bail};
use base64::Engine; // Add Engine trait
use ssh_key::{LineEnding, PrivateKey}; // Remove unused PublicKey import
use std::fs;
use std::os::unix::fs::PermissionsExt; // Add PermissionsExt trait
use std::path::PathBuf;
use treadmill_rs::api::switchboard::{AuthToken, JobSshEndpoint};
use xdg::BaseDirectories;

pub fn save_token(token: &AuthToken) -> Result<()> {
    let token_path = get_token_path()?;
    fs::write(&token_path, serde_json::to_string(token)?)
        .with_context(|| format!("Failed to write token to {token_path:?}"))?;
    Ok(())
}

pub fn get_token() -> Result<String> {
    match std::env::var("TML_API_TOKEN") {
        Ok(token_b64) => {
            let token_bytes = base64::engine::general_purpose::STANDARD
                .decode(token_b64)
                .context("Decoding Base64-encoded TML_API_TOKEN variable")?;

            let token_array: [u8; 128] = token_bytes.try_into().map_err(|vec: Vec<u8>| {
                anyhow!(
                    "TML_API_TOKEN has invalid length ({} bytes instead of 128 bytes)",
                    vec.len()
                )
            })?;

            Ok(AuthToken(token_array).encode_for_http())
        }
        Err(std::env::VarError::NotUnicode(_)) => {
            bail!("Supplied TML_API_TOKEN is not valid UTF-8");
        }
        Err(std::env::VarError::NotPresent) => {
            let token_path = get_token_path()?;
            let token_str = fs::read_to_string(&token_path)
                .with_context(|| format!("Failed to read token from {token_path:?}"))?;
            let token: AuthToken =
                serde_json::from_str(&token_str).with_context(|| "Failed to parse token JSON")?;
            Ok(token.encode_for_http())
        }
    }
}

fn get_token_path() -> Result<PathBuf> {
    let xdg_dirs = BaseDirectories::with_prefix("treadmill-tb")
        .context("Failed to initialize XDG base directories")?;
    xdg_dirs
        .place_data_file("token.json")
        .context("Failed to determine token file path")
}

pub fn ssh_private_key_path() -> Result<PathBuf> {
    let xdg_dirs = BaseDirectories::with_prefix("treadmill-tb")
        .context("Failed to initialize XDG base directories")?;

    xdg_dirs
        .place_data_file("ssh-key")
        .context("Failed to determine SSH key path")
}

// Save the private key to the Treadmill data directory
fn save_private_key(private_key: &PrivateKey) -> Result<PathBuf> {
    let key_path = ssh_private_key_path()?;

    let openssh_private_key = private_key
        .to_openssh(LineEnding::LF)
        .map_err(|e| anyhow!("Failed to convert private key to OpenSSH format: {}", e))?;

    fs::write(&key_path, openssh_private_key)?;
    let mut perms = fs::metadata(&key_path)?.permissions();
    perms.set_mode(0o600);
    fs::set_permissions(&key_path, perms)?;

    Ok(key_path)
}

// Generate a new Ed25519 key pair for jobs
pub fn generate_or_load_job_ssh_key() -> Result<String> {
    let key_path = ssh_private_key_path()?;

    let private_key = if key_path.exists() {
        let openssh_pem_bytes =
            std::fs::read(key_path).context("Reading Treadmill CLI SSH private key file")?;
        PrivateKey::from_openssh(&openssh_pem_bytes)
            .context("Parsing Treadmill CLI OpenSSH PEM-formatted private key file")?
    } else {
        println!("Generating new SSH keypair...");
        let private_key = PrivateKey::random(&mut rand_core::OsRng, ssh_key::Algorithm::Ed25519)
            .map_err(|e| anyhow!("Failed to generate Ed25519 key: {}", e))?;
        save_private_key(&private_key)?;
        private_key
    };

    let public_key = private_key.public_key();
    let public_key_str = public_key
        .to_openssh()
        .map_err(|e| anyhow!("Failed to convert public key to OpenSSH format: {}", e))?;

    Ok(public_key_str)
}

// Read the job's SSH endpoints from the switchboard API
pub async fn get_job_ssh_user_endpoints(
    client: &reqwest::Client,
    config: &crate::config::Config,
    job_id: uuid::Uuid,
) -> Result<(Option<String>, Vec<JobSshEndpoint>)> {
    let token = get_token()?;

    let response = client
        .get(format!("{}/api/v1/jobs/{}/status", config.api.url, job_id))
        .bearer_auth(token)
        .send()
        .await?;

    if response.status().is_success() {
        let status: treadmill_rs::api::switchboard::jobs::status::Response =
            response.json().await?;
        match status {
            treadmill_rs::api::switchboard::jobs::status::Response::Ok {
                job_status: status,
                ..
            } => Ok((
                status.state.ssh_user,
                status
                    .state
                    .ssh_endpoints
                    .unwrap_or_else(std::vec::Vec::new),
            )),
            _ => Err(anyhow!("Failed to get job status")),
        }
    } else {
        Err(anyhow!("Failed to get job status: {}", response.status()))
    }
}