use crate::error::{RealizarError, Result};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct ThreadConfig {
pub n_threads_batch: usize,
pub n_threads_decode: usize,
}
impl ThreadConfig {
#[must_use]
pub fn auto() -> Self {
let num_cpus = rayon::current_num_threads();
Self {
n_threads_batch: num_cpus,
n_threads_decode: (num_cpus / 2).max(1),
}
}
#[must_use]
pub fn new(n_threads_batch: usize, n_threads_decode: usize) -> Self {
Self {
n_threads_batch: n_threads_batch.max(1),
n_threads_decode: n_threads_decode.max(1),
}
}
#[must_use]
pub fn threads_for(&self, is_prefill: bool) -> usize {
if is_prefill {
self.n_threads_batch
} else {
self.n_threads_decode
}
}
}
impl Default for ThreadConfig {
fn default() -> Self {
Self::auto()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum InferenceMode {
Prefill,
Decode,
}
impl InferenceMode {
#[must_use]
pub fn is_prefill(self) -> bool {
matches!(self, Self::Prefill)
}
#[must_use]
pub fn is_decode(self) -> bool {
matches!(self, Self::Decode)
}
}
pub fn configure_optimal_thread_pool() -> Result<()> {
let optimal_threads = std::env::var("RAYON_NUM_THREADS")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(16);
configure_thread_pool(optimal_threads)
}
pub fn configure_thread_pool(num_threads: usize) -> Result<()> {
rayon::ThreadPoolBuilder::new()
.num_threads(num_threads)
.build_global()
.map_err(|e| {
RealizarError::InvalidConfiguration(format!("Failed to configure thread pool: {e}"))
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_thread_config_auto_returns_valid_config() {
let config = ThreadConfig::auto();
assert!(config.n_threads_batch >= 1, "batch threads must be >= 1");
assert!(config.n_threads_decode >= 1, "decode threads must be >= 1");
}
#[test]
fn test_thread_config_auto_batch_gte_decode() {
let config = ThreadConfig::auto();
assert!(
config.n_threads_batch >= config.n_threads_decode,
"batch threads should be >= decode threads"
);
}
#[test]
fn test_thread_config_new_clamps_to_minimum() {
let config = ThreadConfig::new(0, 0);
assert_eq!(config.n_threads_batch, 1, "0 should be clamped to 1");
assert_eq!(config.n_threads_decode, 1, "0 should be clamped to 1");
}
#[test]
fn test_thread_config_new_preserves_valid_values() {
let config = ThreadConfig::new(8, 4);
assert_eq!(config.n_threads_batch, 8);
assert_eq!(config.n_threads_decode, 4);
}
#[test]
fn test_thread_config_threads_for_prefill() {
let config = ThreadConfig::new(16, 8);
assert_eq!(config.threads_for(true), 16);
}
#[test]
fn test_thread_config_threads_for_decode() {
let config = ThreadConfig::new(16, 8);
assert_eq!(config.threads_for(false), 8);
}
#[test]
fn test_thread_config_default_uses_auto() {
let default_config = ThreadConfig::default();
let auto_config = ThreadConfig::auto();
assert_eq!(default_config, auto_config);
}
#[test]
fn test_thread_config_clone() {
let config = ThreadConfig::new(12, 6);
let cloned = config;
assert_eq!(config, cloned);
}
#[test]
fn test_thread_config_debug() {
let config = ThreadConfig::new(4, 2);
let debug_str = format!("{:?}", config);
assert!(debug_str.contains("ThreadConfig"));
assert!(debug_str.contains("4"));
assert!(debug_str.contains("2"));
}
#[test]
fn test_inference_mode_is_prefill() {
assert!(InferenceMode::Prefill.is_prefill());
assert!(!InferenceMode::Decode.is_prefill());
}
#[test]
fn test_inference_mode_is_decode() {
assert!(InferenceMode::Decode.is_decode());
assert!(!InferenceMode::Prefill.is_decode());
}
#[test]
fn test_inference_mode_equality() {
assert_eq!(InferenceMode::Prefill, InferenceMode::Prefill);
assert_eq!(InferenceMode::Decode, InferenceMode::Decode);
assert_ne!(InferenceMode::Prefill, InferenceMode::Decode);
}
#[test]
fn test_inference_mode_clone() {
let mode = InferenceMode::Prefill;
let cloned = mode;
assert_eq!(mode, cloned);
}
#[test]
fn test_inference_mode_debug() {
assert_eq!(format!("{:?}", InferenceMode::Prefill), "Prefill");
assert_eq!(format!("{:?}", InferenceMode::Decode), "Decode");
}
#[test]
fn test_inference_mode_hash() {
use std::collections::HashSet;
let mut set = HashSet::new();
set.insert(InferenceMode::Prefill);
set.insert(InferenceMode::Decode);
assert_eq!(set.len(), 2);
assert!(set.contains(&InferenceMode::Prefill));
assert!(set.contains(&InferenceMode::Decode));
}
#[test]
fn test_config_with_mode() {
let config = ThreadConfig::new(16, 4);
let prefill_threads = config.threads_for(InferenceMode::Prefill.is_prefill());
let decode_threads = config.threads_for(InferenceMode::Decode.is_prefill());
assert_eq!(prefill_threads, 16);
assert_eq!(decode_threads, 4);
}
#[test]
fn test_thread_config_with_one_thread() {
let config = ThreadConfig::new(1, 1);
assert_eq!(config.threads_for(true), 1);
assert_eq!(config.threads_for(false), 1);
}
#[test]
fn test_thread_config_large_values() {
let config = ThreadConfig::new(1024, 512);
assert_eq!(config.n_threads_batch, 1024);
assert_eq!(config.n_threads_decode, 512);
}
#[test]
fn test_thread_config_decode_larger_than_batch() {
let config = ThreadConfig::new(4, 8);
assert_eq!(config.n_threads_batch, 4);
assert_eq!(config.n_threads_decode, 8);
}
}