use std::collections::HashMap;
use std::sync::Arc;
use async_trait::async_trait;
#[derive(Debug, Clone)]
pub enum FeatureValue {
Float(f64),
Int(i64),
String(String),
Vector(Vec<f32>),
Bool(bool),
Null,
}
impl PartialEq for FeatureValue {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(Self::Float(a), Self::Float(b)) => a.to_bits() == b.to_bits(),
(Self::Int(a), Self::Int(b)) => a == b,
(Self::String(a), Self::String(b)) => a == b,
(Self::Vector(a), Self::Vector(b)) => {
a.len() == b.len()
&& a.iter()
.zip(b.iter())
.all(|(x, y)| x.to_bits() == y.to_bits())
}
(Self::Bool(a), Self::Bool(b)) => a == b,
(Self::Null, Self::Null) => true,
_ => false,
}
}
}
impl Eq for FeatureValue {}
impl std::hash::Hash for FeatureValue {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
std::mem::discriminant(self).hash(state);
match self {
Self::Float(f) => f.to_bits().hash(state),
Self::Int(i) => i.hash(state),
Self::String(s) => s.hash(state),
Self::Vector(v) => {
v.len().hash(state);
for f in v {
f.to_bits().hash(state);
}
}
Self::Bool(b) => b.hash(state),
Self::Null => {}
}
}
}
#[derive(Debug, Clone, Default)]
pub struct ClassifyInput {
pub features: HashMap<String, FeatureValue>,
}
impl ClassifyInput {
pub fn new() -> Self {
Self::default()
}
pub fn with(mut self, name: impl Into<String>, value: FeatureValue) -> Self {
self.features.insert(name.into(), value);
self
}
pub fn stable_hash(&self) -> u64 {
use std::hash::{Hash, Hasher};
let mut entries: Vec<(&String, &FeatureValue)> = self.features.iter().collect();
entries.sort_by(|a, b| a.0.cmp(b.0));
let mut h = std::collections::hash_map::DefaultHasher::new();
entries.len().hash(&mut h);
for (k, v) in entries {
k.hash(&mut h);
v.hash(&mut h);
}
h.finish()
}
}
impl PartialEq for ClassifyInput {
fn eq(&self, other: &Self) -> bool {
self.features == other.features
}
}
impl Eq for ClassifyInput {}
#[derive(Debug, Clone, PartialEq)]
pub enum ClassifierError {
ArityMismatch { expected: usize, actual: usize },
DomainViolation { value: f64 },
Provider(String),
}
impl std::fmt::Display for ClassifierError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::ArityMismatch { expected, actual } => write!(
f,
"classifier arity mismatch: expected {expected} outputs, got {actual}"
),
Self::DomainViolation { value } => {
write!(f, "classifier output {value} outside [0, 1]")
}
Self::Provider(msg) => write!(f, "classifier provider error: {msg}"),
}
}
}
impl std::error::Error for ClassifierError {}
pub type ClassifierResult<T> = std::result::Result<T, ClassifierError>;
#[async_trait]
pub trait NeuralClassifier: Send + Sync + std::fmt::Debug {
async fn classify(&self, inputs: &[ClassifyInput]) -> ClassifierResult<Vec<f64>>;
async fn classify_logits(&self, inputs: &[ClassifyInput]) -> ClassifierResult<Vec<f64>> {
let probs = self.classify(inputs).await?;
Ok(probs.into_iter().map(inverse_sigmoid).collect())
}
fn name(&self) -> &str;
fn get_calibrator(&self) -> Option<Arc<dyn crate::calibration::Calibrator>> {
None
}
async fn raw_and_calibrated(
&self,
inputs: &[ClassifyInput],
) -> ClassifierResult<Vec<(f64, Option<f64>)>> {
let raw = self.classify(inputs).await?;
Ok(raw.into_iter().map(|p| (p, None)).collect())
}
}
pub struct MockClassifier {
name: String,
f: Arc<dyn Fn(&ClassifyInput) -> f64 + Send + Sync>,
}
impl MockClassifier {
pub fn new<F>(name: impl Into<String>, f: F) -> Self
where
F: Fn(&ClassifyInput) -> f64 + Send + Sync + 'static,
{
Self {
name: name.into(),
f: Arc::new(f),
}
}
pub fn constant(name: impl Into<String>, value: f64) -> Self {
let v = value.clamp(0.0, 1.0);
Self::new(name, move |_| v)
}
}
impl std::fmt::Debug for MockClassifier {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MockClassifier")
.field("name", &self.name)
.finish_non_exhaustive()
}
}
#[async_trait]
impl NeuralClassifier for MockClassifier {
async fn classify(&self, inputs: &[ClassifyInput]) -> ClassifierResult<Vec<f64>> {
let mut out = Vec::with_capacity(inputs.len());
for inp in inputs {
let v = (self.f)(inp);
if v.is_nan() {
return Err(ClassifierError::DomainViolation { value: v });
}
out.push(v.clamp(0.0, 1.0));
}
Ok(out)
}
fn name(&self) -> &str {
&self.name
}
}
pub struct CalibratedClassifier {
name: String,
base: Arc<dyn NeuralClassifier>,
calibrator: Arc<dyn crate::calibration::Calibrator>,
}
impl CalibratedClassifier {
pub fn new(
name: impl Into<String>,
base: Arc<dyn NeuralClassifier>,
calibrator: Arc<dyn crate::calibration::Calibrator>,
) -> Self {
Self {
name: name.into(),
base,
calibrator,
}
}
}
impl std::fmt::Debug for CalibratedClassifier {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CalibratedClassifier")
.field("name", &self.name)
.field("base", &self.base.name())
.field("method", &self.calibrator.method())
.finish_non_exhaustive()
}
}
#[async_trait]
impl NeuralClassifier for CalibratedClassifier {
async fn classify(&self, inputs: &[ClassifyInput]) -> ClassifierResult<Vec<f64>> {
let raw = self.base.classify(inputs).await?;
Ok(self.calibrator.apply_batch(&raw))
}
fn name(&self) -> &str {
&self.name
}
fn get_calibrator(&self) -> Option<Arc<dyn crate::calibration::Calibrator>> {
Some(Arc::clone(&self.calibrator))
}
async fn raw_and_calibrated(
&self,
inputs: &[ClassifyInput],
) -> ClassifierResult<Vec<(f64, Option<f64>)>> {
let raw = self.base.classify(inputs).await?;
let calibrated = self.calibrator.apply_batch(&raw);
Ok(raw
.into_iter()
.zip(calibrated)
.map(|(r, c)| (r, Some(c)))
.collect())
}
}
#[derive(Debug)]
struct KeyedStore<V> {
inner: std::sync::RwLock<HashMap<(String, u64), V>>,
}
impl<V> Default for KeyedStore<V> {
fn default() -> Self {
Self {
inner: std::sync::RwLock::new(HashMap::new()),
}
}
}
impl<V: Clone> KeyedStore<V> {
fn get(&self, model: &str, input_hash: u64) -> Option<V> {
self.inner
.read()
.ok()
.and_then(|g| g.get(&(model.to_string(), input_hash)).cloned())
}
}
impl<V> KeyedStore<V> {
fn insert(&self, model: &str, input_hash: u64, value: V) {
if let Ok(mut g) = self.inner.write() {
g.insert((model.to_string(), input_hash), value);
}
}
fn insert_bounded(&self, model: &str, input_hash: u64, value: V, max_entries: usize) {
if let Ok(mut g) = self.inner.write() {
if max_entries > 0 && g.len() >= max_entries {
g.clear();
}
g.insert((model.to_string(), input_hash), value);
}
}
fn clear(&self) {
if let Ok(mut g) = self.inner.write() {
g.clear();
}
}
fn len(&self) -> usize {
self.inner.read().map(|g| g.len()).unwrap_or(0)
}
}
#[derive(Debug, Default)]
pub struct NeuralProvenanceStore {
inner: KeyedStore<NeuralProvenanceRecord>,
}
#[derive(Debug, Clone)]
pub struct NeuralProvenanceRecord {
pub raw_probability: f64,
pub calibrated_probability: Option<f64>,
pub confidence_band: Option<crate::result::ConfidenceBand>,
pub feature_inputs: HashMap<String, FeatureValue>,
}
impl NeuralProvenanceStore {
pub fn new() -> Self {
Self::default()
}
pub fn record(&self, model: &str, input_hash: u64, record: NeuralProvenanceRecord) {
self.inner.insert(model, input_hash, record);
}
pub fn get(&self, model: &str, input_hash: u64) -> Option<NeuralProvenanceRecord> {
self.inner.get(model, input_hash)
}
pub fn clear(&self) {
self.inner.clear();
}
pub fn len(&self) -> usize {
self.inner.len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
#[derive(Debug, Default)]
pub struct ModelInvocationCache {
inner: KeyedStore<f64>,
max_entries: usize,
}
impl ModelInvocationCache {
pub fn new(max_entries: usize) -> Self {
Self {
inner: KeyedStore::default(),
max_entries,
}
}
pub fn get(&self, model: &str, input_hash: u64) -> Option<f64> {
self.inner.get(model, input_hash)
}
pub fn insert(&self, model: &str, input_hash: u64, value: f64) {
self.inner
.insert_bounded(model, input_hash, value, self.max_entries);
}
pub fn clear(&self) {
self.inner.clear();
}
pub fn len(&self) -> usize {
self.inner.len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
pub struct CandleLinearClassifier {
name: String,
feature_order: Vec<String>,
weight: Vec<f32>,
bias: f32,
device: candle_core::Device,
}
impl std::fmt::Debug for CandleLinearClassifier {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CandleLinearClassifier")
.field("name", &self.name)
.field("feature_order", &self.feature_order)
.field("n_features", &self.weight.len())
.finish_non_exhaustive()
}
}
impl CandleLinearClassifier {
pub fn load(
name: impl Into<String>,
feature_order: Vec<String>,
weights_path: impl AsRef<std::path::Path>,
) -> ClassifierResult<Self> {
let device = candle_core::Device::Cpu;
let path = weights_path.as_ref();
let tensors = candle_core::safetensors::load(path, &device).map_err(|e| {
ClassifierError::Provider(format!(
"candle: failed to load safetensors from {path:?}: {e}"
))
})?;
let weight_t = tensors.get("weight").ok_or_else(|| {
ClassifierError::Provider("candle: safetensors missing 'weight' tensor".to_string())
})?;
let bias_t = tensors.get("bias").ok_or_else(|| {
ClassifierError::Provider("candle: safetensors missing 'bias' tensor".to_string())
})?;
let weight: Vec<f32> = weight_t
.flatten_all()
.and_then(|t| t.to_vec1::<f32>())
.map_err(|e| ClassifierError::Provider(format!("candle: weight read: {e}")))?;
let bias_vec: Vec<f32> = bias_t
.flatten_all()
.and_then(|t| t.to_vec1::<f32>())
.map_err(|e| ClassifierError::Provider(format!("candle: bias read: {e}")))?;
if bias_vec.len() != 1 {
return Err(ClassifierError::Provider(format!(
"candle: 'bias' must be scalar (shape [1]); got len={}",
bias_vec.len()
)));
}
if weight.len() != feature_order.len() {
return Err(ClassifierError::Provider(format!(
"candle: weight length {} != feature_order length {}",
weight.len(),
feature_order.len()
)));
}
Ok(Self {
name: name.into(),
feature_order,
weight,
bias: bias_vec[0],
device,
})
}
fn encode_feature(&self, v: Option<&FeatureValue>) -> f32 {
match v {
Some(FeatureValue::Float(f)) => *f as f32,
Some(FeatureValue::Int(i)) => *i as f32,
Some(FeatureValue::Bool(b)) => f32::from(*b),
Some(FeatureValue::String(s)) => {
let mut h: u32 = 5381;
for byte in s.as_bytes() {
h = h.wrapping_mul(33).wrapping_add(*byte as u32);
}
(h as i32) as f32 / i32::MAX as f32
}
Some(FeatureValue::Null) | None => 0.0,
_ => 0.0,
}
}
}
#[async_trait]
impl NeuralClassifier for CandleLinearClassifier {
async fn classify(&self, inputs: &[ClassifyInput]) -> ClassifierResult<Vec<f64>> {
if inputs.is_empty() {
return Ok(Vec::new());
}
let n_features = self.weight.len();
let mut data: Vec<f32> = Vec::with_capacity(inputs.len() * n_features);
for inp in inputs {
for fname in &self.feature_order {
data.push(self.encode_feature(inp.features.get(fname)));
}
}
let x = candle_core::Tensor::from_vec(data, (inputs.len(), n_features), &self.device)
.map_err(|e| ClassifierError::Provider(format!("candle: input tensor: {e}")))?;
let w = candle_core::Tensor::from_slice(&self.weight, (n_features, 1), &self.device)
.map_err(|e| ClassifierError::Provider(format!("candle: weight tensor: {e}")))?;
let logits = x
.matmul(&w)
.and_then(|t| t.broadcast_add(&candle_core::Tensor::new(&[self.bias], &self.device)?))
.map_err(|e| ClassifierError::Provider(format!("candle: forward pass: {e}")))?;
let probs = candle_nn::ops::sigmoid(&logits)
.and_then(|t| t.flatten_all())
.and_then(|t| t.to_vec1::<f32>())
.map_err(|e| ClassifierError::Provider(format!("candle: sigmoid: {e}")))?;
Ok(probs.into_iter().map(|p| p as f64).collect())
}
fn name(&self) -> &str {
&self.name
}
}
fn inverse_sigmoid(p: f64) -> f64 {
let p = p.clamp(0.0, 1.0);
if p == 0.0 {
f64::NEG_INFINITY
} else if p == 1.0 {
f64::INFINITY
} else {
(p / (1.0 - p)).ln()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn mock_constant_returns_value_per_row() {
let sr = MockClassifier::constant("classify/test", 0.7);
let inputs = vec![
ClassifyInput::new().with("x", FeatureValue::Float(1.0)),
ClassifyInput::new().with("x", FeatureValue::Float(2.0)),
ClassifyInput::new().with("x", FeatureValue::Float(3.0)),
];
let out = sr.classify(&inputs).await.unwrap();
assert_eq!(out, vec![0.7, 0.7, 0.7]);
assert_eq!(out.len(), inputs.len());
assert_eq!(sr.name(), "classify/test");
}
#[tokio::test]
async fn mock_feature_driven() {
let sr = MockClassifier::new("classify/feature", |inp| {
match inp.features.get("severity") {
Some(FeatureValue::Float(v)) => (*v / 10.0).clamp(0.0, 1.0),
_ => 0.0,
}
});
let inputs = vec![
ClassifyInput::new().with("severity", FeatureValue::Float(2.0)),
ClassifyInput::new().with("severity", FeatureValue::Float(9.0)),
ClassifyInput::new().with("severity", FeatureValue::Float(15.0)), ];
let out = sr.classify(&inputs).await.unwrap();
assert_eq!(out, vec![0.2, 0.9, 1.0]);
}
#[tokio::test]
async fn classify_logits_default_inverse_sigmoid() {
let sr = MockClassifier::constant("classify/test", 0.5);
let out = sr.classify_logits(&[ClassifyInput::new()]).await.unwrap();
assert!((out[0] - 0.0).abs() < 1e-12);
}
#[tokio::test]
async fn mock_rejects_nan() {
let sr = MockClassifier::new("classify/nan", |_| f64::NAN);
let err = sr.classify(&[ClassifyInput::new()]).await.unwrap_err();
assert!(matches!(err, ClassifierError::DomainViolation { .. }));
}
#[test]
fn feature_value_hash_distinguishes_variants() {
fn h(v: FeatureValue) -> u64 {
use std::hash::{Hash, Hasher};
let mut hasher = std::collections::hash_map::DefaultHasher::new();
v.hash(&mut hasher);
hasher.finish()
}
assert_ne!(h(FeatureValue::Float(0.0)), h(FeatureValue::Int(0)));
assert_ne!(h(FeatureValue::Null), h(FeatureValue::Bool(false)));
assert_eq!(h(FeatureValue::Float(0.5)), h(FeatureValue::Float(0.5)));
}
#[test]
fn classify_input_hash_order_independent() {
let a = ClassifyInput::new()
.with("country", FeatureValue::String("US".into()))
.with("revenue", FeatureValue::Float(1.0e6));
let b = ClassifyInput::new()
.with("revenue", FeatureValue::Float(1.0e6))
.with("country", FeatureValue::String("US".into()));
assert_eq!(a.stable_hash(), b.stable_hash());
let c = ClassifyInput::new()
.with("country", FeatureValue::String("DE".into()))
.with("revenue", FeatureValue::Float(1.0e6));
assert_ne!(a.stable_hash(), c.stable_hash());
}
#[test]
fn feature_value_vector_hash() {
fn h(v: FeatureValue) -> u64 {
use std::hash::{Hash, Hasher};
let mut hasher = std::collections::hash_map::DefaultHasher::new();
v.hash(&mut hasher);
hasher.finish()
}
let a = FeatureValue::Vector(vec![1.0, 2.0, 3.0]);
let b = FeatureValue::Vector(vec![1.0, 2.0, 3.0]);
let c = FeatureValue::Vector(vec![1.0, 2.0, 3.5]);
assert_eq!(h(a.clone()), h(b));
assert_ne!(h(a), h(c));
}
#[test]
fn model_invocation_cache_hit_miss() {
let cache = ModelInvocationCache::new(100);
assert!(cache.get("m", 42).is_none());
cache.insert("m", 42, 0.7);
assert_eq!(cache.get("m", 42), Some(0.7));
assert!(cache.get("other", 42).is_none());
assert!(cache.get("m", 43).is_none());
}
#[test]
fn model_invocation_cache_evicts_on_overflow() {
let cache = ModelInvocationCache::new(2);
cache.insert("m", 1, 0.1);
cache.insert("m", 2, 0.2);
assert_eq!(cache.len(), 2);
cache.insert("m", 3, 0.3);
assert_eq!(cache.len(), 1);
assert_eq!(cache.get("m", 3), Some(0.3));
}
#[test]
fn inverse_sigmoid_endpoints() {
assert!(inverse_sigmoid(0.0).is_infinite() && inverse_sigmoid(0.0) < 0.0);
assert!(inverse_sigmoid(1.0).is_infinite() && inverse_sigmoid(1.0) > 0.0);
assert!((inverse_sigmoid(0.5) - 0.0).abs() < 1e-12);
}
}