use std::sync::{Arc, LazyLock};
use regex::Regex;
use thiserror::Error;
static API_KEY_PATTERN: LazyLock<Regex> = LazyLock::new(|| {
Regex::new(r"\b[a-zA-Z0-9_-]{3,}\.[a-zA-Z0-9_-]{10,}\b").expect("invalid regex")
});
static SENSITIVE_PATTERNS: LazyLock<Vec<(Regex, &'static str)>> = LazyLock::new(|| {
vec![
(
Regex::new(r"(?i)(api[_-]?key\s*[=:]\s*)[^\s,]+").expect("invalid regex"),
"$1[FILTERED]",
),
(
Regex::new(r"(?i)(password\s*[=:]\s*)[^\s,]+").expect("invalid regex"),
"$1[FILTERED]",
),
(
Regex::new(r"(?i)(token\s*[=:]\s*)[^\s,]+").expect("invalid regex"),
"$1[FILTERED]",
),
(
Regex::new(r"(?i)(secret\s*[=:]\s*)[^\s,]+").expect("invalid regex"),
"$1[FILTERED]",
),
(
Regex::new(r"(?i)(bearer\s+[a-zA-Z0-9_-]+\.[a-zA-Z0-9_-]+)").expect("invalid regex"),
"bearer [FILTERED]",
),
(
Regex::new(r"(?i)(authorization\s*:\s*Bearer\s+)[^\s,]+").expect("invalid regex"),
"$1[FILTERED]",
),
]
});
static CONTAINS_SENSITIVE_PATTERNS: LazyLock<Vec<Regex>> = LazyLock::new(|| {
vec![
Regex::new(r"(?i)api[_-]?key\s*[=:]").expect("invalid regex"),
Regex::new(r"(?i)password\s*[=:]").expect("invalid regex"),
Regex::new(r"(?i)token\s*[=:]").expect("invalid regex"),
Regex::new(r"(?i)secret\s*[=:]").expect("invalid regex"),
Regex::new(r"(?i)authorization\s*:\s*Bearer").expect("invalid regex"),
]
});
pub fn mask_sensitive_info(text: &str) -> String {
let mut result = API_KEY_PATTERN.replace_all(text, "[FILTERED]").to_string();
for (re, replacement) in SENSITIVE_PATTERNS.iter() {
result = re.replace_all(&result, *replacement).to_string();
}
result
}
pub fn mask_api_key(text: &str) -> String {
API_KEY_PATTERN.replace_all(text, "[FILTERED]").to_string()
}
pub fn contains_sensitive_info(text: &str) -> bool {
if API_KEY_PATTERN.is_match(text) {
return true;
}
CONTAINS_SENSITIVE_PATTERNS
.iter()
.any(|re| re.is_match(text))
}
pub fn validate_api_key(api_key: &str) -> ZaiResult<()> {
if api_key.is_empty() {
return Err(ZaiError::ApiError {
code: 1200,
message: "API key cannot be empty".to_string(),
});
}
let parts: Vec<&str> = api_key.split('.').collect();
if parts.len() != 2 {
return Err(ZaiError::ApiError {
code: 1001,
message: "API key must be in format '<id>.<secret>'".to_string(),
});
}
let (id, secret) = (parts[0], parts[1]);
if id.is_empty() || secret.is_empty() {
return Err(ZaiError::ApiError {
code: 1200,
message: "API key id and secret must not be empty".to_string(),
});
}
let valid_chars = |s: &str| -> bool {
s.chars()
.all(|c| c.is_alphanumeric() || c == '_' || c == '-')
};
if !valid_chars(id) || !valid_chars(secret) {
return Err(ZaiError::ApiError {
code: 1200,
message: "API key contains invalid characters".to_string(),
});
}
if id.len() < 3 {
return Err(ZaiError::ApiError {
code: 1200,
message: "API key id is too short".to_string(),
});
}
if secret.len() < 10 {
return Err(ZaiError::ApiError {
code: 1200,
message: "API key secret is too short".to_string(),
});
}
Ok(())
}
#[derive(Error, Debug)]
pub enum ZaiError {
#[error("HTTP error [{status}]: {message}")]
HttpError { status: u16, message: String },
#[error("Authentication error [{code}]: {message}")]
AuthError { code: u16, message: String },
#[error("Account error [{code}]: {message}")]
AccountError { code: u16, message: String },
#[error("API error [{code}]: {message}")]
ApiError { code: u16, message: String },
#[error("Rate limit error [{code}]: {message}")]
RateLimitError { code: u16, message: String },
#[error("Content policy error [{code}]: {message}")]
ContentPolicyError { code: u16, message: String },
#[error("File error [{code}]: {message}")]
FileError { code: u16, message: String },
#[error("Network error: {0}")]
NetworkError(Arc<reqwest::Error>),
#[error("JSON error: {0}")]
JsonError(Arc<serde_json::Error>),
#[error("Unknown error [{code}]: {message}")]
Unknown { code: u16, message: String },
}
impl ZaiError {
pub fn from_api_response(status: u16, api_code: u16, api_message: String) -> Self {
match status {
400 => ZaiError::HttpError {
status,
message: if api_message.is_empty() {
"Bad request - check your parameters".to_string()
} else {
api_message
},
},
401 => ZaiError::HttpError {
status,
message: "Unauthorized - check your API key".to_string(),
},
404 => ZaiError::HttpError {
status,
message: "Not found - requested resource doesn't exist".to_string(),
},
429 => ZaiError::HttpError {
status,
message: if api_message.is_empty() {
"Too many requests - rate limit exceeded".to_string()
} else {
api_message
},
},
434 => ZaiError::HttpError {
status,
message: "No API permission - feature not available".to_string(),
},
435 => ZaiError::HttpError {
status,
message: "File size exceeds 100MB limit".to_string(),
},
500 => ZaiError::HttpError {
status,
message: "Internal server error - try again later".to_string(),
},
_ => {
match api_code {
1000..=1004 | 1100 => ZaiError::AuthError {
code: api_code,
message: api_message,
},
1110..=1121 => ZaiError::AccountError {
code: api_code,
message: api_message,
},
1200..=1234 => ZaiError::ApiError {
code: api_code,
message: api_message,
},
1300..=1309 => ZaiError::RateLimitError {
code: api_code,
message: api_message,
},
_ => ZaiError::Unknown {
code: api_code,
message: if api_message.is_empty() {
"Unknown error".to_string()
} else {
api_message
},
},
}
},
}
}
pub fn is_rate_limit(&self) -> bool {
matches!(self, ZaiError::RateLimitError { .. })
}
pub fn is_auth_error(&self) -> bool {
matches!(self, ZaiError::AuthError { .. })
}
pub fn is_client_error(&self) -> bool {
match self {
ZaiError::HttpError { status, .. } => *status >= 400 && *status < 500,
ZaiError::AuthError { .. }
| ZaiError::AccountError { .. }
| ZaiError::ApiError { .. }
| ZaiError::RateLimitError { .. }
| ZaiError::ContentPolicyError { .. }
| ZaiError::FileError { .. } => true,
_ => false,
}
}
pub fn is_server_error(&self) -> bool {
match self {
ZaiError::HttpError { status, .. } => *status >= 500,
ZaiError::Unknown { code, .. } => *code >= 500,
_ => false,
}
}
pub fn compact(&self) -> String {
match self {
ZaiError::HttpError { status, message } => {
format!("HTTP[{}]: {}", status, message)
},
ZaiError::AuthError { code, message } => {
format!("AUTH[{}]: {}", code, message)
},
ZaiError::AccountError { code, message } => {
format!("ACCOUNT[{}]: {}", code, message)
},
ZaiError::ApiError { code, message } => {
format!("API[{}]: {}", code, message)
},
ZaiError::RateLimitError { code, message } => {
format!("RATE_LIMIT[{}]: {}", code, message)
},
ZaiError::ContentPolicyError { code, message } => {
format!("POLICY[{}]: {}", code, message)
},
ZaiError::FileError { code, message } => {
format!("FILE[{}]: {}", code, message)
},
ZaiError::NetworkError(err) => {
format!("NETWORK: {}", err)
},
ZaiError::JsonError(err) => {
format!("JSON: {}", err)
},
ZaiError::Unknown { code, message } => {
format!("UNKNOWN[{}]: {}", code, message)
},
}
}
pub fn code(&self) -> Option<u16> {
match self {
ZaiError::HttpError { status, .. } => Some(*status),
ZaiError::AuthError { code, .. } => Some(*code),
ZaiError::AccountError { code, .. } => Some(*code),
ZaiError::ApiError { code, .. } => Some(*code),
ZaiError::RateLimitError { code, .. } => Some(*code),
ZaiError::ContentPolicyError { code, .. } => Some(*code),
ZaiError::FileError { code, .. } => Some(*code),
ZaiError::NetworkError(_) => None,
ZaiError::JsonError(_) => None,
ZaiError::Unknown { code, .. } => Some(*code),
}
}
pub fn message(&self) -> String {
match self {
ZaiError::HttpError { message, .. } => message.clone(),
ZaiError::AuthError { message, .. } => message.clone(),
ZaiError::AccountError { message, .. } => message.clone(),
ZaiError::ApiError { message, .. } => message.clone(),
ZaiError::RateLimitError { message, .. } => message.clone(),
ZaiError::ContentPolicyError { message, .. } => message.clone(),
ZaiError::FileError { message, .. } => message.clone(),
ZaiError::NetworkError(err) => err.to_string(),
ZaiError::JsonError(err) => err.to_string(),
ZaiError::Unknown { message, .. } => message.clone(),
}
}
}
impl Clone for ZaiError {
fn clone(&self) -> Self {
match self {
ZaiError::HttpError { status, message } => ZaiError::HttpError {
status: *status,
message: message.clone(),
},
ZaiError::AuthError { code, message } => ZaiError::AuthError {
code: *code,
message: message.clone(),
},
ZaiError::AccountError { code, message } => ZaiError::AccountError {
code: *code,
message: message.clone(),
},
ZaiError::ApiError { code, message } => ZaiError::ApiError {
code: *code,
message: message.clone(),
},
ZaiError::RateLimitError { code, message } => ZaiError::RateLimitError {
code: *code,
message: message.clone(),
},
ZaiError::ContentPolicyError { code, message } => ZaiError::ContentPolicyError {
code: *code,
message: message.clone(),
},
ZaiError::FileError { code, message } => ZaiError::FileError {
code: *code,
message: message.clone(),
},
ZaiError::NetworkError(err) => ZaiError::NetworkError(Arc::clone(err)),
ZaiError::JsonError(err) => ZaiError::JsonError(Arc::clone(err)),
ZaiError::Unknown { code, message } => ZaiError::Unknown {
code: *code,
message: message.clone(),
},
}
}
}
pub type ZaiResult<T> = Result<T, ZaiError>;
impl From<reqwest::Error> for ZaiError {
fn from(err: reqwest::Error) -> Self {
if let Some(status) = err.status() {
ZaiError::from_api_response(status.as_u16(), 0, err.to_string())
} else {
ZaiError::NetworkError(Arc::new(err))
}
}
}
impl From<serde_json::Error> for ZaiError {
fn from(err: serde_json::Error) -> Self {
ZaiError::JsonError(Arc::new(err))
}
}
impl From<validator::ValidationErrors> for ZaiError {
fn from(err: validator::ValidationErrors) -> Self {
ZaiError::ApiError {
code: 1200,
message: format!("Validation error: {:?}", err),
}
}
}
impl From<std::io::Error> for ZaiError {
fn from(err: std::io::Error) -> Self {
ZaiError::Unknown {
code: 0,
message: err.to_string(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_from_api_response_bad_request() {
let err = ZaiError::from_api_response(400, 0, "Invalid input".to_string());
assert!(err.is_client_error());
assert!(!err.is_server_error());
assert_eq!(err.code(), Some(400));
}
#[test]
fn test_from_api_response_unauthorized() {
let err = ZaiError::from_api_response(401, 0, "".to_string());
assert!(err.is_client_error());
assert_eq!(err.message(), "Unauthorized - check your API key");
}
#[test]
fn test_from_api_response_rate_limit() {
let err = ZaiError::from_api_response(429, 1301, "Too many requests".to_string());
assert!(err.is_client_error());
assert!(!err.is_rate_limit()); assert_eq!(err.code(), Some(429));
let err = ZaiError::from_api_response(200, 1301, "Too many requests".to_string());
assert!(err.is_client_error());
assert!(err.is_rate_limit());
assert_eq!(err.code(), Some(1301));
}
#[test]
fn test_from_api_response_server_error() {
let err = ZaiError::from_api_response(500, 0, "".to_string());
assert!(!err.is_client_error());
assert!(err.is_server_error());
}
#[test]
fn test_from_api_response_auth_error_code() {
let err = ZaiError::from_api_response(200, 1001, "Invalid API key".to_string());
assert!(err.is_auth_error());
assert_eq!(err.code(), Some(1001));
assert_eq!(err.message(), "Invalid API key");
}
#[test]
fn test_from_api_response_account_error() {
let err = ZaiError::from_api_response(200, 1110, "Account expired".to_string());
assert!(err.is_client_error());
assert_eq!(err.code(), Some(1110));
}
#[test]
fn test_from_api_response_api_error() {
let err = ZaiError::from_api_response(200, 1200, "Invalid parameters".to_string());
assert!(err.is_client_error());
assert_eq!(err.code(), Some(1200));
}
#[test]
fn test_from_api_response_unknown_code() {
let err = ZaiError::from_api_response(200, 9999, "Unknown error".to_string());
assert!(!err.is_client_error()); assert_eq!(err.code(), Some(9999));
}
#[test]
fn test_compact() {
let err = ZaiError::HttpError {
status: 404,
message: "Not found".to_string(),
};
assert_eq!(err.compact(), "HTTP[404]: Not found");
let err = ZaiError::AuthError {
code: 1001,
message: "Invalid key".to_string(),
};
assert_eq!(err.compact(), "AUTH[1001]: Invalid key");
}
#[test]
fn test_code() {
let io_err =
std::io::Error::new(std::io::ErrorKind::ConnectionRefused, "connection refused");
let err = ZaiError::from(io_err);
assert_eq!(err.code(), Some(0));
let err = ZaiError::JsonError(std::sync::Arc::new(serde_json::Error::io(
std::io::Error::new(std::io::ErrorKind::InvalidData, "invalid JSON"),
)));
assert!(err.code().is_none());
let err = ZaiError::HttpError {
status: 500,
message: "Server error".to_string(),
};
assert_eq!(err.code(), Some(500));
}
#[test]
fn test_message() {
let err = ZaiError::RateLimitError {
code: 1300,
message: "Too many requests".to_string(),
};
assert_eq!(err.message(), "Too many requests");
}
#[test]
fn test_from_reqwest_error_with_status() {
let io_err = std::io::Error::other("test error");
let zai_err = ZaiError::from(io_err);
match zai_err {
ZaiError::Unknown { .. } => {},
_ => panic!("Expected Unknown error for io::Error"),
}
}
#[test]
fn test_validate_api_key_valid() {
assert!(validate_api_key("abc123.abcdefghijklmnopqrstuvwxyz").is_ok());
}
#[test]
fn test_validate_api_key_empty() {
let result = validate_api_key("");
assert!(result.is_err());
match result {
Err(ZaiError::ApiError { code, .. }) => {
assert_eq!(code, 1200);
},
_ => panic!("Expected ApiError"),
}
}
#[test]
fn test_validate_api_key_no_dot() {
let result = validate_api_key("invalid");
assert!(result.is_err());
match result {
Err(ZaiError::ApiError { code, message }) => {
assert_eq!(code, 1001);
assert!(message.contains("format"));
},
_ => panic!("Expected ApiError"),
}
}
#[test]
fn test_validate_api_key_multiple_dots() {
let result = validate_api_key("id.secret.extra");
assert!(result.is_err());
assert_eq!(result.unwrap_err().code(), Some(1001));
}
#[test]
fn test_validate_api_key_empty_id() {
let result = validate_api_key(".secret123456789");
assert!(result.is_err());
assert_eq!(result.unwrap_err().code(), Some(1200));
}
#[test]
fn test_validate_api_key_empty_secret() {
let result = validate_api_key("id123.");
assert!(result.is_err());
assert_eq!(result.unwrap_err().code(), Some(1200));
}
#[test]
fn test_validate_api_key_invalid_chars() {
let result = validate_api_key("id$123.secret@456");
assert!(result.is_err());
assert_eq!(result.unwrap_err().code(), Some(1200));
}
#[test]
fn test_validate_api_key_id_too_short() {
let result = validate_api_key("ab.abcdefghijklmn");
assert!(result.is_err());
assert!(result.unwrap_err().message().contains("id is too short"));
}
#[test]
fn test_validate_api_key_secret_too_short() {
let result = validate_api_key("id123.short");
assert!(result.is_err());
assert!(
result
.unwrap_err()
.message()
.contains("secret is too short")
);
}
#[test]
fn test_mask_sensitive_info_api_key() {
let text = "API key: abc123.abcdefghijklmnopqrstuvwxyz12345";
let filtered = mask_sensitive_info(text);
assert!(filtered.contains("[FILTERED]"));
assert!(!filtered.contains("abc123"));
assert!(!filtered.contains("abcdefghijklmnopqrstuvwxyz"));
}
#[test]
fn test_mask_sensitive_info_password() {
let text = "password: secret123, other text";
let filtered = mask_sensitive_info(text);
assert!(filtered.contains("[FILTERED]"));
assert!(!filtered.contains("secret123"));
}
#[test]
fn test_mask_sensitive_info_token() {
let text = "token=abc123xyz, other content";
let filtered = mask_sensitive_info(text);
assert!(filtered.contains("[FILTERED]"));
assert!(!filtered.contains("abc123xyz"));
}
#[test]
fn test_mask_sensitive_info_bearer() {
let text = "Authorization: Bearer abc123.abc1234567890";
let filtered = mask_sensitive_info(text);
assert!(filtered.contains("[FILTERED]"));
assert!(!filtered.contains("abc123"));
}
#[test]
fn test_mask_sensitive_info_multiple() {
let text = "api_key=abc123.xyz456, password=secret123";
let filtered = mask_sensitive_info(text);
let filtered_count = filtered.matches("[FILTERED]").count();
assert_eq!(filtered_count, 2);
}
#[test]
fn test_mask_sensitive_info_no_sensitive() {
let text = "Regular text without sensitive information";
let filtered = mask_sensitive_info(text);
assert_eq!(filtered, text);
}
#[test]
fn test_mask_api_key() {
let text = "API key: abc123.abcdefghijklmnopqrstuvwxyz12345";
let filtered = mask_api_key(text);
assert!(filtered.contains("[FILTERED]"));
assert!(!filtered.contains("abc123"));
}
#[test]
fn test_contains_sensitive_info_api_key() {
assert!(contains_sensitive_info("api_key: abc123.abc1234567890"));
assert!(!contains_sensitive_info("regular text"));
}
#[test]
fn test_contains_sensitive_info_password() {
assert!(contains_sensitive_info("password: secret"));
assert!(contains_sensitive_info("password=123"));
assert!(!contains_sensitive_info("password"));
assert!(!contains_sensitive_info("word:password"));
}
#[test]
fn test_contains_sensitive_info_token() {
assert!(contains_sensitive_info("token=abc123"));
assert!(contains_sensitive_info("token: xyz123"));
assert!(!contains_sensitive_info("token"));
assert!(!contains_sensitive_info("tokenize this"));
}
}