use fs4::fs_std::FileExt;
use std::fs::{self, File, OpenOptions};
use std::path::PathBuf;
use std::time::{Duration, Instant};
use crate::errors::AppError;
pub struct LlmSlotGuard {
#[allow(dead_code)]
slot_file: File,
slot_id: u32,
acquired_at: Instant,
}
impl LlmSlotGuard {
pub fn slot_id(&self) -> u32 {
self.slot_id
}
}
impl Drop for LlmSlotGuard {
fn drop(&mut self) {
let path = slot_path(self.slot_id);
if let Err(e) = fs::remove_file(&path) {
tracing::debug!(slot_id = self.slot_id, error = %e, "slot file removal failed (already gone?)");
}
tracing::debug!(
slot_id = self.slot_id,
held_ms = self.acquired_at.elapsed().as_millis() as u64,
"llm slot released"
);
}
}
pub fn acquire_llm_slot(max_concurrent: u32, wait_secs: u64) -> Result<LlmSlotGuard, AppError> {
if max_concurrent == 0 {
return Err(AppError::Validation(
"max_concurrent deve ser >= 1 para acquire_llm_slot".to_string(),
));
}
let dir = slots_dir();
fs::create_dir_all(&dir).map_err(|e| {
AppError::Io(std::io::Error::new(
e.kind(),
format!("failed to create slots dir {}: {e}", dir.display()),
))
})?;
let stale = find_stale_slots(max_concurrent);
for slot_id in &stale {
let _ = force_release(*slot_id);
tracing::info!(slot_id, "released stale LLM slot (PID dead)");
}
let start = Instant::now();
let timeout = Duration::from_secs(wait_secs);
loop {
for slot_id in 0..max_concurrent {
let path = slot_path(slot_id);
match OpenOptions::new().write(true).create_new(true).open(&path) {
Ok(mut file) => {
if file.try_lock_exclusive().is_ok() {
let pid = std::process::id();
use std::io::Write;
let _ = writeln!(file, "pid={pid}");
tracing::debug!(slot_id, pid, "llm slot acquired");
return Ok(LlmSlotGuard {
slot_file: file,
slot_id,
acquired_at: Instant::now(),
});
}
}
Err(_) => {
}
}
}
if start.elapsed() >= timeout {
return Err(AppError::LockBusy(format!(
"failed to acquire LLM slot within {wait_secs}s (max={max_concurrent} concurrent)"
)));
}
std::thread::sleep(Duration::from_millis(100));
}
}
#[derive(Debug, Clone, serde::Serialize)]
pub struct SlotStatus {
pub max: u32,
pub active: u32,
pub pids: Vec<u32>,
}
pub fn read_status(max_concurrent: u32) -> SlotStatus {
let mut active = 0u32;
let mut pids = Vec::new();
for slot_id in 0..max_concurrent {
let path = slot_path(slot_id);
if path.exists() {
active += 1;
if let Ok(content) = fs::read_to_string(&path) {
if let Some(pid_line) = content.lines().find(|l| l.starts_with("pid=")) {
if let Ok(pid) = pid_line[4..].parse::<u32>() {
pids.push(pid);
}
}
}
}
}
SlotStatus {
max: max_concurrent,
active,
pids,
}
}
pub fn force_release(slot_id: u32) -> Result<(), AppError> {
let path = slot_path(slot_id);
if path.exists() {
fs::remove_file(&path).map_err(|e| {
AppError::Io(std::io::Error::new(
e.kind(),
format!("failed to release slot {slot_id}: {e}"),
))
})?;
}
Ok(())
}
pub fn find_stale_slots(max_concurrent: u32) -> Vec<u32> {
let mut stale = Vec::new();
for slot_id in 0..max_concurrent {
let path = slot_path(slot_id);
if path.exists() {
if let Ok(content) = fs::read_to_string(&path) {
if let Some(pid_line) = content.lines().find(|l| l.starts_with("pid=")) {
if let Ok(pid) = pid_line[4..].parse::<u32>() {
if !pid_alive(pid) {
stale.push(slot_id);
}
}
}
}
}
}
stale
}
#[cfg(unix)]
fn pid_alive(pid: u32) -> bool {
unsafe { libc::kill(pid as i32, 0) == 0 }
}
#[cfg(not(unix))]
fn pid_alive(pid: u32) -> bool {
let _ = pid;
true
}
pub fn slots_dir() -> PathBuf {
let base = std::env::var("XDG_RUNTIME_DIR")
.or_else(|_| std::env::var("SQLITE_GRAPHRAG_CACHE_DIR"))
.unwrap_or_else(|_| {
std::env::var("HOME")
.map(|h| format!("{h}/.local/share"))
.unwrap_or_else(|_| "/tmp".to_string())
});
PathBuf::from(base).join("sqlite-graphrag/llm-slots")
}
pub fn slot_path(id: u32) -> PathBuf {
slots_dir().join(format!("slot-{id}.lock"))
}
pub fn default_max_concurrency() -> u32 {
let cpus = std::thread::available_parallelism()
.map(|n| n.get() as u32)
.unwrap_or(4);
let assumed_available_mb: u32 = 4096;
let per_worker = crate::constants::LLM_WORKER_RSS_MB as u32;
let safe = assumed_available_mb / per_worker.max(1);
let capped = safe.min(crate::constants::MAX_CONCURRENT_CLI_INSTANCES as u32);
cpus.min(capped).max(1)
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::sync::Barrier;
use std::thread;
fn unique_test_dir() -> PathBuf {
let mut dir = std::env::temp_dir();
dir.push(format!(
"llm-slots-test-{}-{}",
std::process::id(),
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_nanos()
));
dir
}
#[test]
fn slot_enforces_max_concurrency() {
let original = std::env::var("SQLITE_GRAPHRAG_CACHE_DIR").ok();
std::env::set_var("SQLITE_GRAPHRAG_CACHE_DIR", unique_test_dir());
let _g1 = acquire_llm_slot(2, 5).expect("first slot");
let _g2 = acquire_llm_slot(2, 5).expect("second slot");
let start = std::time::Instant::now();
let result = acquire_llm_slot(2, 1);
assert!(result.is_err(), "third slot should fail with max=2");
assert!(
start.elapsed() >= std::time::Duration::from_secs(1),
"should wait full timeout before failing"
);
if let Some(v) = original {
std::env::set_var("SQLITE_GRAPHRAG_CACHE_DIR", v);
} else {
std::env::remove_var("SQLITE_GRAPHRAG_CACHE_DIR");
}
}
#[test]
fn slot_releases_on_drop() {
let original = std::env::var("SQLITE_GRAPHRAG_CACHE_DIR").ok();
std::env::set_var("SQLITE_GRAPHRAG_CACHE_DIR", unique_test_dir());
let g1 = acquire_llm_slot(1, 5).expect("first slot");
drop(g1);
let _g2 = acquire_llm_slot(1, 5).expect("second slot after drop");
if let Some(v) = original {
std::env::set_var("SQLITE_GRAPHRAG_CACHE_DIR", v);
} else {
std::env::remove_var("SQLITE_GRAPHRAG_CACHE_DIR");
}
}
#[test]
fn slot_max_concurrent_zero_is_validation_error() {
let result = acquire_llm_slot(0, 1);
assert!(matches!(result, Err(AppError::Validation(_))));
}
#[test]
fn read_status_reflects_active_slots() {
let original = std::env::var("SQLITE_GRAPHRAG_CACHE_DIR").ok();
std::env::set_var("SQLITE_GRAPHRAG_CACHE_DIR", unique_test_dir());
let _g1 = acquire_llm_slot(4, 5).expect("first slot");
let status = read_status(4);
assert_eq!(status.max, 4);
assert!(status.active >= 1);
assert!(!status.pids.is_empty());
if let Some(v) = original {
std::env::set_var("SQLITE_GRAPHRAG_CACHE_DIR", v);
} else {
std::env::remove_var("SQLITE_GRAPHRAG_CACHE_DIR");
}
}
#[test]
fn concurrent_acquires_with_2_threads_serialize() {
let original = std::env::var("SQLITE_GRAPHRAG_CACHE_DIR").ok();
std::env::set_var("SQLITE_GRAPHRAG_CACHE_DIR", unique_test_dir());
let barrier = Arc::new(Barrier::new(3));
let mut handles = vec![];
for _ in 0..3 {
let b = barrier.clone();
handles.push(thread::spawn(move || {
b.wait();
acquire_llm_slot(2, 5)
}));
}
let results: Vec<_> = handles.into_iter().map(|h| h.join().unwrap()).collect();
let successes = results.iter().filter(|r| r.is_ok()).count();
assert!(successes >= 1);
if let Some(v) = original {
std::env::set_var("SQLITE_GRAPHRAG_CACHE_DIR", v);
} else {
std::env::remove_var("SQLITE_GRAPHRAG_CACHE_DIR");
}
}
}