use crate::classifier::{IntentClassifier, SharedClassifier};
use crate::error::NlError;
use crossbeam_channel::{Receiver, Sender, bounded};
pub const POOL_MIN: usize = 1;
pub const POOL_MAX: usize = 8;
pub const POOL_DEFAULT: usize = 2;
pub struct ClassifierPool {
sender: Sender<SharedClassifier>,
receiver: Receiver<SharedClassifier>,
capacity: usize,
}
impl ClassifierPool {
pub fn new<L>(capacity: usize, mut loader: L) -> Result<Self, NlError>
where
L: FnMut() -> Result<IntentClassifier, NlError>,
{
let capacity = capacity.clamp(POOL_MIN, POOL_MAX);
let (sender, receiver) = bounded::<SharedClassifier>(capacity);
for _ in 0..capacity {
let classifier = loader()?;
let shared = SharedClassifier::new(classifier);
sender
.send(shared)
.expect("crossbeam_channel just created with capacity == iteration count");
}
Ok(Self {
sender,
receiver,
capacity,
})
}
#[must_use]
pub fn capacity(&self) -> usize {
self.capacity
}
pub fn acquire(&self) -> PoolGuard<'_> {
let shared = self
.receiver
.recv()
.expect("ClassifierPool channel disconnected — pool dropped while in use");
let sender = self.sender.clone();
let on_release: SlotReturn = Box::new(move |shared| {
let _ = sender.send(shared);
});
let scoped = scopeguard::guard(shared, on_release);
PoolGuard {
scoped: Some(scoped),
_pool: std::marker::PhantomData,
}
}
}
impl std::fmt::Debug for ClassifierPool {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ClassifierPool")
.field("capacity", &self.capacity)
.field("available", &self.receiver.len())
.finish()
}
}
pub struct PoolGuard<'a> {
scoped: Option<scopeguard::ScopeGuard<SharedClassifier, SlotReturn>>,
_pool: std::marker::PhantomData<&'a ClassifierPool>,
}
type SlotReturn = Box<dyn FnOnce(SharedClassifier) + Send + 'static>;
impl<'a> PoolGuard<'a> {
#[must_use]
pub fn classifier(&self) -> &SharedClassifier {
use std::ops::Deref;
let scoped = self
.scoped
.as_ref()
.expect("PoolGuard accessed after drop — invariant violated");
scoped.deref()
}
}
#[must_use]
pub fn resolve_pool_size(configured: Option<usize>) -> usize {
let raw = configured
.or_else(|| {
std::env::var("SQRY_NL_POOL_SIZE")
.ok()
.and_then(|s| s.trim().parse::<usize>().ok())
})
.unwrap_or(POOL_DEFAULT);
raw.clamp(POOL_MIN, POOL_MAX)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::{ClassificationResult, Intent};
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
fn _silence_unused_warning() {
let _ = ClassificationResult {
intent: Intent::Ambiguous,
confidence: 0.0,
all_probabilities: vec![],
model_version: "test".into(),
};
}
static ENV_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(());
#[test]
fn resolve_pool_size_prefers_configured() {
let _g = ENV_LOCK.lock().unwrap_or_else(|e| e.into_inner());
unsafe { std::env::set_var("SQRY_NL_POOL_SIZE", "6") };
assert_eq!(resolve_pool_size(Some(3)), 3);
unsafe { std::env::remove_var("SQRY_NL_POOL_SIZE") };
}
#[test]
fn resolve_pool_size_falls_back_to_env() {
let _g = ENV_LOCK.lock().unwrap_or_else(|e| e.into_inner());
unsafe { std::env::set_var("SQRY_NL_POOL_SIZE", "5") };
assert_eq!(resolve_pool_size(None), 5);
unsafe { std::env::remove_var("SQRY_NL_POOL_SIZE") };
}
#[test]
fn resolve_pool_size_default_when_unset() {
let _g = ENV_LOCK.lock().unwrap_or_else(|e| e.into_inner());
unsafe { std::env::remove_var("SQRY_NL_POOL_SIZE") };
assert_eq!(resolve_pool_size(None), POOL_DEFAULT);
}
#[test]
fn resolve_pool_size_clamped_to_max() {
assert_eq!(resolve_pool_size(Some(999)), POOL_MAX);
}
#[test]
fn resolve_pool_size_clamped_to_min() {
assert_eq!(resolve_pool_size(Some(0)), POOL_MIN);
}
#[test]
fn capacity_clamps_above_max() {
let count = Arc::new(AtomicUsize::new(0));
let count_inner = Arc::clone(&count);
let res = ClassifierPool::new(999, move || -> Result<IntentClassifier, NlError> {
count_inner.fetch_add(1, Ordering::SeqCst);
Err(NlError::Config("synthetic loader failure".into()))
});
assert!(res.is_err());
assert_eq!(count.load(Ordering::SeqCst), 1);
}
#[test]
fn channel_recv_send_round_trips() {
let (tx, rx) = bounded::<u64>(2);
tx.send(1).unwrap();
tx.send(2).unwrap();
let a = rx.recv().unwrap();
let b = rx.recv().unwrap();
assert_eq!(rx.len(), 0);
tx.send(a).unwrap();
tx.send(b).unwrap();
assert_eq!(rx.len(), 2);
}
}