use std::collections::HashMap;
use crate::core::jwt::{self, AtiNamespace, TokenClaims};
use crate::TokenCommands;
pub fn execute(subcmd: &TokenCommands) -> Result<(), Box<dyn std::error::Error>> {
match subcmd {
TokenCommands::Keygen { algorithm } => keygen(algorithm),
TokenCommands::Issue {
sub,
scope,
ttl,
aud,
iss,
key,
secret,
rate,
} => issue(
sub,
scope,
*ttl,
aud.as_deref(),
iss.as_deref(),
key.as_deref(),
secret.as_deref(),
rate,
),
TokenCommands::Inspect { token } => inspect(token),
TokenCommands::Validate { token, key, secret } => {
validate(token, key.as_deref(), secret.as_deref())
}
}
}
fn keygen(algorithm: &str) -> Result<(), Box<dyn std::error::Error>> {
match algorithm.to_uppercase().as_str() {
"ES256" => {
use base64::Engine;
use ring::signature::KeyPair;
let rng = ring::rand::SystemRandom::new();
let pkcs8_bytes = ring::signature::EcdsaKeyPair::generate_pkcs8(
&ring::signature::ECDSA_P256_SHA256_ASN1_SIGNING,
&rng,
)
.map_err(|e| format!("Key generation failed: {e}"))?;
let b64 = base64::engine::general_purpose::STANDARD.encode(pkcs8_bytes.as_ref());
let private_pem = format_pem("PRIVATE KEY", &b64);
let key_pair = ring::signature::EcdsaKeyPair::from_pkcs8(
&ring::signature::ECDSA_P256_SHA256_ASN1_SIGNING,
pkcs8_bytes.as_ref(),
&rng,
)
.map_err(|e| format!("Failed to load generated key: {e}"))?;
let public_key_bytes = key_pair.public_key().as_ref();
let spki_der = wrap_ec_public_key_spki(public_key_bytes);
let pub_b64 = base64::engine::general_purpose::STANDARD.encode(&spki_der);
let public_pem = format_pem("PUBLIC KEY", &pub_b64);
eprintln!("Generated ES256 key pair");
eprintln!("=== Private Key (keep secret — for token issuance only) ===");
println!("{private_pem}");
eprintln!("=== Public Key (distribute — for token validation) ===");
println!("{public_pem}");
eprintln!("Save the private key to a file and set ATI_JWT_PRIVATE_KEY=<path>");
eprintln!("Save the public key to a file and set ATI_JWT_PUBLIC_KEY=<path>");
}
"HS256" => {
let mut secret = [0u8; 32];
use ring::rand::SecureRandom;
let rng = ring::rand::SystemRandom::new();
rng.fill(&mut secret)
.map_err(|e| format!("Random generation failed: {e}"))?;
let hex = hex::encode(secret);
eprintln!("Generated HS256 shared secret");
println!("{hex}");
eprintln!("Set ATI_JWT_SECRET=<secret above>");
}
_ => {
return Err(format!("Unsupported algorithm: {algorithm}. Use ES256 or HS256.").into());
}
}
Ok(())
}
fn issue(
sub: &str,
scope: &str,
ttl: u64,
aud: Option<&str>,
iss: Option<&str>,
key_path: Option<&str>,
secret_hex: Option<&str>,
rate_args: &[String],
) -> Result<(), Box<dyn std::error::Error>> {
let now = jwt::now_secs();
let mut rate_map = HashMap::new();
for arg in rate_args {
let parts: Vec<&str> = arg.splitn(2, '=').collect();
if parts.len() == 2 {
rate_map.insert(parts[0].to_string(), parts[1].to_string());
} else {
return Err(format!(
"Invalid rate spec '{}': expected pattern=count/unit (e.g. tool:github:*=10/hour)",
arg
)
.into());
}
}
let claims = TokenClaims {
iss: iss
.map(String::from)
.or_else(|| std::env::var("ATI_JWT_ISSUER").ok()),
sub: sub.to_string(),
aud: aud.unwrap_or("ati-proxy").to_string(),
iat: now,
exp: now + ttl,
jti: Some(uuid::Uuid::new_v4().to_string()),
scope: scope.to_string(),
ati: Some(AtiNamespace {
v: 1,
rate: rate_map,
}),
job_id: None,
sandbox_id: None,
};
let config = if let Some(path) = key_path {
let pem =
std::fs::read(path).map_err(|e| format!("Cannot read private key {path}: {e}"))?;
jwt::config_from_pem(
&pem, Some(&pem),
jsonwebtoken::Algorithm::ES256,
claims.iss.clone(),
claims.aud.clone(),
)?
} else if let Some(hex_str) = secret_hex {
let secret_bytes = hex::decode(hex_str).map_err(|e| format!("Invalid hex secret: {e}"))?;
jwt::config_from_secret(&secret_bytes, claims.iss.clone(), claims.aud.clone())
} else {
jwt::config_from_env()?
.ok_or("No signing key available. Provide --key <path>, --secret <hex>, or set ATI_JWT_PRIVATE_KEY / ATI_JWT_SECRET.")?
};
let token = jwt::issue(&claims, &config)?;
println!("{token}");
if atty_stderr() {
eprintln!("Token issued (sub={sub}, scope={scope}, ttl={ttl}s)");
eprintln!("Set: export ATI_SESSION_TOKEN=<token above>");
}
Ok(())
}
fn inspect(token: &str) -> Result<(), Box<dyn std::error::Error>> {
let claims = jwt::inspect(token)?;
let json = serde_json::json!({
"iss": claims.iss,
"sub": claims.sub,
"aud": claims.aud,
"iat": claims.iat,
"exp": claims.exp,
"jti": claims.jti,
"scope": claims.scope,
"scopes": claims.scopes(),
"ati": claims.ati,
});
println!("{}", serde_json::to_string_pretty(&json)?);
Ok(())
}
fn validate(
token: &str,
key_path: Option<&str>,
secret_hex: Option<&str>,
) -> Result<(), Box<dyn std::error::Error>> {
let config = if let Some(path) = key_path {
let pem = std::fs::read(path).map_err(|e| format!("Cannot read public key {path}: {e}"))?;
let audience = std::env::var("ATI_JWT_AUDIENCE").unwrap_or_else(|_| "ati-proxy".into());
let issuer = std::env::var("ATI_JWT_ISSUER").ok();
jwt::config_from_pem(&pem, None, jsonwebtoken::Algorithm::ES256, issuer, audience)?
} else if let Some(hex_str) = secret_hex {
let secret_bytes = hex::decode(hex_str).map_err(|e| format!("Invalid hex secret: {e}"))?;
let audience = std::env::var("ATI_JWT_AUDIENCE").unwrap_or_else(|_| "ati-proxy".into());
let issuer = std::env::var("ATI_JWT_ISSUER").ok();
jwt::config_from_secret(&secret_bytes, issuer, audience)
} else {
jwt::config_from_env()?
.ok_or("No validation key available. Provide --key <path>, --secret <hex>, or set ATI_JWT_PUBLIC_KEY / ATI_JWT_SECRET.")?
};
match jwt::validate(token, &config) {
Ok(claims) => {
tracing::info!(
sub = %claims.sub,
scope = %claims.scope,
exp = claims.exp,
"VALID"
);
Ok(())
}
Err(e) => {
tracing::error!("INVALID — {e}");
std::process::exit(1);
}
}
}
fn atty_stderr() -> bool {
use std::io::IsTerminal;
std::io::stderr().is_terminal()
}
fn format_pem(label: &str, b64: &str) -> String {
let mut pem = format!("-----BEGIN {label}-----\n");
for chunk in b64.as_bytes().chunks(64) {
pem.push_str(std::str::from_utf8(chunk).unwrap_or(""));
pem.push('\n');
}
pem.push_str(&format!("-----END {label}-----\n"));
pem
}
fn wrap_ec_public_key_spki(public_key_bytes: &[u8]) -> Vec<u8> {
let ec_oid: &[u8] = &[0x06, 0x07, 0x2a, 0x86, 0x48, 0xce, 0x3d, 0x02, 0x01];
let p256_oid: &[u8] = &[0x06, 0x08, 0x2a, 0x86, 0x48, 0xce, 0x3d, 0x03, 0x01, 0x07];
let alg_id_content_len = ec_oid.len() + p256_oid.len();
let mut alg_id = vec![0x30]; push_der_length(&mut alg_id, alg_id_content_len);
alg_id.extend_from_slice(ec_oid);
alg_id.extend_from_slice(p256_oid);
let bit_string_content_len = 1 + public_key_bytes.len(); let mut bit_string = vec![0x03]; push_der_length(&mut bit_string, bit_string_content_len);
bit_string.push(0x00); bit_string.extend_from_slice(public_key_bytes);
let total_content_len = alg_id.len() + bit_string.len();
let mut spki = vec![0x30]; push_der_length(&mut spki, total_content_len);
spki.extend_from_slice(&alg_id);
spki.extend_from_slice(&bit_string);
spki
}
fn push_der_length(buf: &mut Vec<u8>, len: usize) {
if len < 128 {
buf.push(len as u8);
} else if len < 256 {
buf.push(0x81);
buf.push(len as u8);
} else {
buf.push(0x82);
buf.push((len >> 8) as u8);
buf.push((len & 0xff) as u8);
}
}