1pub fn init_threads(n: Option<usize>) -> usize {
34 let mut builder = rayon::ThreadPoolBuilder::new();
35 if let Some(count) = n {
36 if count > 0 {
37 builder = builder.num_threads(count);
38 }
39 }
40 let _ = builder.build_global();
41 rayon::current_num_threads()
42}
43
44pub mod config;
47pub mod inference;
48pub mod model;
49pub mod tensor_utils;
50pub mod viterbi;
51pub mod weights;
52
53pub use config::{ModelConfig, ViterbiConfig};
56pub use inference::PrivacyFilterInference;
57pub use viterbi::PrivacySpan;
58
59#[cfg(feature = "ndarray")]
62pub mod backend {
63 pub use burn::backend::NdArray as B;
64 pub type Device = burn::backend::ndarray::NdArrayDevice;
65 pub fn device() -> Device { Device::Cpu }
66}
67
68#[cfg(all(feature = "wgpu-f16", not(feature = "ndarray"), not(feature = "wgpu"), not(feature = "mlx")))]
69pub mod backend {
70 pub type B = burn::backend::wgpu::Wgpu<half::f16, i32, u32>;
71 pub type Device = burn::backend::wgpu::WgpuDevice;
72 pub fn device() -> Device { Device::DefaultDevice }
73}
74
75#[cfg(all(feature = "wgpu", not(feature = "ndarray"), not(feature = "wgpu-f16"), not(feature = "mlx")))]
76pub mod backend {
77 pub use burn::backend::Wgpu as B;
78 pub type Device = burn::backend::wgpu::WgpuDevice;
79 pub fn device() -> Device { Device::DefaultDevice }
80}
81
82#[cfg(all(feature = "mlx", not(feature = "ndarray"), not(feature = "wgpu"), not(feature = "wgpu-f16")))]
83pub mod backend {
84 pub use burn_mlx::Mlx as B;
85 pub type Device = burn_mlx::MlxDevice;
86 pub fn device() -> Device { Default::default() }
87}