use std::sync::Arc;
use std::time::Duration;
use tokio::sync::Semaphore;
use crate::metrics;
pub mod bucket;
pub mod circuit_breaker;
pub mod cost;
pub mod deadlines;
pub mod policy;
pub mod registry;
pub mod retry_budget;
pub use bucket::{TokenBucket, TokenBucketConfig};
pub use circuit_breaker::{
BreakerState, BreakerVerdict, CircuitBreaker, CircuitBreakerConfig, DeploymentProfile,
HealthObservation, OpenReason,
};
pub use cost::{cost_units, CostInputs, CostParams, DEFAULT_E_EXPAND};
pub use deadlines::{
run_with_deadline, run_with_deadline_or_cancel, Cancelled, DeadlineBudget, DeadlineConfig,
DeadlineError, RecallStage,
};
pub use policy::{
PolicyResolver, ProvisionalDefaults, QuotaPolicy, QuotaScope, ScopeKind, FALLBACK_DEFAULTS,
PROVISIONAL_DEFAULTS,
};
pub use registry::{BucketDimension, BucketKey, BucketRegistry, ConsumeOutcome};
pub use retry_budget::{
build_guidance, compute_retry_after, RetryBudget, RetryBudgetConfig, RetryGuidance,
};
pub const HARD_TOP_K_CAP: usize = 1000;
pub const MAX_REQUEST_BODY_BYTES: usize = 64 * 1024;
pub const DEFAULT_MAX_CONCURRENT_EXPANDED_RECALL: usize = 4;
pub const DEFAULT_MAX_IN_FLIGHT_RECALL: usize = 64;
pub const PERMIT_ACQUIRE_TIMEOUT: Duration = Duration::from_millis(100);
#[derive(Debug, Clone)]
pub struct AdmissionConfig {
pub max_concurrent_expanded_recall: usize,
pub max_in_flight_recall: usize,
pub max_request_body_bytes: usize,
pub hard_top_k_cap: usize,
}
impl Default for AdmissionConfig {
fn default() -> Self {
Self {
max_concurrent_expanded_recall: DEFAULT_MAX_CONCURRENT_EXPANDED_RECALL,
max_in_flight_recall: DEFAULT_MAX_IN_FLIGHT_RECALL,
max_request_body_bytes: MAX_REQUEST_BODY_BYTES,
hard_top_k_cap: HARD_TOP_K_CAP,
}
}
}
#[derive(Clone)]
pub struct AdmissionState {
pub cfg: Arc<AdmissionConfig>,
pub expanded_recall: Arc<Semaphore>,
pub in_flight_recall: Arc<Semaphore>,
}
impl AdmissionState {
pub fn new(cfg: AdmissionConfig) -> Self {
Self {
expanded_recall: Arc::new(Semaphore::new(cfg.max_concurrent_expanded_recall)),
in_flight_recall: Arc::new(Semaphore::new(cfg.max_in_flight_recall)),
cfg: Arc::new(cfg),
}
}
pub async fn acquire_recall_permits(
&self,
expand_entities: bool,
) -> Result<RecallPermits, RejectReason> {
let in_flight = match tokio::time::timeout(
PERMIT_ACQUIRE_TIMEOUT,
self.in_flight_recall.clone().acquire_owned(),
)
.await
{
Ok(Ok(p)) => p,
Ok(Err(_closed)) => return Err(RejectReason::ServerShutdown),
Err(_timeout) => {
metrics::increment_recall_rejected("in_flight_saturated");
return Err(RejectReason::InFlightSaturated);
}
};
let expanded = if expand_entities {
match tokio::time::timeout(
PERMIT_ACQUIRE_TIMEOUT,
self.expanded_recall.clone().acquire_owned(),
)
.await
{
Ok(Ok(p)) => Some(p),
Ok(Err(_closed)) => return Err(RejectReason::ServerShutdown),
Err(_timeout) => {
metrics::increment_recall_rejected("expanded_saturated");
return Err(RejectReason::ExpandedSaturated);
}
}
} else {
None
};
metrics::set_recall_in_flight_gauge(
(self.cfg.max_in_flight_recall - self.in_flight_recall.available_permits()) as i64,
);
if expand_entities {
metrics::set_expansion_concurrent_gauge(
(self.cfg.max_concurrent_expanded_recall - self.expanded_recall.available_permits())
as i64,
);
}
Ok(RecallPermits {
_in_flight: in_flight,
_expanded: expanded,
})
}
}
pub struct RecallPermits {
_in_flight: tokio::sync::OwnedSemaphorePermit,
_expanded: Option<tokio::sync::OwnedSemaphorePermit>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RejectReason {
TopKTooLarge,
InFlightSaturated,
ExpandedSaturated,
BodyTooLarge,
ServerShutdown,
}
impl RejectReason {
pub fn metric_label(&self) -> &'static str {
match self {
RejectReason::TopKTooLarge => "top_k_cap",
RejectReason::InFlightSaturated => "in_flight_saturated",
RejectReason::ExpandedSaturated => "expanded_saturated",
RejectReason::BodyTooLarge => "body_too_large",
RejectReason::ServerShutdown => "server_shutdown",
}
}
pub fn message(&self) -> &'static str {
match self {
RejectReason::TopKTooLarge => {
"top_k exceeds hard cap; reduce top_k or use the v2 scan endpoint when available"
}
RejectReason::InFlightSaturated => {
"server in-flight recall capacity exhausted; retry after a short backoff"
}
RejectReason::ExpandedSaturated => {
"server expanded-recall capacity exhausted; retry, or set expand_entities=false for cheap recall"
}
RejectReason::BodyTooLarge => "request body exceeds limit",
RejectReason::ServerShutdown => "server shutting down",
}
}
pub fn http_status(&self) -> u16 {
match self {
RejectReason::TopKTooLarge | RejectReason::BodyTooLarge => 400,
RejectReason::InFlightSaturated | RejectReason::ExpandedSaturated => 503,
RejectReason::ServerShutdown => 503,
}
}
}
pub fn check_top_k(top_k: usize, cap: usize) -> Result<(), RejectReason> {
if top_k > cap {
metrics::increment_recall_rejected("top_k_cap");
return Err(RejectReason::TopKTooLarge);
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn check_top_k_accepts_within_cap() {
assert!(check_top_k(10, HARD_TOP_K_CAP).is_ok());
assert!(check_top_k(HARD_TOP_K_CAP, HARD_TOP_K_CAP).is_ok());
}
#[test]
fn check_top_k_rejects_above_cap() {
let err = check_top_k(HARD_TOP_K_CAP + 1, HARD_TOP_K_CAP).unwrap_err();
assert_eq!(err, RejectReason::TopKTooLarge);
assert_eq!(err.http_status(), 400);
}
#[test]
fn reject_reason_metric_labels_are_stable() {
assert_eq!(RejectReason::TopKTooLarge.metric_label(), "top_k_cap");
assert_eq!(
RejectReason::InFlightSaturated.metric_label(),
"in_flight_saturated"
);
assert_eq!(
RejectReason::ExpandedSaturated.metric_label(),
"expanded_saturated"
);
assert_eq!(RejectReason::BodyTooLarge.metric_label(), "body_too_large");
assert_eq!(
RejectReason::ServerShutdown.metric_label(),
"server_shutdown"
);
}
#[tokio::test]
async fn acquire_permits_succeeds_within_cap() {
let st = AdmissionState::new(AdmissionConfig::default());
let p1 = st.acquire_recall_permits(true).await.unwrap();
let p2 = st.acquire_recall_permits(false).await.unwrap();
drop(p1);
drop(p2);
}
#[tokio::test]
async fn acquire_permits_rejects_when_expanded_saturated() {
let cfg = AdmissionConfig {
max_concurrent_expanded_recall: 1,
max_in_flight_recall: 8,
..Default::default()
};
let st = AdmissionState::new(cfg);
let _hold = st.acquire_recall_permits(true).await.unwrap();
let result = st.acquire_recall_permits(true).await;
assert!(matches!(result, Err(RejectReason::ExpandedSaturated)));
}
#[tokio::test]
async fn acquire_permits_rejects_when_in_flight_saturated() {
let cfg = AdmissionConfig {
max_concurrent_expanded_recall: 8,
max_in_flight_recall: 1,
..Default::default()
};
let st = AdmissionState::new(cfg);
let _hold = st.acquire_recall_permits(false).await.unwrap();
let result = st.acquire_recall_permits(false).await;
assert!(matches!(result, Err(RejectReason::InFlightSaturated)));
}
}