use std::fs;
use std::io::Write;
use std::path::PathBuf;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use sha2::{Sha256, Digest};
use thiserror::Error;
use ureq::tls::{TlsConfig, RootCerts, PemItem, parse_pem};
#[derive(Error, Debug)]
pub enum RemoteError {
#[error("HTTP request failed: {0}")]
Network(String),
#[error("invalid URL: {0}")]
InvalidUrl(String),
#[error("HTTP error: {0}")]
HttpStatus(String),
#[error("cache error: {0}")]
Cache(String),
#[error("only HTTPS URLs are allowed for security")]
HttpNotAllowed,
#[error("hash verification failed: expected {expected}, got {actual}")]
HashMismatch { expected: String, actual: String },
#[error("rate limited: wait {seconds} seconds before fetching again")]
RateLimited { seconds: u64 },
#[error("failed to load CA certificate: {0}")]
CertificateError(String),
}
pub const CACHE_TTL_SECS: u64 = 3600;
pub const DEFAULT_RATE_LIMIT_SECS: u64 = 60;
#[derive(Debug, Clone)]
pub struct SecurityOptions {
pub verify_hash: Option<String>,
pub ca_cert: Option<String>,
pub rate_limit_seconds: u64,
}
impl Default for SecurityOptions {
fn default() -> Self {
Self {
verify_hash: None,
ca_cert: None,
rate_limit_seconds: DEFAULT_RATE_LIMIT_SECS,
}
}
}
impl SecurityOptions {
pub fn new() -> Self {
Self::default()
}
pub fn with_hash(mut self, hash: Option<String>) -> Self {
self.verify_hash = hash;
self
}
pub fn with_ca_cert(mut self, path: Option<String>) -> Self {
self.ca_cert = path;
self
}
pub fn with_rate_limit(mut self, seconds: u64) -> Self {
self.rate_limit_seconds = seconds;
self
}
}
pub fn is_remote_url(path: &str) -> bool {
path.starts_with("https://") || path.starts_with("http://")
}
#[allow(dead_code)]
pub fn fetch_remote_schema(url: &str, no_cache: bool) -> Result<String, RemoteError> {
fetch_remote_schema_secure(url, no_cache, &SecurityOptions::new())
}
pub fn fetch_remote_schema_secure(
url: &str,
no_cache: bool,
security: &SecurityOptions,
) -> Result<String, RemoteError> {
if url.starts_with("http://") {
return Err(RemoteError::HttpNotAllowed);
}
if !url.starts_with("https://") {
return Err(RemoteError::InvalidUrl(url.to_string()));
}
if !no_cache && security.rate_limit_seconds > 0 {
check_rate_limit(url, security.rate_limit_seconds)?;
}
if !no_cache {
if let Some(cached) = read_cache(url)? {
if let Some(ref expected_hash) = security.verify_hash {
verify_content_hash(&cached, expected_hash)?;
}
return Ok(cached);
}
}
let content = fetch_url_secure(url, security.ca_cert.as_deref())?;
if let Some(ref expected_hash) = security.verify_hash {
verify_content_hash(&content, expected_hash)?;
}
if let Err(e) = write_cache_with_metadata(url, &content) {
eprintln!("warning: failed to cache schema: {}", e);
}
Ok(content)
}
pub fn verify_content_hash(content: &str, expected_hash: &str) -> Result<(), RemoteError> {
let mut hasher = Sha256::new();
hasher.update(content.as_bytes());
let actual_hash = format!("{:x}", hasher.finalize());
let expected_lower = expected_hash.to_lowercase();
if actual_hash == expected_lower || actual_hash.starts_with(&expected_lower) {
Ok(())
} else {
Err(RemoteError::HashMismatch {
expected: expected_hash.to_string(),
actual: actual_hash,
})
}
}
pub fn compute_content_hash(content: &str) -> String {
let mut hasher = Sha256::new();
hasher.update(content.as_bytes());
format!("{:x}", hasher.finalize())
}
#[allow(dead_code)]
fn fetch_url(url: &str) -> Result<String, RemoteError> {
fetch_url_secure(url, None)
}
fn fetch_url_secure(url: &str, ca_cert_path: Option<&str>) -> Result<String, RemoteError> {
let tls_config = build_tls_config(ca_cert_path)?;
let agent = ureq::Agent::config_builder()
.timeout_global(Some(Duration::from_secs(30)))
.tls_config(tls_config)
.build()
.new_agent();
let mut response = agent
.get(url)
.call()
.map_err(|e| RemoteError::Network(e.to_string()))?;
if response.status() != 200 {
return Err(RemoteError::HttpStatus(format!(
"status {} for {}",
response.status(),
url
)));
}
response
.body_mut()
.read_to_string()
.map_err(|e| RemoteError::Network(e.to_string()))
}
fn build_tls_config(ca_cert_path: Option<&str>) -> Result<TlsConfig, RemoteError> {
if let Some(ca_path) = ca_cert_path {
let pem_data = fs::read(ca_path)
.map_err(|e| RemoteError::CertificateError(format!("failed to read {}: {}", ca_path, e)))?;
let mut certs = Vec::new();
for item in parse_pem(&pem_data) {
match item {
Ok(PemItem::Certificate(cert)) => certs.push(cert),
Ok(_) => {} Err(e) => return Err(RemoteError::CertificateError(
format!("failed to parse PEM from {}: {}", ca_path, e)
)),
}
}
if certs.is_empty() {
return Err(RemoteError::CertificateError(
format!("no valid certificates found in {}", ca_path)
));
}
let count = certs.len();
let root_certs = RootCerts::new_with_certs(&certs);
eprintln!("zenv: using CA certificate from {} ({} cert(s))", ca_path, count);
Ok(TlsConfig::builder()
.root_certs(root_certs)
.build())
} else {
Ok(TlsConfig::default())
}
}
fn check_rate_limit(url: &str, rate_limit_seconds: u64) -> Result<(), RemoteError> {
let metadata_path = match metadata_path_for_url(url) {
Some(p) => p,
None => return Ok(()), };
if !metadata_path.exists() {
return Ok(()); }
if let Ok(content) = fs::read_to_string(&metadata_path) {
if let Ok(metadata) = serde_json::from_str::<CacheMetadata>(&content) {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
let elapsed = now.saturating_sub(metadata.fetched_at);
if elapsed < rate_limit_seconds {
let wait_seconds = rate_limit_seconds - elapsed;
return Err(RemoteError::RateLimited { seconds: wait_seconds });
}
}
}
Ok(())
}
#[derive(serde::Serialize, serde::Deserialize)]
struct CacheMetadata {
url: String,
fetched_at: u64,
content_hash: String,
}
fn metadata_path_for_url(url: &str) -> Option<PathBuf> {
cache_dir().map(|d| d.join(format!("{}.meta", cache_filename(url).trim_end_matches(".json"))))
}
fn write_cache_with_metadata(url: &str, content: &str) -> Result<(), RemoteError> {
write_cache(url, content)?;
let metadata_path = match metadata_path_for_url(url) {
Some(p) => p,
None => return Ok(()),
};
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
let metadata = CacheMetadata {
url: url.to_string(),
fetched_at: now,
content_hash: compute_content_hash(content),
};
let metadata_json = serde_json::to_string(&metadata)
.map_err(|e| RemoteError::Cache(e.to_string()))?;
fs::write(&metadata_path, metadata_json)
.map_err(|e| RemoteError::Cache(e.to_string()))?;
Ok(())
}
pub fn cache_dir() -> Option<PathBuf> {
dirs::cache_dir().map(|p| p.join("zorath-env"))
}
pub fn cache_filename(url: &str) -> String {
let hash: u64 = url.bytes().enumerate().fold(0u64, |acc, (i, b)| {
acc.wrapping_add((b as u64).wrapping_mul((i as u64).wrapping_add(1)))
});
format!("{:016x}.json", hash)
}
fn read_cache(url: &str) -> Result<Option<String>, RemoteError> {
let cache_dir = match cache_dir() {
Some(dir) => dir,
None => return Ok(None),
};
let cache_path = cache_dir.join(cache_filename(url));
if !cache_path.exists() {
return Ok(None);
}
let metadata = fs::metadata(&cache_path).map_err(|e| RemoteError::Cache(e.to_string()))?;
let modified = metadata
.modified()
.map_err(|e| RemoteError::Cache(e.to_string()))?;
let age = SystemTime::now()
.duration_since(modified)
.unwrap_or(Duration::MAX);
if age.as_secs() > CACHE_TTL_SECS {
return Ok(None);
}
let content = fs::read_to_string(&cache_path).map_err(|e| RemoteError::Cache(e.to_string()))?;
Ok(Some(content))
}
fn write_cache(url: &str, content: &str) -> Result<(), RemoteError> {
let cache_dir = match cache_dir() {
Some(dir) => dir,
None => return Ok(()), };
fs::create_dir_all(&cache_dir).map_err(|e| RemoteError::Cache(e.to_string()))?;
let cache_path = cache_dir.join(cache_filename(url));
let mut file = fs::File::create(&cache_path).map_err(|e| RemoteError::Cache(e.to_string()))?;
file.write_all(content.as_bytes())
.map_err(|e| RemoteError::Cache(e.to_string()))?;
Ok(())
}
pub fn resolve_relative_url(base_url: &str, relative_path: &str) -> Result<String, RemoteError> {
if relative_path.starts_with("https://") || relative_path.starts_with("http://") {
return Ok(relative_path.to_string());
}
let base = url::Url::parse(base_url).map_err(|e| RemoteError::InvalidUrl(e.to_string()))?;
let resolved = base
.join(relative_path)
.map_err(|e| RemoteError::InvalidUrl(e.to_string()))?;
Ok(resolved.to_string())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_is_remote_url() {
assert!(is_remote_url("https://example.com/schema.json"));
assert!(is_remote_url("http://example.com/schema.json"));
assert!(!is_remote_url("env.schema.json"));
assert!(!is_remote_url("./schemas/env.schema.json"));
assert!(!is_remote_url("/absolute/path/schema.json"));
}
#[test]
fn test_http_rejected() {
let result = fetch_remote_schema("http://example.com/schema.json", true);
assert!(matches!(result, Err(RemoteError::HttpNotAllowed)));
}
#[test]
fn test_cache_filename() {
let name1 = cache_filename("https://example.com/a.json");
let name2 = cache_filename("https://example.com/b.json");
assert_ne!(name1, name2);
assert!(name1.ends_with(".json"));
}
#[test]
fn test_resolve_relative_url() {
let base = "https://example.com/schemas/prod.json";
let resolved = resolve_relative_url(base, "base.json").unwrap();
assert_eq!(resolved, "https://example.com/schemas/base.json");
let resolved = resolve_relative_url(base, "../common.json").unwrap();
assert_eq!(resolved, "https://example.com/common.json");
let resolved = resolve_relative_url(base, "https://other.com/schema.json").unwrap();
assert_eq!(resolved, "https://other.com/schema.json");
}
#[test]
fn test_compute_content_hash() {
let content = r#"{"FOO": {"type": "string"}}"#;
let hash = compute_content_hash(content);
assert_eq!(hash.len(), 64);
assert_eq!(hash, compute_content_hash(content));
}
#[test]
fn test_verify_content_hash_matches() {
let content = "test content";
let hash = compute_content_hash(content);
assert!(verify_content_hash(content, &hash).is_ok());
assert!(verify_content_hash(content, &hash.to_uppercase()).is_ok());
assert!(verify_content_hash(content, &hash[..16]).is_ok());
}
#[test]
fn test_verify_content_hash_mismatch() {
let content = "test content";
let wrong_hash = "0000000000000000000000000000000000000000000000000000000000000000";
let result = verify_content_hash(content, wrong_hash);
assert!(matches!(result, Err(RemoteError::HashMismatch { .. })));
}
#[test]
fn test_security_options_builder() {
let opts = SecurityOptions::new()
.with_hash(Some("abc123".to_string()))
.with_ca_cert(Some("/path/to/cert.pem".to_string()))
.with_rate_limit(120);
assert_eq!(opts.verify_hash, Some("abc123".to_string()));
assert_eq!(opts.ca_cert, Some("/path/to/cert.pem".to_string()));
assert_eq!(opts.rate_limit_seconds, 120);
}
#[test]
fn test_security_options_defaults() {
let opts = SecurityOptions::default();
assert_eq!(opts.verify_hash, None);
assert_eq!(opts.ca_cert, None);
assert_eq!(opts.rate_limit_seconds, DEFAULT_RATE_LIMIT_SECS);
}
#[test]
fn test_security_options_new() {
let opts = SecurityOptions::new();
assert_eq!(opts.verify_hash, None);
assert_eq!(opts.ca_cert, None);
assert_eq!(opts.rate_limit_seconds, DEFAULT_RATE_LIMIT_SECS);
}
#[test]
fn test_cache_metadata_serialization() {
let metadata = CacheMetadata {
url: "https://example.com/schema.json".to_string(),
fetched_at: 1234567890,
content_hash: "abc123".to_string(),
};
let json = serde_json::to_string(&metadata).unwrap();
let parsed: CacheMetadata = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.url, metadata.url);
assert_eq!(parsed.fetched_at, metadata.fetched_at);
assert_eq!(parsed.content_hash, metadata.content_hash);
}
#[test]
fn test_http_rejected_secure() {
let security = SecurityOptions::new();
let result = fetch_remote_schema_secure("http://example.com/schema.json", true, &security);
assert!(matches!(result, Err(RemoteError::HttpNotAllowed)));
}
#[test]
fn test_invalid_ca_cert_path() {
let result = build_tls_config(Some("/nonexistent/path/ca.pem"));
assert!(matches!(result, Err(RemoteError::CertificateError(_))));
}
#[test]
fn test_verify_hash_empty_content() {
let content = "";
let hash = compute_content_hash(content);
assert!(verify_content_hash(content, &hash).is_ok());
assert_eq!(hash.len(), 64);
}
#[test]
fn test_verify_hash_unicode_content() {
let content = r#"{"description": "Unicode test"}"#;
let hash = compute_content_hash(content);
assert!(verify_content_hash(content, &hash).is_ok());
}
#[test]
fn test_verify_hash_with_newlines() {
let content = "line1\nline2\nline3";
let hash = compute_content_hash(content);
assert!(verify_content_hash(content, &hash).is_ok());
let content_crlf = "line1\r\nline2\r\nline3";
let hash_crlf = compute_content_hash(content_crlf);
assert_ne!(hash, hash_crlf);
}
#[test]
fn test_compute_hash_deterministic() {
let content = r#"{"PORT": {"type": "int", "required": true}}"#;
let hash1 = compute_content_hash(content);
let hash2 = compute_content_hash(content);
let hash3 = compute_content_hash(content);
assert_eq!(hash1, hash2);
assert_eq!(hash2, hash3);
}
#[test]
fn test_compute_hash_different_content_different_hash() {
let content1 = r#"{"FOO": "bar"}"#;
let content2 = r#"{"FOO": "baz"}"#;
let hash1 = compute_content_hash(content1);
let hash2 = compute_content_hash(content2);
assert_ne!(hash1, hash2);
}
#[test]
fn test_compute_hash_special_characters() {
let content = r#"{"key": "value with $pecial & <chars>"}"#;
let hash = compute_content_hash(content);
assert_eq!(hash.len(), 64);
assert!(hash.chars().all(|c| c.is_ascii_hexdigit()));
}
#[test]
fn test_verify_hash_short_prefix() {
let content = "test";
let hash = compute_content_hash(content);
assert!(verify_content_hash(content, &hash[..8]).is_ok());
}
#[test]
fn test_build_tls_config_with_none() {
let result = build_tls_config(None);
assert!(result.is_ok(), "Should succeed with no CA cert");
}
#[test]
fn test_build_tls_config_empty_file() {
let temp_file = tempfile::NamedTempFile::new().unwrap();
let result = build_tls_config(Some(temp_file.path().to_str().unwrap()));
let _ = result;
}
#[test]
fn test_build_tls_config_invalid_pem_content() {
use std::io::Write;
let mut temp_file = tempfile::NamedTempFile::new().unwrap();
writeln!(temp_file, "This is not a valid PEM certificate").unwrap();
let result = build_tls_config(Some(temp_file.path().to_str().unwrap()));
let _ = result;
}
#[test]
fn test_rate_limit_with_zero_seconds() {
let opts = SecurityOptions::new().with_rate_limit(0);
assert_eq!(opts.rate_limit_seconds, 0);
}
#[test]
fn test_cache_metadata_with_current_time() {
use std::time::{SystemTime, UNIX_EPOCH};
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
let metadata = CacheMetadata {
url: "https://example.com/test.json".to_string(),
fetched_at: now,
content_hash: compute_content_hash("test"),
};
let json = serde_json::to_string(&metadata).unwrap();
let parsed: CacheMetadata = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.fetched_at, now);
}
#[test]
fn test_cache_dir_returns_path() {
let dir = cache_dir();
assert!(dir.is_some(), "Should return a cache directory path");
if let Some(path) = dir {
assert!(!path.as_os_str().is_empty());
}
}
#[test]
fn test_cache_filename_consistent() {
let url = "https://example.com/schemas/env.schema.json";
let name1 = cache_filename(url);
let name2 = cache_filename(url);
assert_eq!(name1, name2, "Same URL should produce same cache filename");
}
#[test]
fn test_cache_filename_different_for_different_urls() {
let url1 = "https://example.com/a.json";
let url2 = "https://example.com/b.json";
let url3 = "https://other.com/a.json";
let name1 = cache_filename(url1);
let name2 = cache_filename(url2);
let name3 = cache_filename(url3);
assert_ne!(name1, name2);
assert_ne!(name1, name3);
assert_ne!(name2, name3);
}
#[test]
fn test_is_remote_url_various_schemes() {
assert!(is_remote_url("https://example.com/schema.json"));
assert!(is_remote_url("http://example.com/schema.json"));
assert!(!is_remote_url("ftp://example.com/schema.json"));
assert!(!is_remote_url("./schema.json"));
assert!(!is_remote_url("/path/to/schema.json"));
assert!(!is_remote_url("C:\\path\\schema.json"));
}
#[test]
fn test_resolve_relative_url_edge_cases() {
let base = "https://example.com/schemas/";
let relative = "child.json";
let result = resolve_relative_url(base, relative).unwrap();
assert!(result.contains("example.com"));
assert!(result.contains("child.json"));
let base2 = "https://example.com/schemas/parent.json";
let relative2 = "child.json";
let result2 = resolve_relative_url(base2, relative2).unwrap();
assert!(result2.contains("child.json"));
}
#[test]
fn test_security_options_all_fields() {
let opts = SecurityOptions::new()
.with_hash(Some("abc123".to_string()))
.with_ca_cert(Some("/path/to/cert.pem".to_string()))
.with_rate_limit(120);
assert_eq!(opts.verify_hash, Some("abc123".to_string()));
assert_eq!(opts.ca_cert, Some("/path/to/cert.pem".to_string()));
assert_eq!(opts.rate_limit_seconds, 120);
}
#[test]
fn test_security_options_chaining() {
let opts = SecurityOptions::new()
.with_hash(None)
.with_ca_cert(None)
.with_rate_limit(0);
assert!(opts.verify_hash.is_none());
assert!(opts.ca_cert.is_none());
assert_eq!(opts.rate_limit_seconds, 0);
}
#[test]
fn test_cache_filename_url_encoded_chars() {
let url1 = "https://example.com/schema%20with%20spaces.json";
let url2 = "https://example.com/schema?query=value&other=123";
let name1 = cache_filename(url1);
let name2 = cache_filename(url2);
assert!(!name1.contains('%'));
assert!(!name1.contains(' '));
assert!(!name2.contains('?'));
assert!(!name2.contains('&'));
}
#[test]
fn test_verify_content_hash_case_insensitive() {
let content = "test content";
let hash_lower = compute_content_hash(content).to_lowercase();
let _hash_upper = hash_lower.to_uppercase();
let result_lower = verify_content_hash(content, &hash_lower);
assert!(result_lower.is_ok());
}
#[test]
fn test_compute_hash_consistency_across_calls() {
let content = "consistent content";
let hash1 = compute_content_hash(content);
let hash2 = compute_content_hash(content);
let hash3 = compute_content_hash(content);
assert_eq!(hash1, hash2);
assert_eq!(hash2, hash3);
}
}