use axum::extract::MatchedPath;
use axum::{
body::Body,
extract::Request,
http::StatusCode,
middleware::Next,
response::{IntoResponse, Response},
};
use crate::{
resil::{BreakerConfig, BreakerPolicyConfig, BreakerRegistry},
rest::{
RestError, RestResilienceConfig,
middleware::{MiddlewareMetrics, record_resilience_event},
},
};
pub async fn breaker_middleware(
registry: BreakerRegistry,
service: String,
config: RestResilienceConfig,
metrics: MiddlewareMetrics,
request: Request<Body>,
next: Next,
) -> Response {
let key = breaker_key(&service, &request);
let breaker = registry
.get_or_insert_with_policy(
key,
BreakerConfig {
failure_threshold: config.breaker_failure_threshold,
reset_timeout: config.breaker_reset_timeout,
},
breaker_policy(&config),
)
.await;
let guard = match breaker.allow().await {
Ok(guard) => {
record_resilience_event(&metrics, "breaker", "allowed");
guard
}
Err(error) => {
record_resilience_event(&metrics, "breaker", "dropped");
return RestError::ServiceUnavailable(error.to_string()).into_response();
}
};
let response = next.run(request).await;
if response.status().is_server_error() {
record_resilience_event(&metrics, "breaker", "failure");
guard.record_failure().await;
} else {
record_resilience_event(&metrics, "breaker", "success");
guard.record_success().await;
}
response
}
fn breaker_key(service: &str, request: &Request<Body>) -> String {
let method = request.method().as_str();
let route = request
.extensions()
.get::<MatchedPath>()
.map(|path| path.as_str())
.unwrap_or("unknown");
format!("{service}:{method}:{route}")
}
pub fn is_failure_status(status: StatusCode) -> bool {
status.is_server_error()
}
pub fn breaker_policy(config: &RestResilienceConfig) -> BreakerPolicyConfig {
let mut policy = if config.breaker_sre_enabled {
BreakerPolicyConfig::google_sre()
} else {
BreakerPolicyConfig::default()
};
policy.sre_k_millis = config.breaker_sre_k_millis;
policy.sre_protection = config.breaker_sre_protection;
policy
}