pub fn init_threads(n: Option<usize>) -> usize {
let mut builder = rayon::ThreadPoolBuilder::new();
if let Some(count) = n {
if count > 0 {
builder = builder.num_threads(count);
}
}
let _ = builder.build_global();
rayon::current_num_threads()
}
pub mod config;
pub mod inference;
pub mod model;
pub mod tensor_utils;
pub mod viterbi;
pub mod weights;
pub use config::{ModelConfig, ViterbiConfig};
pub use inference::PrivacyFilterInference;
pub use viterbi::PrivacySpan;
#[cfg(feature = "ndarray")]
pub mod backend {
pub use burn::backend::NdArray as B;
pub type Device = burn::backend::ndarray::NdArrayDevice;
pub fn device() -> Device { Device::Cpu }
}
#[cfg(all(feature = "wgpu-f16", not(feature = "ndarray"), not(feature = "wgpu"), not(feature = "mlx")))]
pub mod backend {
pub type B = burn::backend::wgpu::Wgpu<half::f16, i32, u32>;
pub type Device = burn::backend::wgpu::WgpuDevice;
pub fn device() -> Device { Device::DefaultDevice }
}
#[cfg(all(feature = "wgpu", not(feature = "ndarray"), not(feature = "wgpu-f16"), not(feature = "mlx")))]
pub mod backend {
pub use burn::backend::Wgpu as B;
pub type Device = burn::backend::wgpu::WgpuDevice;
pub fn device() -> Device { Device::DefaultDevice }
}
#[cfg(all(feature = "mlx", not(feature = "ndarray"), not(feature = "wgpu"), not(feature = "wgpu-f16")))]
pub mod backend {
pub use burn_mlx::Mlx as B;
pub type Device = burn_mlx::MlxDevice;
pub fn device() -> Device { Default::default() }
}