caretta 0.16.4

caretta agent
// Copyright (c) 2026 Geoff Seemueller
//
// Licensed under the MIT License or Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// See LICENSE-MIT or LICENSE-APACHE for the full license text.
//
// Additionally, this file is subject to the Revenue Sharing Agreement terms
// as defined in REVENUE-SHARING.md for covered organizations.

use std::env;
use std::fs;
use std::path::PathBuf;

/// Marker env var set when `GROK_CREDENTIALS` was restored to `~/.grok/auth.json`.
pub const GROK_AUTH_MANAGED_ENV: &str = "CARETTA_GROK_AUTH_MANAGED";

fn grok_home_dir() -> PathBuf {
    if let Ok(configured) = env::var("GROK_HOME") {
        let configured = configured.trim();
        if !configured.is_empty() {
            return PathBuf::from(configured);
        }
    }
    dirs::home_dir()
        .unwrap_or_else(|| PathBuf::from("."))
        .join(".grok")
}

fn grok_auth_json_path() -> PathBuf {
    grok_home_dir().join("auth.json")
}

fn decode_grok_credentials(raw: &str) -> Result<Vec<u8>, String> {
    let trimmed = raw.trim();
    if trimmed.is_empty() {
        return Err("GROK_CREDENTIALS is empty".to_string());
    }
    if trimmed.starts_with('{') {
        return Ok(trimmed.as_bytes().to_vec());
    }
    decode_base64(trimmed)
}

fn decode_base64(input: &str) -> Result<Vec<u8>, String> {
    const TABLE: &[u8; 256] = &{
        let mut table = [255u8; 256];
        let alphabet = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
        let mut i = 0;
        while i < alphabet.len() {
            table[alphabet[i] as usize] = i as u8;
            i += 1;
        }
        table
    };

    let mut output = Vec::with_capacity(input.len() * 3 / 4);
    let mut buf = [0u8; 4];
    let mut buf_len = 0usize;

    for ch in input.bytes() {
        if ch.is_ascii_whitespace() {
            continue;
        }
        if ch == b'=' {
            break;
        }
        let value = TABLE[ch as usize];
        if value == 255 {
            return Err("GROK_CREDENTIALS is not valid base64".to_string());
        }
        buf[buf_len] = value;
        buf_len += 1;
        if buf_len == 4 {
            output.push((buf[0] << 2) | (buf[1] >> 4));
            output.push(((buf[1] & 0x0f) << 4) | (buf[2] >> 2));
            output.push(((buf[2] & 0x03) << 6) | buf[3]);
            buf_len = 0;
        }
    }

    match buf_len {
        0 => {}
        2 => output.push((buf[0] << 2) | (buf[1] >> 4)),
        3 => {
            output.push((buf[0] << 2) | (buf[1] >> 4));
            output.push(((buf[1] & 0x0f) << 4) | (buf[2] >> 2));
        }
        _ => return Err("GROK_CREDENTIALS is not valid base64".to_string()),
    }

    if output.is_empty() {
        return Err("GROK_CREDENTIALS decoded to empty auth.json".to_string());
    }
    Ok(output)
}

/// Restore SuperGrok / X Premium+ OIDC session for ephemeral CI runners.
///
/// `GROK_CREDENTIALS` holds base64-encoded `~/.grok/auth.json` from `grok login`
/// on a trusted machine. The official Grok CLI reads that file and refreshes
/// tokens via the stored `refresh_token`.
pub fn restore_grok_auth_from_env() -> bool {
    let raw = match env::var("GROK_CREDENTIALS") {
        Ok(value) => value,
        Err(_) => return false,
    };

    let auth_bytes = match decode_grok_credentials(&raw) {
        Ok(bytes) => bytes,
        Err(err) => {
            crate::agent::cmd::log(&format!("WARNING: {err}; skipping Grok auth restore"));
            return false;
        }
    };

    let grok_home = grok_home_dir();
    if let Err(err) = fs::create_dir_all(&grok_home) {
        crate::agent::cmd::log(&format!(
            "WARNING: failed to create {}: {err}; skipping Grok auth restore",
            grok_home.display()
        ));
        return false;
    }

    let auth_path = grok_auth_json_path();
    if let Err(err) = fs::write(&auth_path, &auth_bytes) {
        crate::agent::cmd::log(&format!(
            "WARNING: failed to write {}: {err}; skipping Grok auth restore",
            auth_path.display()
        ));
        return false;
    }

    #[cfg(unix)]
    {
        use std::os::unix::fs::PermissionsExt;
        if let Err(err) = fs::set_permissions(&auth_path, fs::Permissions::from_mode(0o600)) {
            crate::agent::cmd::log(&format!(
                "WARNING: failed to chmod {}: {err}",
                auth_path.display()
            ));
        }
        if let Err(err) = fs::set_permissions(&grok_home, fs::Permissions::from_mode(0o700)) {
            crate::agent::cmd::log(&format!(
                "WARNING: failed to chmod {}: {err}",
                grok_home.display()
            ));
        }
    }

    unsafe {
        env::set_var(GROK_AUTH_MANAGED_ENV, "1");
    }
    crate::agent::cmd::log(&format!(
        "Restored Grok auth.json for CI ({})",
        auth_path.display()
    ));
    true
}

pub fn grok_subscription_auth_active() -> bool {
    env::var(GROK_AUTH_MANAGED_ENV).is_ok_and(|value| value == "1")
}

#[cfg(test)]
mod tests {
    use super::{decode_base64, decode_grok_credentials};

    #[test]
    fn decodes_base64_grok_credentials() {
        let decoded = decode_grok_credentials("eyJzY29wZSI6InRva2VuIn0=").expect("decode");
        assert_eq!(decoded, br#"{"scope":"token"}"#);
    }

    #[test]
    fn accepts_raw_json_grok_credentials() {
        let decoded = decode_grok_credentials(r#"{"scope":"token"}"#).expect("decode");
        assert_eq!(decoded, br#"{"scope":"token"}"#);
    }

    #[test]
    fn rejects_invalid_base64() {
        assert!(decode_base64("not!!!").is_err());
    }
}