use std::sync::Mutex;
use crate::hardware::AcceleratorFamily;
use crate::profile::AcceleratorProfile;
use crate::registry::{AcceleratorRegistry, DetectBuilder};
pub struct LazyRegistry {
state: Mutex<LazyState>,
}
struct LazyState {
profiles: Vec<AcceleratorProfile>,
probed: [bool; 5], }
impl LazyRegistry {
pub fn new() -> Self {
Self {
state: Mutex::new(LazyState {
profiles: vec![crate::detect::cpu_profile()],
probed: [true, false, false, false, false], }),
}
}
pub fn by_family(&self, family: AcceleratorFamily) -> Vec<AcceleratorProfile> {
let mut state = self.state.lock().unwrap_or_else(|p| p.into_inner());
Self::ensure_probed(&mut state, family);
state
.profiles
.iter()
.filter(|p| p.available && p.accelerator.family() == family)
.cloned()
.collect()
}
pub fn probed_profiles(&self) -> Vec<AcceleratorProfile> {
let state = self.state.lock().unwrap_or_else(|p| p.into_inner());
state.profiles.clone()
}
pub fn into_registry(self) -> AcceleratorRegistry {
let mut state = self.state.into_inner().unwrap_or_else(|p| p.into_inner());
for family in [
AcceleratorFamily::Gpu,
AcceleratorFamily::Npu,
AcceleratorFamily::Tpu,
AcceleratorFamily::AiAsic,
] {
Self::ensure_probed(&mut state, family);
}
AcceleratorRegistry::from_profiles(state.profiles)
}
fn ensure_probed(state: &mut LazyState, family: AcceleratorFamily) {
let idx = family_index(family);
if state.probed[idx] {
return;
}
state.probed[idx] = true;
let builder = match family {
AcceleratorFamily::Cpu => return, AcceleratorFamily::Gpu => DetectBuilder::none().with_cuda().with_rocm().with_vulkan(),
AcceleratorFamily::Npu => DetectBuilder::none()
.with_intel_npu()
.with_amd_xdna()
.with_samsung_npu()
.with_mediatek_apu(),
AcceleratorFamily::Tpu => DetectBuilder::none().with_tpu(),
AcceleratorFamily::AiAsic => DetectBuilder::none()
.with_gaudi()
.with_aws_neuron()
.with_intel_oneapi()
.with_qualcomm()
.with_cerebras()
.with_graphcore()
.with_groq(),
};
let partial = builder.detect();
for p in partial.all_profiles() {
if !matches!(p.accelerator, crate::hardware::AcceleratorType::Cpu) {
state.profiles.push(p.clone());
}
}
}
}
impl Default for LazyRegistry {
fn default() -> Self {
Self::new()
}
}
impl std::fmt::Debug for LazyRegistry {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let state = self.state.lock().unwrap_or_else(|p| p.into_inner());
f.debug_struct("LazyRegistry")
.field("profiles", &state.profiles.len())
.field("probed_families", &state.probed)
.finish()
}
}
fn family_index(family: AcceleratorFamily) -> usize {
match family {
AcceleratorFamily::Cpu => 0,
AcceleratorFamily::Gpu => 1,
AcceleratorFamily::Npu => 2,
AcceleratorFamily::Tpu => 3,
AcceleratorFamily::AiAsic => 4,
}
}