#![cfg(feature = "classifier")]
use std::path::PathBuf;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, Barrier};
use std::thread;
use std::time::{Duration, Instant};
use sqry_nl::classifier::{ClassifierPool, IntentClassifier, SharedClassifier, TrustMode};
use sqry_nl::error::NlError;
fn in_tree_model_dir() -> PathBuf {
PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("models")
}
#[test]
#[ignore = "requires ONNX Runtime dylib + committed model fixtures; run manually with --ignored"]
fn n_concurrent_translates_use_n_distinct_sessions() {
const POOL_SIZE: usize = 4;
const FANIN: usize = 16;
let model_dir = in_tree_model_dir();
assert!(
model_dir.join("intent_classifier.onnx").exists(),
"expected committed model at {}; install ONNX fixtures \
before running this test",
model_dir.display(),
);
let load_calls = Arc::new(AtomicUsize::new(0));
let load_calls_for_pool = Arc::clone(&load_calls);
let model_dir_clone = model_dir.clone();
let pool = ClassifierPool::new(POOL_SIZE, move || -> Result<IntentClassifier, NlError> {
load_calls_for_pool.fetch_add(1, Ordering::SeqCst);
IntentClassifier::load(&model_dir_clone, false, TrustMode::Custom).map_err(NlError::from)
})
.expect("pool init must succeed against in-tree model fixtures");
let load_calls_after_init = load_calls.load(Ordering::SeqCst);
assert_eq!(
load_calls_after_init, POOL_SIZE,
"pool of size {POOL_SIZE} must call loader exactly {POOL_SIZE} times during init, got {load_calls_after_init}"
);
assert_eq!(pool.capacity(), POOL_SIZE);
let pool = Arc::new(pool);
let barrier = Arc::new(Barrier::new(FANIN));
let observed = Arc::new(parking_lot::Mutex::new(Vec::with_capacity(FANIN)));
let mut handles = Vec::with_capacity(FANIN);
for _ in 0..FANIN {
let pool = Arc::clone(&pool);
let barrier = Arc::clone(&barrier);
let observed = Arc::clone(&observed);
handles.push(thread::spawn(move || {
barrier.wait();
let guard = pool.acquire();
let shared: &SharedClassifier = guard.classifier();
let ptr = shared.identity();
let mut classifier = shared.lock();
std::hint::black_box(&mut *classifier);
drop(classifier);
observed.lock().push(ptr);
}));
}
for h in handles {
h.join().expect("worker thread panicked");
}
let load_calls_after_wave = load_calls.load(Ordering::SeqCst);
assert_eq!(
load_calls_after_wave, POOL_SIZE,
"no further loader calls allowed after init; got {load_calls_after_wave} total \
(expected {POOL_SIZE} from init)"
);
let observed = observed.lock();
let mut unique: std::collections::HashSet<usize> =
std::collections::HashSet::with_capacity(POOL_SIZE);
for &ptr in observed.iter() {
unique.insert(ptr);
}
assert_eq!(
unique.len(),
POOL_SIZE,
"expected exactly {POOL_SIZE} distinct pool slots to be observed across {FANIN} calls; \
got {} unique pointers",
unique.len()
);
}
#[test]
#[ignore = "requires ONNX Runtime dylib + committed model fixtures; run manually with --ignored"]
fn panic_in_classify_does_not_lose_slot() {
const POOL_SIZE: usize = 2;
let model_dir = in_tree_model_dir();
assert!(
model_dir.join("intent_classifier.onnx").exists(),
"expected committed model at {}",
model_dir.display(),
);
let pool = ClassifierPool::new(POOL_SIZE, || -> Result<IntentClassifier, NlError> {
IntentClassifier::load(&model_dir, false, TrustMode::Custom).map_err(NlError::from)
})
.expect("pool init");
let pool = Arc::new(pool);
let pool_a = Arc::clone(&pool);
let join_a = thread::spawn(move || {
let _guard = pool_a.acquire();
let _shared = _guard.classifier().clone();
panic!("synthetic panic during classify");
});
let result = join_a.join();
assert!(
result.is_err(),
"worker should have panicked; got {:?}",
result.map(|()| "no panic")
);
let pool_b = Arc::clone(&pool);
let budget = Duration::from_secs(2);
let start = Instant::now();
let _g1 = pool_b.acquire();
let _g2 = pool_b.acquire();
assert!(
start.elapsed() < budget,
"second acquire after panic must succeed within {budget:?}; \
post-panic deadlock indicates the slot was leaked"
);
drop(_g1);
drop(_g2);
let _g3 = pool_b.acquire();
drop(_g3);
}
#[test]
#[ignore = "requires ONNX Runtime dylib + committed model fixtures; run manually with --ignored"]
fn translator_pool_serves_concurrent_translates() {
use sqry_nl::{Translator, TranslatorConfig};
const FANIN: usize = 16;
const BUDGET: Duration = Duration::from_secs(60);
let model_dir = in_tree_model_dir();
assert!(
model_dir.join("intent_classifier.onnx").exists(),
"expected committed model at {}",
model_dir.display(),
);
let config = TranslatorConfig {
model_dir_override: Some(model_dir),
allow_unverified_model: false,
classifier_pool_size: Some(4),
..TranslatorConfig::default()
};
let translator = Arc::new(Translator::new(config).expect("translator init"));
let barrier = Arc::new(Barrier::new(FANIN));
let start = Instant::now();
let mut handles = Vec::with_capacity(FANIN);
for tid in 0..FANIN {
let translator = Arc::clone(&translator);
let barrier = Arc::clone(&barrier);
handles.push(thread::spawn(move || {
barrier.wait();
let q = format!("find functions named worker_{tid}");
let _resp = translator.translate_shared(&q);
}));
}
for h in handles {
h.join().expect("worker thread panicked — pool deadlock?");
}
let elapsed = start.elapsed();
assert!(
elapsed < BUDGET,
"16 concurrent translates exceeded {BUDGET:?}: {elapsed:?}"
);
}