use base64::Engine;
use base64::engine::general_purpose::URL_SAFE_NO_PAD;
use serde_json::Value;
use std::sync::Mutex;
use std::time::{SystemTime, UNIX_EPOCH};
use ureq::Agent;
use zeroize::Zeroizing;
use crate::error::{Error, JsonResultExt, Result, UreqResultExt};
pub struct Auth {
service_account_id: String,
key_bytes: Zeroizing<Vec<u8>>,
instance_api_url: String,
agent: Agent,
cloud_api_domain: Mutex<Option<String>>,
cached_token: Mutex<Option<CachedToken>>,
}
struct CachedToken {
token: String,
expires_at: u64,
}
impl Auth {
pub fn new(
service_account_id: String,
key_b64: &str,
instance_api_url: String,
agent: Agent,
) -> Result<Self> {
let key_bytes = Zeroizing::new(
URL_SAFE_NO_PAD
.decode(key_b64.trim())
.or_else(|_| base64::engine::general_purpose::STANDARD.decode(key_b64.trim()))
.map_err(|_| Error::InvalidServiceAccountKeyEncoding)?,
);
if key_bytes.len() != 32 {
return Err(Error::InvalidServiceAccountKeyLength {
got: key_bytes.len(),
});
}
Ok(Self {
service_account_id,
key_bytes,
instance_api_url,
agent,
cloud_api_domain: Mutex::new(None),
cached_token: Mutex::new(None),
})
}
pub fn service_account_id(&self) -> &str {
&self.service_account_id
}
pub fn get_token(&self) -> Result<String> {
let mut guard = self.cached_token.lock().map_err(|_| Error::MutexPoisoned {
name: "token cache",
})?;
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map_err(|source| Error::SystemClockBeforeUnixEpoch { source })?
.as_secs();
if let Some(ref cached) = *guard
&& cached.expires_at > now + 300
{
return Ok(cached.token.clone());
}
let domain = self.get_cloud_api_domain()?;
let sa_jwt = self.sign_jwt(now, &domain)?;
let (instance_token, expires_at) = self.exchange_token(&sa_jwt)?;
*guard = Some(CachedToken {
token: instance_token.clone(),
expires_at,
});
Ok(instance_token)
}
fn get_cloud_api_domain(&self) -> Result<String> {
let mut guard = self
.cloud_api_domain
.lock()
.map_err(|_| Error::MutexPoisoned {
name: "cloud_api_domain",
})?;
if let Some(ref domain) = *guard {
return Ok(domain.clone());
}
let url = format!("{}/api/v1/auth/config", self.instance_api_url);
let mut resp = self
.agent
.get(&url)
.call()
.with_request_context(format!("failed to fetch auth config ({url})"))?;
let status = resp.status().as_u16();
let body: String = resp
.body_mut()
.read_to_string()
.with_response_read_context("auth config response")?;
if !(200..300).contains(&status) {
return Err(Error::ApiError {
status,
message: body,
});
}
let json: Value = serde_json::from_str(&body).with_json_parse_context("auth config")?;
let domain = json
.get("cloud_api_domain")
.and_then(|v| v.as_str())
.ok_or_else(|| Error::MissingField {
context: "auth config",
field: "cloud_api_domain",
})?
.to_string();
*guard = Some(domain.clone());
Ok(domain)
}
fn sign_jwt(&self, now: u64, cloud_api_domain: &str) -> Result<String> {
let header = jsonwebtoken::Header::new(jsonwebtoken::Algorithm::EdDSA);
let claims = serde_json::json!({
"sub": self.service_account_id,
"aud": format!("https://{cloud_api_domain}/auth/token"),
"exp": now + 300,
"iat": now,
"name": self.service_account_id,
"service_account": self.service_account_id,
});
let der_key = ed25519_seed_to_pkcs8_der(&self.key_bytes)?;
let key = jsonwebtoken::EncodingKey::from_ed_der(&der_key);
jsonwebtoken::encode(&header, &claims, &key)
.map_err(|source| Error::JwtSignFailed { source })
}
fn exchange_token(&self, sa_jwt: &str) -> Result<(String, u64)> {
let url = format!("{}/api/v1/auth/token", self.instance_api_url);
let mut resp = self
.agent
.post(&url)
.header("Authorization", &format!("Bearer {sa_jwt}"))
.send_empty()
.with_request_context(format!("failed to exchange token ({url})"))?;
let status = resp.status().as_u16();
let resp_body: String = resp
.body_mut()
.read_to_string()
.with_response_read_context("token exchange response")?;
if !(200..300).contains(&status) {
return Err(Error::ApiError {
status,
message: resp_body,
});
}
let json: Value =
serde_json::from_str(&resp_body).with_json_parse_context("token exchange response")?;
let token = json
.get("access_token")
.and_then(|v| v.as_str())
.ok_or_else(|| Error::MissingField {
context: "token exchange response",
field: "access_token",
})?
.to_string();
let expires_at = json
.get("expiration")
.and_then(|v| v.as_u64())
.ok_or_else(|| Error::MissingField {
context: "token exchange response",
field: "expiration",
})?;
Ok((token, expires_at))
}
}
impl std::fmt::Debug for Auth {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Auth")
.field("service_account_id", &self.service_account_id)
.field("instance_api_url", &self.instance_api_url)
.field("service_account_key", &"[REDACTED]")
.finish()
}
}
fn ed25519_seed_to_pkcs8_der(seed: &[u8]) -> Result<Zeroizing<Vec<u8>>> {
if seed.len() != 32 {
return Err(Error::InvalidEd25519SeedLength { got: seed.len() });
}
let prefix: &[u8] = &[
0x30, 0x2e, 0x02, 0x01, 0x00, 0x30, 0x05, 0x06, 0x03, 0x2b, 0x65, 0x70, 0x04, 0x22, 0x04, 0x20, ];
let mut der = Zeroizing::new(Vec::with_capacity(prefix.len() + 32));
der.extend_from_slice(prefix);
der.extend_from_slice(seed);
Ok(der)
}
#[cfg(test)]
mod tests {
use super::*;
fn test_agent() -> Agent {
crate::new_agent()
}
fn test_seed() -> [u8; 32] {
[42u8; 32]
}
fn test_auth(seed: &[u8; 32]) -> Auth {
let b64 = URL_SAFE_NO_PAD.encode(seed);
Auth::new(
"asc-sa-test".into(),
&b64,
"https://ascend.io".into(),
test_agent(),
)
.unwrap()
}
#[test]
fn pkcs8_der_output_is_48_bytes() {
let der = ed25519_seed_to_pkcs8_der(&[0u8; 32]).unwrap();
assert_eq!(der.len(), 48);
}
#[test]
fn pkcs8_der_has_correct_prefix() {
let der = ed25519_seed_to_pkcs8_der(&[0u8; 32]).unwrap();
let expected_prefix: &[u8] = &[
0x30, 0x2e, 0x02, 0x01, 0x00, 0x30, 0x05, 0x06, 0x03, 0x2b, 0x65, 0x70, 0x04, 0x22,
0x04, 0x20,
];
assert_eq!(&der[..16], expected_prefix);
}
#[test]
fn pkcs8_der_embeds_seed() {
let seed: Vec<u8> = (0..32).collect();
let der = ed25519_seed_to_pkcs8_der(&seed).unwrap();
assert_eq!(&der[16..], &seed[..]);
}
#[test]
fn pkcs8_der_rejects_wrong_length() {
assert!(ed25519_seed_to_pkcs8_der(&[]).is_err());
assert!(ed25519_seed_to_pkcs8_der(&[0u8; 16]).is_err());
assert!(ed25519_seed_to_pkcs8_der(&[0u8; 64]).is_err());
}
#[test]
fn pkcs8_der_roundtrip_with_jsonwebtoken() {
let der = ed25519_seed_to_pkcs8_der(&test_seed()).unwrap();
let key = jsonwebtoken::EncodingKey::from_ed_der(&der);
let header = jsonwebtoken::Header::new(jsonwebtoken::Algorithm::EdDSA);
let claims = serde_json::json!({"sub": "test"});
assert!(jsonwebtoken::encode(&header, &claims, &key).is_ok());
}
#[test]
fn auth_new_accepts_url_safe_base64() {
let b64 = URL_SAFE_NO_PAD.encode(test_seed());
let auth = Auth::new("sa".into(), &b64, "https://ascend.io".into(), test_agent());
assert!(auth.is_ok());
}
#[test]
fn auth_new_accepts_standard_base64() {
let b64 = base64::engine::general_purpose::STANDARD.encode([0xFF_u8; 32]);
let auth = Auth::new("sa".into(), &b64, "https://ascend.io".into(), test_agent());
assert!(auth.is_ok());
}
#[test]
fn auth_new_rejects_wrong_key_length() {
let b64_short = URL_SAFE_NO_PAD.encode([0u8; 16]);
let auth = Auth::new(
"sa".into(),
&b64_short,
"https://ascend.io".into(),
test_agent(),
);
assert!(auth.is_err());
assert!(auth.unwrap_err().to_string().contains("32 bytes"));
let b64_long = URL_SAFE_NO_PAD.encode([0u8; 64]);
let auth = Auth::new(
"sa".into(),
&b64_long,
"https://ascend.io".into(),
test_agent(),
);
assert!(auth.is_err());
}
#[test]
fn auth_new_rejects_invalid_base64() {
let auth = Auth::new(
"sa".into(),
"!!!invalid!!!",
"https://ascend.io".into(),
test_agent(),
);
assert!(auth.is_err());
}
#[test]
fn auth_new_trims_whitespace() {
let b64 = format!(" {} \n", URL_SAFE_NO_PAD.encode(test_seed()));
let auth = Auth::new("sa".into(), &b64, "https://ascend.io".into(), test_agent());
assert!(auth.is_ok());
}
#[test]
fn sign_jwt_produces_three_part_token() {
let auth = test_auth(&test_seed());
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
let jwt = auth.sign_jwt(now, "api.cloud.ascend.io").unwrap();
assert_eq!(jwt.split('.').count(), 3);
}
#[test]
fn sign_jwt_has_correct_claims() {
let auth = test_auth(&test_seed());
let now = 1_700_000_000u64;
let jwt = auth.sign_jwt(now, "api.cloud.ascend.io").unwrap();
let payload_b64 = jwt.split('.').nth(1).unwrap();
let payload_bytes = URL_SAFE_NO_PAD.decode(payload_b64).unwrap();
let claims: Value = serde_json::from_slice(&payload_bytes).unwrap();
assert_eq!(claims["sub"], "asc-sa-test");
assert_eq!(claims["aud"], "https://api.cloud.ascend.io/auth/token");
assert_eq!(claims["exp"], now + 300);
assert_eq!(claims["iat"], now);
assert_eq!(claims["name"], "asc-sa-test");
assert_eq!(claims["service_account"], "asc-sa-test");
}
#[test]
fn debug_redacts_key() {
let seed = test_seed();
let auth = test_auth(&seed);
let debug = format!("{auth:?}");
assert!(debug.contains("[REDACTED]"));
assert!(!debug.contains(&URL_SAFE_NO_PAD.encode(seed)));
}
}