use crate::web::state::AppState;
use opentelemetry::KeyValue;
use tasker_shared::metrics::orchestration::api_requests_rejected_total;
use tasker_shared::types::web::{ApiError, ApiResult};
use tracing::warn;
pub use crate::api_common::circuit_breaker::{CircuitState, WebDatabaseCircuitBreaker};
pub async fn execute_with_circuit_breaker<T, E, F, Fut>(
state: &AppState,
operation: F,
) -> ApiResult<T>
where
F: FnOnce() -> Fut,
Fut: std::future::Future<Output = Result<T, E>>,
E: std::error::Error + Send + Sync + 'static,
{
use tracing::error;
if !state.is_database_healthy() {
record_circuit_breaker_rejection("unknown");
return Err(ApiError::CircuitBreakerOpen);
}
match operation().await {
Ok(result) => {
state.record_database_success();
Ok(result)
}
Err(e) => {
state.record_database_failure();
error!(error = %e, "Database operation failed");
Err(ApiError::database_error(format!("Operation failed: {e}")))
}
}
}
pub async fn execute_with_backpressure_check<T, E, F, Fut>(
state: &AppState,
endpoint: &str,
operation: F,
) -> ApiResult<T>
where
F: FnOnce() -> Fut,
Fut: std::future::Future<Output = Result<T, E>>,
E: std::error::Error + Send + Sync + 'static,
{
use tracing::error;
if let Some(backpressure_error) = state.check_backpressure_status() {
record_backpressure_rejection(endpoint, &backpressure_error);
return Err(backpressure_error);
}
match operation().await {
Ok(result) => {
state.record_database_success();
Ok(result)
}
Err(e) => {
state.record_database_failure();
error!(error = %e, endpoint = endpoint, "Operation failed");
Err(ApiError::database_error(format!("Operation failed: {e}")))
}
}
}
pub fn record_backpressure_rejection(endpoint: &str, error: &ApiError) {
let reason = match error {
ApiError::Backpressure { reason, .. } => reason.as_str(),
ApiError::CircuitBreakerOpen => "circuit_breaker",
_ => "unknown",
};
let counter = api_requests_rejected_total();
counter.add(
1,
&[
KeyValue::new("endpoint", endpoint.to_string()),
KeyValue::new("reason", reason.to_string()),
],
);
warn!(
endpoint = endpoint,
reason = reason,
"API request rejected due to backpressure"
);
}
pub fn record_circuit_breaker_rejection(endpoint: &str) {
let counter = api_requests_rejected_total();
counter.add(
1,
&[
KeyValue::new("endpoint", endpoint.to_string()),
KeyValue::new("reason", "circuit_breaker"),
],
);
warn!(
endpoint = endpoint,
reason = "circuit_breaker",
"API request rejected due to circuit breaker open"
);
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
#[test]
fn test_circuit_breaker_starts_closed() {
let cb = WebDatabaseCircuitBreaker::new(3, Duration::from_secs(5), "test");
assert!(!cb.is_circuit_open());
assert_eq!(cb.current_state(), CircuitState::Closed);
}
#[test]
fn test_circuit_opens_after_threshold_failures() {
let cb = WebDatabaseCircuitBreaker::new(3, Duration::from_secs(5), "test");
cb.record_failure();
cb.record_failure();
assert!(!cb.is_circuit_open());
assert_eq!(cb.current_state(), CircuitState::Closed);
cb.record_failure();
assert!(cb.is_circuit_open());
assert_eq!(cb.current_state(), CircuitState::Open);
}
#[test]
fn test_circuit_closes_on_success_via_half_open() {
let cb = WebDatabaseCircuitBreaker::new(
2,
Duration::ZERO, "test",
);
cb.record_failure();
cb.record_failure();
assert_eq!(cb.current_state(), CircuitState::Open);
assert!(!cb.is_circuit_open());
cb.record_success();
cb.record_success();
assert_eq!(cb.current_state(), CircuitState::Closed);
assert_eq!(cb.current_failures(), 0);
}
#[test]
fn test_default_circuit_breaker_configuration() {
let cb = WebDatabaseCircuitBreaker::default();
assert_eq!(cb.component_name(), "web_database");
assert_eq!(cb.current_state(), CircuitState::Closed);
assert_eq!(cb.current_failures(), 0);
assert!(!cb.is_circuit_open());
}
#[test]
fn test_component_name_accessor() {
let cb = WebDatabaseCircuitBreaker::new(5, Duration::from_secs(30), "custom_component");
assert_eq!(cb.component_name(), "custom_component");
}
#[test]
fn test_failure_count_increments_correctly() {
let cb = WebDatabaseCircuitBreaker::new(10, Duration::from_secs(30), "test");
assert_eq!(cb.current_failures(), 0);
cb.record_failure();
assert_eq!(cb.current_failures(), 1);
cb.record_failure();
assert_eq!(cb.current_failures(), 2);
cb.record_failure();
assert_eq!(cb.current_failures(), 3);
}
#[test]
fn test_success_resets_failure_count() {
let cb = WebDatabaseCircuitBreaker::new(10, Duration::from_secs(30), "test");
cb.record_failure();
cb.record_failure();
cb.record_failure();
assert_eq!(cb.current_failures(), 3);
cb.record_success();
assert_eq!(cb.current_failures(), 0);
}
#[test]
fn test_success_failure_success_sequence() {
let cb = WebDatabaseCircuitBreaker::new(3, Duration::from_secs(30), "test");
cb.record_failure();
cb.record_failure();
assert_eq!(cb.current_failures(), 2);
assert_eq!(cb.current_state(), CircuitState::Closed);
cb.record_success();
assert_eq!(cb.current_failures(), 0);
cb.record_failure();
assert_eq!(cb.current_failures(), 1);
cb.record_success();
assert_eq!(cb.current_failures(), 0);
}
#[test]
fn test_circuit_breaker_exact_threshold() {
let cb = WebDatabaseCircuitBreaker::new(5, Duration::from_secs(30), "test");
for i in 1..5 {
cb.record_failure();
assert!(
!cb.is_circuit_open(),
"Circuit should be closed at {} failures (threshold is 5)",
i
);
}
cb.record_failure();
assert!(
cb.is_circuit_open(),
"Circuit should be open at threshold (5 failures)"
);
}
#[test]
fn test_multiple_successes_keep_circuit_closed() {
let cb = WebDatabaseCircuitBreaker::new(3, Duration::from_secs(30), "test");
cb.record_success();
cb.record_success();
cb.record_success();
assert_eq!(cb.current_state(), CircuitState::Closed);
assert_eq!(cb.current_failures(), 0);
}
#[test]
fn test_open_circuit_stays_open_without_recovery() {
let cb = WebDatabaseCircuitBreaker::new(2, Duration::from_secs(3600), "test");
cb.record_failure();
cb.record_failure();
assert!(cb.is_circuit_open());
assert!(cb.is_circuit_open());
assert_eq!(cb.current_state(), CircuitState::Open);
}
#[test]
fn test_record_backpressure_rejection_runs_without_panic() {
let error = tasker_shared::types::web::ApiError::backpressure("test", 5);
record_backpressure_rejection("/v1/tasks", &error);
}
#[test]
fn test_record_circuit_breaker_rejection_runs_without_panic() {
record_circuit_breaker_rejection("/v1/tasks");
}
}