use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, OnceLock, RwLock};
#[derive(Debug)]
pub struct CooperativeInferencePolicy {
pub enabled: bool,
pub allow_coordinator: bool,
pub allow_worker: bool,
pub max_replicas: usize,
pub max_worker_timeout_ms: u32,
pub max_concurrent_remote: usize,
pub allow_tool_execution: bool,
in_flight: AtomicUsize,
}
#[derive(Debug, Clone)]
pub struct CooperativeInferencePolicyOptions {
pub enabled: bool,
pub allow_coordinator: bool,
pub allow_worker: bool,
pub max_replicas: usize,
pub max_worker_timeout_ms: u32,
pub max_concurrent_remote: usize,
pub allow_tool_execution: bool,
}
impl Default for CooperativeInferencePolicyOptions {
fn default() -> Self {
Self {
enabled: false,
allow_coordinator: false,
allow_worker: false,
max_replicas: 3,
max_worker_timeout_ms: 30_000,
max_concurrent_remote: 2,
allow_tool_execution: false,
}
}
}
impl CooperativeInferencePolicy {
pub fn new(opts: CooperativeInferencePolicyOptions) -> Self {
Self {
enabled: opts.enabled,
allow_coordinator: opts.allow_coordinator,
allow_worker: opts.allow_worker,
max_replicas: opts.max_replicas.max(1),
max_worker_timeout_ms: opts.max_worker_timeout_ms.clamp(1, 60_000),
max_concurrent_remote: opts.max_concurrent_remote.max(1),
allow_tool_execution: opts.allow_tool_execution,
in_flight: AtomicUsize::new(0),
}
}
pub fn permits_intent(&self, intent: &str) -> bool {
let domain = intent.split(':').nth(3).unwrap_or("");
self.allow_tool_execution || domain != "tool"
}
pub fn check_coordinator(&self) -> bool {
self.enabled && self.allow_coordinator
}
pub fn check_worker(&self) -> bool {
self.enabled && self.allow_worker
}
pub fn try_acquire_cip_slot(&self) -> bool {
let mut cur = self.in_flight.load(Ordering::Acquire);
loop {
if cur >= self.max_concurrent_remote {
return false;
}
match self
.in_flight
.compare_exchange(cur, cur + 1, Ordering::AcqRel, Ordering::Acquire)
{
Ok(_) => return true,
Err(actual) => cur = actual,
}
}
}
pub fn release_cip_slot(&self) {
let prev = self.in_flight.fetch_sub(1, Ordering::AcqRel);
debug_assert!(prev > 0, "release_cip_slot called more times than acquire");
}
pub fn as_register_policy_block(&self) -> Option<serde_json::Value> {
if !self.enabled {
return None;
}
Some(serde_json::json!({
"allow_remote_inference": self.allow_worker,
"allow_tool_execution": self.allow_tool_execution,
}))
}
}
static GLOBAL_POLICY: OnceLock<RwLock<Arc<CooperativeInferencePolicy>>> = OnceLock::new();
fn lock() -> &'static RwLock<Arc<CooperativeInferencePolicy>> {
GLOBAL_POLICY.get_or_init(|| {
RwLock::new(Arc::new(CooperativeInferencePolicy::new(
CooperativeInferencePolicyOptions::default(),
)))
})
}
pub fn get_cip_policy() -> Arc<CooperativeInferencePolicy> {
lock().read().expect("poisoned").clone()
}
pub fn configure_cip_policy(
opts: CooperativeInferencePolicyOptions,
) -> Arc<CooperativeInferencePolicy> {
let new = Arc::new(CooperativeInferencePolicy::new(opts));
let mut guard = lock().write().expect("poisoned");
*guard = new.clone();
new
}