use anyhow::{anyhow, Result};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fmt;
use std::sync::Arc;
use std::time::{Duration, Instant, SystemTime};
use tokio::sync::RwLock;
use tracing::{info, warn};
use scirs2_core::random::secure::SecureRandom;
use sha2::Sha256;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CorsConfig {
pub allowed_origins: Vec<String>,
pub allowed_methods: Vec<String>,
pub allowed_headers: Vec<String>,
pub exposed_headers: Vec<String>,
pub allow_credentials: bool,
pub max_age: Option<u64>,
pub wildcard_origins: bool,
}
impl Default for CorsConfig {
fn default() -> Self {
Self {
allowed_origins: vec!["*".to_string()],
allowed_methods: vec!["GET".to_string(), "POST".to_string(), "OPTIONS".to_string()],
allowed_headers: vec![
"Content-Type".to_string(),
"Authorization".to_string(),
"X-Requested-With".to_string(),
],
exposed_headers: vec!["Content-Length".to_string()],
allow_credentials: false,
max_age: Some(3600), wildcard_origins: false,
}
}
}
impl CorsConfig {
pub fn production(origins: Vec<String>) -> Self {
Self {
allowed_origins: origins,
allowed_methods: vec!["GET".to_string(), "POST".to_string(), "OPTIONS".to_string()],
allowed_headers: vec![
"Content-Type".to_string(),
"Authorization".to_string(),
"X-Request-ID".to_string(),
"X-API-Key".to_string(),
],
exposed_headers: vec!["Content-Length".to_string(), "X-Request-ID".to_string()],
allow_credentials: true,
max_age: Some(86400), wildcard_origins: false,
}
}
pub fn is_origin_allowed(&self, origin: &str) -> bool {
if self.allowed_origins.contains(&"*".to_string()) {
return true;
}
if self.wildcard_origins {
for allowed in &self.allowed_origins {
if Self::wildcard_match(allowed, origin) {
return true;
}
}
}
self.allowed_origins.contains(&origin.to_string())
}
fn wildcard_match(pattern: &str, origin: &str) -> bool {
if pattern.contains('*') {
let parts: Vec<&str> = pattern.split('*').collect();
if parts.len() == 2 {
origin.starts_with(parts[0]) && origin.ends_with(parts[1])
} else {
false
}
} else {
pattern == origin
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JwtConfig {
pub secret: String,
pub algorithm: JwtAlgorithm,
pub expiration: u64,
pub issuer: Option<String>,
pub audience: Option<Vec<String>>,
pub required_claims: Vec<String>,
pub enable_refresh: bool,
pub refresh_expiration: u64,
}
impl Default for JwtConfig {
fn default() -> Self {
Self {
secret: "".to_string(), algorithm: JwtAlgorithm::HS256,
expiration: 3600, issuer: None,
audience: None,
required_claims: vec!["sub".to_string()],
enable_refresh: true,
refresh_expiration: 604800, }
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum JwtAlgorithm {
HS256,
HS384,
HS512,
RS256,
RS384,
RS512,
}
impl fmt::Display for JwtAlgorithm {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
JwtAlgorithm::HS256 => write!(f, "HS256"),
JwtAlgorithm::HS384 => write!(f, "HS384"),
JwtAlgorithm::HS512 => write!(f, "HS512"),
JwtAlgorithm::RS256 => write!(f, "RS256"),
JwtAlgorithm::RS384 => write!(f, "RS384"),
JwtAlgorithm::RS512 => write!(f, "RS512"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JwtClaims {
pub sub: String,
pub exp: u64,
pub iat: u64,
#[serde(skip_serializing_if = "Option::is_none")]
pub iss: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub aud: Option<Vec<String>>,
#[serde(flatten)]
pub custom: HashMap<String, serde_json::Value>,
}
impl JwtClaims {
pub fn new(sub: String, expiration: u64) -> Self {
let now = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.expect("SystemTime should be after UNIX_EPOCH")
.as_secs();
Self {
sub,
exp: now + expiration,
iat: now,
iss: None,
aud: None,
custom: HashMap::new(),
}
}
pub fn with_claim(mut self, key: String, value: serde_json::Value) -> Self {
self.custom.insert(key, value);
self
}
pub fn is_expired(&self) -> bool {
let now = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.expect("SystemTime should be after UNIX_EPOCH")
.as_secs();
now >= self.exp
}
}
pub struct JwtManager {
config: JwtConfig,
}
impl JwtManager {
pub fn new(config: JwtConfig) -> Result<Self> {
if config.secret.is_empty() {
return Err(anyhow!("JWT secret cannot be empty"));
}
Ok(Self { config })
}
pub fn generate_secure_secret() -> String {
let mut rng = SecureRandom::new();
rng.random_alphanumeric(64)
}
pub fn generate_token(&self, claims: &JwtClaims) -> Result<String> {
if self.config.secret.is_empty() {
return Err(anyhow!("JWT secret not configured"));
}
let header = serde_json::json!({
"typ": "JWT",
"alg": self.config.algorithm.to_string()
});
let header_b64 = base64_url_encode(&serde_json::to_string(&header)?);
let payload_b64 = base64_url_encode(&serde_json::to_string(claims)?);
let signing_input = format!("{}.{}", header_b64, payload_b64);
let signature = self.sign_hmac_sha256(&signing_input)?;
let signature_b64 = base64_url_encode(&hex::encode(&signature));
Ok(format!("{}.{}", signing_input, signature_b64))
}
fn sign_hmac_sha256(&self, data: &str) -> Result<Vec<u8>> {
use hmac::{Hmac, Mac};
type HmacSha256 = Hmac<Sha256>;
let mut mac = HmacSha256::new_from_slice(self.config.secret.as_bytes())
.map_err(|e| anyhow!("Invalid key length: {}", e))?;
mac.update(data.as_bytes());
Ok(mac.finalize().into_bytes().to_vec())
}
pub fn verify_token(&self, token: &str) -> Result<JwtClaims> {
let parts: Vec<&str> = token.split('.').collect();
if parts.len() != 3 {
return Err(anyhow!(
"Invalid token format: expected 3 parts, got {}",
parts.len()
));
}
let signing_input = format!("{}.{}", parts[0], parts[1]);
let expected_signature = self.sign_hmac_sha256(&signing_input)?;
let expected_signature_b64 = base64_url_encode(&hex::encode(&expected_signature));
if parts[2] != expected_signature_b64 {
return Err(anyhow!("Invalid token signature"));
}
let payload_json = base64_url_decode(parts[1])?;
let claims: JwtClaims = serde_json::from_str(&payload_json)
.map_err(|e| anyhow!("Invalid token payload: {}", e))?;
if claims.is_expired() {
return Err(anyhow!("Token expired at {}", claims.exp));
}
Ok(claims)
}
}
fn base64_url_encode(data: &str) -> String {
const CHARSET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_";
let bytes = data.as_bytes();
let mut result = String::with_capacity((bytes.len() * 4 / 3) + 4);
for chunk in bytes.chunks(3) {
let b1 = chunk[0];
let b2 = chunk.get(1).copied().unwrap_or(0);
let b3 = chunk.get(2).copied().unwrap_or(0);
let n = (b1 as u32) << 16 | (b2 as u32) << 8 | (b3 as u32);
result.push(CHARSET[((n >> 18) & 0x3F) as usize] as char);
result.push(CHARSET[((n >> 12) & 0x3F) as usize] as char);
if chunk.len() > 1 {
result.push(CHARSET[((n >> 6) & 0x3F) as usize] as char);
}
if chunk.len() > 2 {
result.push(CHARSET[(n & 0x3F) as usize] as char);
}
}
result
}
fn base64_url_decode(data: &str) -> Result<String> {
let data = data.replace('-', "+").replace('_', "/");
let padding = match data.len() % 4 {
2 => "==",
3 => "=",
_ => "",
};
let data_with_padding = format!("{}{}", data, padding);
let decoded_bytes = decode_base64_standard(&data_with_padding)?;
String::from_utf8(decoded_bytes).map_err(|e| anyhow!("Invalid UTF-8 in decoded data: {}", e))
}
fn decode_base64_standard(data: &str) -> Result<Vec<u8>> {
const DECODE_TABLE: [u8; 256] = {
let mut table = [0xFF; 256];
let chars = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
let mut i = 0;
while i < 64 {
table[chars[i] as usize] = i as u8;
i += 1;
}
table
};
let bytes = data.as_bytes();
let mut result = Vec::with_capacity((bytes.len() * 3) / 4);
let mut buffer = 0u32;
let mut bits = 0;
for &byte in bytes {
if byte == b'=' {
break;
}
let value = DECODE_TABLE[byte as usize];
if value == 0xFF {
continue; }
buffer = (buffer << 6) | (value as u32);
bits += 6;
if bits >= 8 {
bits -= 8;
result.push((buffer >> bits) as u8);
buffer &= (1 << bits) - 1;
}
}
Ok(result)
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OpenTelemetryConfig {
pub service_name: String,
pub service_version: String,
pub otlp_endpoint: Option<String>,
pub sampling_rate: f64,
pub enable_export: bool,
pub export_timeout: u64,
pub batch_config: BatchConfig,
}
impl Default for OpenTelemetryConfig {
fn default() -> Self {
Self {
service_name: "oxirs-gql".to_string(),
service_version: env!("CARGO_PKG_VERSION").to_string(),
otlp_endpoint: None,
sampling_rate: 1.0,
enable_export: true,
export_timeout: 30,
batch_config: BatchConfig::default(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BatchConfig {
pub max_queue_size: usize,
pub max_export_batch_size: usize,
pub scheduled_delay_millis: u64,
pub max_export_timeout_millis: u64,
}
impl Default for BatchConfig {
fn default() -> Self {
Self {
max_queue_size: 2048,
max_export_batch_size: 512,
scheduled_delay_millis: 5000,
max_export_timeout_millis: 30000,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TraceContext {
pub trace_id: String,
pub span_id: String,
pub parent_span_id: Option<String>,
pub trace_flags: u8,
pub baggage: HashMap<String, String>,
}
impl Default for TraceContext {
fn default() -> Self {
Self::new()
}
}
impl TraceContext {
pub fn new() -> Self {
Self {
trace_id: generate_trace_id(),
span_id: generate_span_id(),
parent_span_id: None,
trace_flags: 1, baggage: HashMap::new(),
}
}
pub fn create_child(&self) -> Self {
Self {
trace_id: self.trace_id.clone(),
span_id: generate_span_id(),
parent_span_id: Some(self.span_id.clone()),
trace_flags: self.trace_flags,
baggage: self.baggage.clone(),
}
}
pub fn from_traceparent(header: &str) -> Result<Self> {
let parts: Vec<&str> = header.split('-').collect();
if parts.len() != 4 {
return Err(anyhow!("Invalid traceparent format"));
}
Ok(Self {
trace_id: parts[1].to_string(),
span_id: parts[2].to_string(),
parent_span_id: None,
trace_flags: u8::from_str_radix(parts[3], 16)?,
baggage: HashMap::new(),
})
}
pub fn to_traceparent(&self) -> String {
format!(
"00-{}-{}-{:02x}",
self.trace_id, self.span_id, self.trace_flags
)
}
}
fn generate_trace_id() -> String {
uuid::Uuid::new_v4().to_string().replace("-", "")
}
fn generate_span_id() -> String {
format!("{:016x}", fastrand::u64(..))
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConnectionPoolConfig {
pub min_connections: usize,
pub max_connections: usize,
pub connection_timeout: u64,
pub idle_timeout: u64,
pub max_lifetime: u64,
pub enable_health_check: bool,
pub health_check_interval: u64,
}
impl Default for ConnectionPoolConfig {
fn default() -> Self {
Self {
min_connections: 5,
max_connections: 100,
connection_timeout: 30,
idle_timeout: 600, max_lifetime: 1800, enable_health_check: true,
health_check_interval: 30,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PoolStats {
pub total_connections: usize,
pub active_connections: usize,
pub idle_connections: usize,
pub waiting_requests: usize,
pub total_requests: u64,
pub failed_connections: u64,
}
pub struct ConnectionPool {
config: ConnectionPoolConfig,
stats: Arc<RwLock<PoolStats>>,
last_adjusted: Arc<RwLock<Instant>>,
}
impl ConnectionPool {
pub fn new(config: ConnectionPoolConfig) -> Self {
Self {
config,
stats: Arc::new(RwLock::new(PoolStats {
total_connections: 0,
active_connections: 0,
idle_connections: 0,
waiting_requests: 0,
total_requests: 0,
failed_connections: 0,
})),
last_adjusted: Arc::new(RwLock::new(Instant::now())),
}
}
pub async fn get_stats(&self) -> PoolStats {
self.stats.read().await.clone()
}
pub async fn adjust_pool_size(&self, load_factor: f64) -> Result<()> {
let mut last_adjusted = self.last_adjusted.write().await;
let now = Instant::now();
if now.duration_since(*last_adjusted) < Duration::from_secs(60) {
return Ok(());
}
let mut stats = self.stats.write().await;
let current_size = stats.total_connections;
let target_size = if load_factor > 0.8 {
std::cmp::min(
current_size + (current_size as f64 * 0.2) as usize,
self.config.max_connections,
)
} else if load_factor < 0.3 && current_size > self.config.min_connections {
std::cmp::max(
current_size - (current_size as f64 * 0.1) as usize,
self.config.min_connections,
)
} else {
current_size
};
if target_size != current_size {
info!(
"Adjusting connection pool size: {} -> {} (load: {:.2})",
current_size, target_size, load_factor
);
stats.total_connections = target_size;
}
*last_adjusted = now;
Ok(())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum HealthStatus {
Healthy,
Degraded,
Unhealthy,
}
impl fmt::Display for HealthStatus {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
HealthStatus::Healthy => write!(f, "healthy"),
HealthStatus::Degraded => write!(f, "degraded"),
HealthStatus::Unhealthy => write!(f, "unhealthy"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HealthCheckResult {
pub status: HealthStatus,
pub version: String,
pub uptime: u64,
pub timestamp: String,
pub checks: HashMap<String, ComponentHealth>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ComponentHealth {
pub status: HealthStatus,
pub response_time_ms: Option<u64>,
pub error: Option<String>,
pub metadata: HashMap<String, serde_json::Value>,
}
pub struct HealthChecker {
start_time: Instant,
checks: HashMap<String, Box<dyn Fn() -> ComponentHealth + Send + Sync>>,
}
impl Default for HealthChecker {
fn default() -> Self {
Self::new()
}
}
impl HealthChecker {
pub fn new() -> Self {
Self {
start_time: Instant::now(),
checks: HashMap::new(),
}
}
pub fn register_check<F>(&mut self, name: String, check: F)
where
F: Fn() -> ComponentHealth + Send + Sync + 'static,
{
self.checks.insert(name, Box::new(check));
}
pub fn check_health(&self) -> HealthCheckResult {
let mut checks = HashMap::new();
let mut overall_status = HealthStatus::Healthy;
for (name, check) in &self.checks {
let component_health = check();
match component_health.status {
HealthStatus::Unhealthy => overall_status = HealthStatus::Unhealthy,
HealthStatus::Degraded if overall_status == HealthStatus::Healthy => {
overall_status = HealthStatus::Degraded
}
_ => {}
}
checks.insert(name.clone(), component_health);
}
HealthCheckResult {
status: overall_status,
version: env!("CARGO_PKG_VERSION").to_string(),
uptime: self.start_time.elapsed().as_secs(),
timestamp: chrono::Utc::now().to_rfc3339(),
checks,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RequestLog {
pub request_id: String,
pub method: String,
pub path: String,
pub query_string: Option<String>,
pub status_code: u16,
pub response_time_ms: u64,
pub request_size: u64,
pub response_size: u64,
pub client_ip: String,
pub user_agent: Option<String>,
pub timestamp: String,
pub trace_context: Option<TraceContext>,
}
impl RequestLog {
pub fn new(request_id: String, method: String, path: String, client_ip: String) -> Self {
Self {
request_id,
method,
path,
query_string: None,
status_code: 200,
response_time_ms: 0,
request_size: 0,
response_size: 0,
client_ip,
user_agent: None,
timestamp: chrono::Utc::now().to_rfc3339(),
trace_context: None,
}
}
pub fn to_json(&self) -> Result<String> {
Ok(serde_json::to_string(self)?)
}
}
#[derive(Debug, Clone)]
pub struct Metrics {
pub requests_total: u64,
pub errors_total: u64,
pub request_duration_ms: Vec<u64>,
pub active_connections: usize,
pub pool_size: usize,
}
impl Default for Metrics {
fn default() -> Self {
Self::new()
}
}
impl Metrics {
pub fn new() -> Self {
Self {
requests_total: 0,
errors_total: 0,
request_duration_ms: Vec::new(),
active_connections: 0,
pool_size: 0,
}
}
pub fn to_prometheus(&self) -> String {
let mut output = String::new();
output.push_str("# HELP oxirs_gql_requests_total Total number of requests\n");
output.push_str("# TYPE oxirs_gql_requests_total counter\n");
output.push_str(&format!(
"oxirs_gql_requests_total {}\n",
self.requests_total
));
output.push_str("# HELP oxirs_gql_errors_total Total number of errors\n");
output.push_str("# TYPE oxirs_gql_errors_total counter\n");
output.push_str(&format!("oxirs_gql_errors_total {}\n", self.errors_total));
output.push_str("# HELP oxirs_gql_active_connections Current active connections\n");
output.push_str("# TYPE oxirs_gql_active_connections gauge\n");
output.push_str(&format!(
"oxirs_gql_active_connections {}\n",
self.active_connections
));
output.push_str("# HELP oxirs_gql_pool_size Current connection pool size\n");
output.push_str("# TYPE oxirs_gql_pool_size gauge\n");
output.push_str(&format!("oxirs_gql_pool_size {}\n", self.pool_size));
output
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CircuitState {
Closed,
Open,
HalfOpen,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CircuitBreakerConfig {
pub failure_threshold: usize,
pub success_threshold: usize,
pub timeout: u64,
pub window_size: u64,
}
impl Default for CircuitBreakerConfig {
fn default() -> Self {
Self {
failure_threshold: 5,
success_threshold: 2,
timeout: 60,
window_size: 60,
}
}
}
pub struct CircuitBreaker {
config: CircuitBreakerConfig,
state: Arc<RwLock<CircuitState>>,
failures: Arc<RwLock<usize>>,
successes: Arc<RwLock<usize>>,
last_failure_time: Arc<RwLock<Option<Instant>>>,
}
impl CircuitBreaker {
pub fn new(config: CircuitBreakerConfig) -> Self {
Self {
config,
state: Arc::new(RwLock::new(CircuitState::Closed)),
failures: Arc::new(RwLock::new(0)),
successes: Arc::new(RwLock::new(0)),
last_failure_time: Arc::new(RwLock::new(None)),
}
}
pub async fn is_allowed(&self) -> bool {
let state = *self.state.read().await;
match state {
CircuitState::Closed => true,
CircuitState::Open => {
if let Some(last_failure) = *self.last_failure_time.read().await {
if last_failure.elapsed() > Duration::from_secs(self.config.timeout) {
*self.state.write().await = CircuitState::HalfOpen;
true
} else {
false
}
} else {
false
}
}
CircuitState::HalfOpen => true,
}
}
pub async fn record_success(&self) {
let state = *self.state.read().await;
match state {
CircuitState::HalfOpen => {
let mut successes = self.successes.write().await;
*successes += 1;
if *successes >= self.config.success_threshold {
*self.state.write().await = CircuitState::Closed;
*successes = 0;
*self.failures.write().await = 0;
info!("Circuit breaker closed after successful recovery");
}
}
CircuitState::Closed => {
*self.failures.write().await = 0;
}
_ => {}
}
}
pub async fn record_failure(&self) {
let state = *self.state.read().await;
match state {
CircuitState::Closed | CircuitState::HalfOpen => {
let mut failures = self.failures.write().await;
*failures += 1;
if *failures >= self.config.failure_threshold {
*self.state.write().await = CircuitState::Open;
*self.last_failure_time.write().await = Some(Instant::now());
*self.successes.write().await = 0;
warn!("Circuit breaker opened due to failures");
}
}
_ => {}
}
}
pub async fn get_state(&self) -> CircuitState {
*self.state.read().await
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cors_config_default() {
let config = CorsConfig::default();
assert!(config.is_origin_allowed("https://example.com"));
assert_eq!(config.max_age, Some(3600));
}
#[test]
fn test_cors_origin_check() {
let config = CorsConfig::production(vec!["https://example.com".to_string()]);
assert!(config.is_origin_allowed("https://example.com"));
assert!(!config.is_origin_allowed("https://evil.com"));
}
#[test]
fn test_cors_wildcard_matching() {
let config = CorsConfig {
allowed_origins: vec!["https://*.example.com".to_string()],
wildcard_origins: true,
..Default::default()
};
assert!(config.is_origin_allowed("https://app.example.com"));
assert!(config.is_origin_allowed("https://api.example.com"));
assert!(!config.is_origin_allowed("https://example.com"));
}
#[test]
fn test_jwt_algorithm_display() {
assert_eq!(JwtAlgorithm::HS256.to_string(), "HS256");
assert_eq!(JwtAlgorithm::RS512.to_string(), "RS512");
}
#[test]
fn test_jwt_claims_creation() {
let claims = JwtClaims::new("user123".to_string(), 3600);
assert_eq!(claims.sub, "user123");
assert!(!claims.is_expired());
}
#[test]
fn test_jwt_claims_custom() {
let claims = JwtClaims::new("user123".to_string(), 3600)
.with_claim("role".to_string(), serde_json::json!("admin"));
assert_eq!(
claims.custom.get("role").expect("should succeed"),
&serde_json::json!("admin")
);
}
#[test]
fn test_trace_context_creation() {
let ctx = TraceContext::new();
assert!(!ctx.trace_id.is_empty());
assert!(!ctx.span_id.is_empty());
assert_eq!(ctx.trace_flags, 1);
}
#[test]
fn test_trace_context_child() {
let parent = TraceContext::new();
let child = parent.create_child();
assert_eq!(child.trace_id, parent.trace_id);
assert_ne!(child.span_id, parent.span_id);
assert_eq!(child.parent_span_id, Some(parent.span_id.clone()));
}
#[test]
fn test_trace_context_traceparent() {
let ctx = TraceContext::new();
let header = ctx.to_traceparent();
assert!(header.starts_with("00-"));
assert_eq!(header.split('-').count(), 4);
}
#[test]
fn test_connection_pool_config() {
let config = ConnectionPoolConfig::default();
assert_eq!(config.min_connections, 5);
assert_eq!(config.max_connections, 100);
}
#[tokio::test]
async fn test_connection_pool_stats() {
let pool = ConnectionPool::new(ConnectionPoolConfig::default());
let stats = pool.get_stats().await;
assert_eq!(stats.total_connections, 0);
assert_eq!(stats.active_connections, 0);
}
#[test]
fn test_health_status_display() {
assert_eq!(HealthStatus::Healthy.to_string(), "healthy");
assert_eq!(HealthStatus::Degraded.to_string(), "degraded");
assert_eq!(HealthStatus::Unhealthy.to_string(), "unhealthy");
}
#[test]
fn test_health_checker() {
let mut checker = HealthChecker::new();
checker.register_check("database".to_string(), || ComponentHealth {
status: HealthStatus::Healthy,
response_time_ms: Some(5),
error: None,
metadata: HashMap::new(),
});
let result = checker.check_health();
assert_eq!(result.status, HealthStatus::Healthy);
assert!(result.checks.contains_key("database"));
}
#[test]
fn test_request_log_creation() {
let log = RequestLog::new(
"req-123".to_string(),
"POST".to_string(),
"/graphql".to_string(),
"127.0.0.1".to_string(),
);
assert_eq!(log.request_id, "req-123");
assert_eq!(log.method, "POST");
assert_eq!(log.status_code, 200);
}
#[test]
fn test_metrics_prometheus_export() {
let mut metrics = Metrics::new();
metrics.requests_total = 100;
metrics.errors_total = 5;
metrics.active_connections = 10;
let output = metrics.to_prometheus();
assert!(output.contains("oxirs_gql_requests_total 100"));
assert!(output.contains("oxirs_gql_errors_total 5"));
assert!(output.contains("oxirs_gql_active_connections 10"));
}
#[test]
fn test_circuit_breaker_config() {
let config = CircuitBreakerConfig::default();
assert_eq!(config.failure_threshold, 5);
assert_eq!(config.success_threshold, 2);
}
#[tokio::test]
async fn test_circuit_breaker_state() {
let breaker = CircuitBreaker::new(CircuitBreakerConfig::default());
assert!(breaker.is_allowed().await);
assert_eq!(breaker.get_state().await, CircuitState::Closed);
}
#[tokio::test]
async fn test_circuit_breaker_failure() {
let config = CircuitBreakerConfig {
failure_threshold: 2,
..Default::default()
};
let breaker = CircuitBreaker::new(config);
breaker.record_failure().await;
assert_eq!(breaker.get_state().await, CircuitState::Closed);
breaker.record_failure().await;
assert_eq!(breaker.get_state().await, CircuitState::Open);
assert!(!breaker.is_allowed().await);
}
#[tokio::test]
async fn test_circuit_breaker_recovery() {
let config = CircuitBreakerConfig {
failure_threshold: 1,
success_threshold: 1,
timeout: 0, ..Default::default()
};
let breaker = CircuitBreaker::new(config);
breaker.record_failure().await;
assert_eq!(breaker.get_state().await, CircuitState::Open);
tokio::time::sleep(Duration::from_millis(100)).await;
assert!(breaker.is_allowed().await);
assert_eq!(breaker.get_state().await, CircuitState::HalfOpen);
breaker.record_success().await;
assert_eq!(breaker.get_state().await, CircuitState::Closed);
}
}