#[path = "ml_weights.rs"]
pub(crate) mod ml_weights;
use std::cell::RefCell;
#[path = "ml_features.rs"]
mod ml_features;
pub use ml_features::compute_features_with_config;
pub use ml_features::{compute_features_public, NUM_FEATURES};
const EXPERT_COUNT: usize = 6;
const EXPERT_HIDDEN_LAYER_1: usize = 32;
const EXPERT_HIDDEN_LAYER_2: usize = 16;
pub fn score(text: &str, context: &str) -> f64 {
score_with_config(text, context, &[], &[], &[], &[])
}
pub fn score_with_config(
text: &str,
context: &str,
known_prefixes: &[String],
secret_keywords: &[String],
test_keywords: &[String],
placeholder_keywords: &[String],
) -> f64 {
if text.is_empty() {
return 0.0;
}
thread_local! {
static SCORE_CACHE: RefCell<std::collections::HashMap<u64, f64>> =
RefCell::new(std::collections::HashMap::with_capacity(64));
}
let cache_key = {
let mut hash: u64 = 0xcbf29ce484222325;
for &byte in text.as_bytes() {
hash ^= u64::from(byte);
hash = hash.wrapping_mul(0x100000001b3);
}
hash ^= 0; hash = hash.wrapping_mul(0x100000001b3);
for &byte in context.as_bytes() {
hash ^= u64::from(byte);
hash = hash.wrapping_mul(0x100000001b3);
}
hash
};
if let Some(score) = SCORE_CACHE.with(|cache| cache.borrow().get(&cache_key).copied()) {
return score;
}
let features = compute_features_with_config(
text,
context,
known_prefixes,
secret_keywords,
test_keywords,
placeholder_keywords,
);
let score = forward_pass(&features) as f64;
SCORE_CACHE.with(|cache| {
let mut cache = cache.borrow_mut();
if cache.len() >= 256 {
cache.clear();
}
cache.insert(cache_key, score);
});
score
}
pub(crate) fn score_features(features: &[f32; NUM_FEATURES]) -> f64 {
forward_pass(features) as f64
}
pub fn model_version() -> &'static str {
ml_weights::MODEL_VERSION
}
fn forward_pass(input: &[f32; NUM_FEATURES]) -> f32 {
let model = ml_weights::model();
forward_pass_impl(model, input)
}
fn forward_pass_impl(model: &ml_weights::MoeModel, input: &[f32; NUM_FEATURES]) -> f32 {
let gate_probs = softmax(&compute_gate_logits(model, input));
let mut score_logit = 0.0f32;
for (expert_idx, gate_prob) in gate_probs.iter().enumerate() {
score_logit += *gate_prob * expert_logit(&model.experts[expert_idx], input);
}
sigmoid(score_logit)
}
fn compute_gate_logits(
model: &ml_weights::MoeModel,
input: &[f32; NUM_FEATURES],
) -> [f32; EXPERT_COUNT] {
debug_assert_eq!(model.gate_weight.len(), NUM_FEATURES * EXPERT_COUNT);
debug_assert_eq!(model.gate_bias.len(), EXPERT_COUNT);
let mut gate_logits = [0.0f32; EXPERT_COUNT];
for (expert_idx, logit) in gate_logits.iter_mut().enumerate() {
let row = &model.gate_weight[expert_idx * NUM_FEATURES..];
*logit = dense_row(row, input, model.gate_bias[expert_idx]);
}
gate_logits
}
fn expert_logit(expert: &ml_weights::ExpertWeights, input: &[f32; NUM_FEATURES]) -> f32 {
let h1 = dense_relu_layer_t::<NUM_FEATURES, EXPERT_HIDDEN_LAYER_1>(
expert.fc1_weight_t,
expert.fc1_bias,
input,
);
let h2 = dense_relu_layer_t::<EXPERT_HIDDEN_LAYER_1, EXPERT_HIDDEN_LAYER_2>(
expert.fc2_weight_t,
expert.fc2_bias,
&h1,
);
dense_row(expert.fc3_weight, &h2, expert.fc3_bias)
}
#[inline]
fn dense_relu_layer_t<const INPUT: usize, const OUTPUT: usize>(
weights_t: &[f32],
bias: &[f32],
input: &[f32; INPUT],
) -> [f32; OUTPUT] {
let mut acc = [0.0f32; OUTPUT];
for (o, slot) in acc.iter_mut().enumerate() {
*slot = bias[o];
}
for k in 0..INPUT {
let x = input[k];
let row = &weights_t[k * OUTPUT..k * OUTPUT + OUTPUT];
for (slot, &w) in acc.iter_mut().zip(row.iter()) {
*slot += x * w;
}
}
for slot in acc.iter_mut() {
*slot = slot.max(0.0);
}
acc
}
#[inline(always)]
fn dense_row<const INPUT: usize>(weights: &[f32], input: &[f32; INPUT], bias: f32) -> f32 {
let mut sum = bias;
for (&x, &w) in input.iter().zip(weights.iter()) {
sum += x * w;
}
sum
}
fn sigmoid(value: f32) -> f32 {
let x = value;
if x <= -6.0 {
0.0
} else if x >= 6.0 {
1.0
} else {
0.5 + 0.5 * x / (1.0 + x.abs())
}
}
fn softmax(logits: &[f32; EXPERT_COUNT]) -> [f32; EXPERT_COUNT] {
let max_logit = logits.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let mut exps = [0.0f32; EXPERT_COUNT];
let mut sum = 0.0f32;
for (idx, logit) in logits.iter().enumerate() {
let value = (*logit - max_logit).exp();
exps[idx] = value;
sum += value;
}
for value in &mut exps {
*value /= sum;
}
exps
}