use std::env;
use base64::Engine;
use base64::engine::general_purpose::URL_SAFE_NO_PAD;
use crate::HatchetError;
#[derive(Clone, Debug)]
pub(crate) enum TlsStrategy {
None,
Tls,
}
#[derive(Debug, Clone)]
pub(crate) struct HatchetConfig {
pub(crate) api_token: String,
pub(crate) grpc_address: String,
pub(crate) server_url: String,
pub(crate) tls_strategy: TlsStrategy,
pub(crate) tenant_id: Option<String>,
}
impl HatchetConfig {
pub fn new(token: &str, tls_strategy: &str) -> Result<Self, HatchetError> {
let parts: Vec<&str> = token.split('.').collect();
if parts.len() != 3 {
return Err(HatchetError::InvalidTokenFormat);
}
let payload_json = Self::decode_token(parts[1])?;
let (grpc_address, server_url, tenant_id) = Self::parse_token(payload_json)?;
let strategy = match tls_strategy {
"none" => TlsStrategy::None,
"tls" => TlsStrategy::Tls,
other => return Err(HatchetError::InvalidTlsStrategy(other.to_string())),
};
Ok(Self {
api_token: token.to_string(),
grpc_address,
server_url,
tls_strategy: strategy,
tenant_id,
})
}
pub fn from_env() -> Result<Self, HatchetError> {
let token = env::var("HATCHET_CLIENT_TOKEN")
.map_err(|_| HatchetError::MissingEnvVar(String::from("HATCHET_CLIENT_TOKEN")))?;
let tls_strategy =
std::env::var("HATCHET_CLIENT_TLS_STRATEGY").unwrap_or(String::from("tls"));
Self::new(&token, &tls_strategy)
}
fn decode_token(token_payload: &str) -> Result<serde_json::Value, HatchetError> {
let payload_bytes = URL_SAFE_NO_PAD.decode(token_payload)?;
let payload_json: serde_json::Value = serde_json::from_slice(&payload_bytes)
.map_err(|e| HatchetError::JsonDecodeError(e.to_string()))?;
Ok(payload_json)
}
fn parse_token(
payload_json: serde_json::Value,
) -> Result<(String, String, Option<String>), HatchetError> {
let grpc_address = payload_json
.get("grpc_broadcast_address")
.and_then(|v| v.as_str())
.ok_or(HatchetError::MissingTokenField("grpc_broadcast_address"))?;
let server_url = payload_json
.get("server_url")
.and_then(|v| v.as_str())
.ok_or(HatchetError::MissingTokenField("server_url"))?;
let tenant_id = payload_json
.get("sub")
.and_then(|v| v.as_str())
.map(|s| s.to_string());
Ok((grpc_address.to_string(), server_url.to_string(), tenant_id))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_token_without_three_parts_raises_error() {
let config = HatchetConfig::new("part0.part1.part2.part3", "tls");
assert!(matches!(config, Err(HatchetError::InvalidTokenFormat)));
let config = HatchetConfig::new("part0.part1", "tls");
assert!(matches!(config, Err(HatchetError::InvalidTokenFormat)));
}
#[test]
fn test_invalid_base64_raises_error() {
let config = HatchetConfig::new("part0.part1.part2", "tls");
assert!(matches!(config, Err(HatchetError::Base64DecodeError(_))));
}
#[test]
fn test_token_decoded_into_config() {
let payload = "eyJzZXJ2ZXJfdXJsIjoiaHR0cHM6Ly9oYXRjaGV0LmNvbSIsImdycGNfYnJvYWRjYXN0X2FkZHJlc3MiOiJlbmdpbmUuaGF0Y2hldC5jb20iLCJzdWIiOiJ0ZXN0LXRlbmFudCJ9";
let token = format!("header.{}.sig", payload);
let config = HatchetConfig::new(&token, "tls").unwrap();
assert_eq!(config.server_url, "https://hatchet.com");
assert_eq!(config.grpc_address, "engine.hatchet.com");
assert_eq!(config.tenant_id, Some("test-tenant".to_string()));
}
#[test]
fn test_tenant_id_extracted_from_token() {
let payload = "eyJzZXJ2ZXJfdXJsIjoiaHR0cHM6Ly9oYXRjaGV0LmNvbSIsImdycGNfYnJvYWRjYXN0X2FkZHJlc3MiOiJlbmdpbmUuaGF0Y2hldC5jb20iLCJzdWIiOiI3MDdkMDg1NS04MGFiLTRlMWYtYTE1Ni1mMWM0NTQ2Y2JmNTIifQ";
let token = format!("header.{}.sig", payload);
let config = HatchetConfig::new(&token, "tls").unwrap();
assert_eq!(
config.tenant_id,
Some("707d0855-80ab-4e1f-a156-f1c4546cbf52".to_string())
);
}
#[test]
fn test_missing_sub_claim_returns_none_tenant() {
let payload = "eyJzZXJ2ZXJfdXJsIjoiaHR0cHM6Ly9oYXRjaGV0LmNvbSIsImdycGNfYnJvYWRjYXN0X2FkZHJlc3MiOiJlbmdpbmUuaGF0Y2hldC5jb20ifQ";
let token = format!("header.{}.sig", payload);
let config = HatchetConfig::new(&token, "tls").unwrap();
assert_eq!(config.tenant_id, None);
}
#[test]
fn test_invalid_tls_strategy_raises_error() {
let payload = "eyJzZXJ2ZXJfdXJsIjoiaHR0cHM6Ly9oYXRjaGV0LmNvbSIsImdycGNfYnJvYWRjYXN0X2FkZHJlc3MiOiJlbmdpbmUuaGF0Y2hldC5jb20iLCJzdWIiOiJ0ZXN0LXRlbmFudCJ9";
let token = format!("header.{}.sig", payload);
let config = HatchetConfig::new(&token, "bad_strategy");
assert!(matches!(config, Err(HatchetError::InvalidTlsStrategy(_))));
}
}