use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{OwnedSemaphorePermit, Semaphore};
pub const ENV_MAX_CONCURRENT: &str = "CAR_INFERENCE_MAX_CONCURRENT";
const SLOW_ACQUIRE_LOG_MS: u128 = 100;
#[derive(Clone)]
pub struct InferenceAdmission {
sem: Arc<Semaphore>,
permits: usize,
}
impl InferenceAdmission {
pub fn new() -> Self {
let permits = chosen_permit_count();
tracing::info!(
permits,
env = ENV_MAX_CONCURRENT,
"inference admission controller online"
);
Self::with_permits(permits)
}
pub fn with_permits(permits: usize) -> Self {
let permits = permits.max(1);
Self {
sem: Arc::new(Semaphore::new(permits)),
permits,
}
}
pub async fn acquire(&self) -> OwnedSemaphorePermit {
let started = std::time::Instant::now();
let permit = self
.sem
.clone()
.acquire_owned()
.await
.expect("inference admission semaphore is never closed");
let waited_ms = started.elapsed().as_millis();
if waited_ms >= SLOW_ACQUIRE_LOG_MS {
tracing::info!(
waited_ms,
permits_total = self.permits,
permits_available = self.sem.available_permits(),
"inference request queued behind concurrency limit"
);
}
permit
}
pub fn try_acquire(&self) -> Option<OwnedSemaphorePermit> {
self.sem.clone().try_acquire_owned().ok()
}
pub async fn acquire_with_timeout(
&self,
max_wait: Duration,
) -> Option<OwnedSemaphorePermit> {
match tokio::time::timeout(max_wait, self.acquire()).await {
Ok(permit) => Some(permit),
Err(_) => {
tracing::warn!(
max_wait_ms = max_wait.as_millis() as u64,
permits_total = self.permits,
"inference admission acquire timed out"
);
None
}
}
}
pub fn permits(&self) -> usize {
self.permits
}
pub fn permits_available(&self) -> usize {
self.sem.available_permits()
}
}
impl Default for InferenceAdmission {
fn default() -> Self {
Self::new()
}
}
fn chosen_permit_count() -> usize {
if let Ok(raw) = std::env::var(ENV_MAX_CONCURRENT) {
if let Ok(n) = raw.trim().parse::<usize>() {
if n >= 1 {
return n;
}
}
tracing::warn!(
value = %raw,
"{} must be a positive integer; ignoring and falling back to auto-sizing",
ENV_MAX_CONCURRENT
);
}
let total_ram_mb = host_ram_mb();
let auto = (total_ram_mb / 8192).max(1).min(8) as usize;
auto
}
fn host_ram_mb() -> u64 {
#[cfg(target_os = "macos")]
{
if let Ok(output) = std::process::Command::new("sysctl")
.args(["-n", "hw.memsize"])
.output()
{
if output.status.success() {
if let Ok(s) = String::from_utf8(output.stdout) {
if let Ok(bytes) = s.trim().parse::<u64>() {
return bytes / (1024 * 1024);
}
}
}
}
}
#[cfg(target_os = "linux")]
{
if let Ok(content) = std::fs::read_to_string("/proc/meminfo") {
for line in content.lines() {
if let Some(rest) = line.strip_prefix("MemTotal:") {
let parts: Vec<&str> = rest.split_whitespace().collect();
if let Some(kb_str) = parts.first() {
if let Ok(kb) = kb_str.parse::<u64>() {
return kb / 1024;
}
}
}
}
}
}
16 * 1024
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn permits_clamps_to_at_least_one() {
let admission = InferenceAdmission::with_permits(0);
assert_eq!(admission.permits(), 1);
}
#[tokio::test]
async fn try_acquire_returns_none_when_full() {
let admission = InferenceAdmission::with_permits(1);
let _held = admission.acquire().await;
assert!(admission.try_acquire().is_none());
}
#[tokio::test]
async fn acquire_with_timeout_returns_none_on_full_queue() {
let admission = InferenceAdmission::with_permits(1);
let _held = admission.acquire().await;
let started = std::time::Instant::now();
let result = admission
.acquire_with_timeout(Duration::from_millis(50))
.await;
assert!(result.is_none());
assert!(started.elapsed() >= Duration::from_millis(45));
}
#[tokio::test]
async fn permits_available_reflects_outstanding_holds() {
let admission = InferenceAdmission::with_permits(2);
assert_eq!(admission.permits_available(), 2);
let _a = admission.acquire().await;
assert_eq!(admission.permits_available(), 1);
let _b = admission.acquire().await;
assert_eq!(admission.permits_available(), 0);
}
#[test]
fn host_ram_mb_returns_a_positive_value() {
assert!(host_ram_mb() > 0);
}
}