Skip to main content

privacy_filter_rs/
lib.rs

1//! # privacy-filter-rs — OpenAI Privacy Filter inference in Rust
2//!
3//! Pure-Rust inference for the [OpenAI Privacy Filter](https://huggingface.co/openai/privacy-filter)
4//! token classification model, built on [Burn 0.20](https://burn.dev).
5//!
6//! ## Quick start
7//!
8//! ```rust,ignore
9//! use privacy_filter_rs::PrivacyFilterInference;
10//! use std::path::Path;
11//!
12//! // Choose backend
13//! type B = burn::backend::NdArray;
14//! let device = burn::backend::ndarray::NdArrayDevice::Cpu;
15//!
16//! let engine = PrivacyFilterInference::<B>::load(
17//!     Path::new("/path/to/privacy-filter"),
18//!     device,
19//! )?;
20//!
21//! let spans = engine.predict("My name is Alice Smith")?;
22//! for span in &spans {
23//!     println!("{}: {} (score: {:.4})", span.entity_group, span.word, span.score);
24//! }
25//! ```
26
27// ── Thread configuration ────────────────────────────────────────────────────
28
29/// Configure the global Rayon thread pool.
30///
31/// Call this **once**, before any model operations.
32/// Returns the actual number of threads in the pool.
33pub fn init_threads(n: Option<usize>) -> usize {
34    let mut builder = rayon::ThreadPoolBuilder::new();
35    if let Some(count) = n {
36        if count > 0 {
37            builder = builder.num_threads(count);
38        }
39    }
40    let _ = builder.build_global();
41    rayon::current_num_threads()
42}
43
44// ── Internal modules ────────────────────────────────────────────────────────
45
46pub mod config;
47pub mod inference;
48pub mod model;
49pub mod tensor_utils;
50pub mod viterbi;
51pub mod weights;
52
53// ── Flat re-exports ─────────────────────────────────────────────────────────
54
55pub use config::{ModelConfig, ViterbiConfig};
56pub use inference::PrivacyFilterInference;
57pub use viterbi::PrivacySpan;
58
59// ── Backend selection ───────────────────────────────────────────────────────
60
61#[cfg(feature = "ndarray")]
62pub mod backend {
63    pub use burn::backend::NdArray as B;
64    pub type Device = burn::backend::ndarray::NdArrayDevice;
65    pub fn device() -> Device { Device::Cpu }
66}
67
68#[cfg(all(feature = "wgpu-f16", not(feature = "ndarray"), not(feature = "wgpu"), not(feature = "mlx")))]
69pub mod backend {
70    pub type B = burn::backend::wgpu::Wgpu<half::f16, i32, u32>;
71    pub type Device = burn::backend::wgpu::WgpuDevice;
72    pub fn device() -> Device { Device::DefaultDevice }
73}
74
75#[cfg(all(feature = "wgpu", not(feature = "ndarray"), not(feature = "wgpu-f16"), not(feature = "mlx")))]
76pub mod backend {
77    pub use burn::backend::Wgpu as B;
78    pub type Device = burn::backend::wgpu::WgpuDevice;
79    pub fn device() -> Device { Device::DefaultDevice }
80}
81
82#[cfg(all(feature = "mlx", not(feature = "ndarray"), not(feature = "wgpu"), not(feature = "wgpu-f16")))]
83pub mod backend {
84    pub use burn_mlx::Mlx as B;
85    pub type Device = burn_mlx::MlxDevice;
86    pub fn device() -> Device { Default::default() }
87}