use super::embed_registry::{self, EmbedModel};
use super::store::ModelProfile;
pub fn active_model() -> &'static EmbedModel {
embed_registry::selected().unwrap_or_else(|e| panic!("{e}"))
}
pub fn model_name() -> &'static str {
active_model().model_name
}
pub fn dim() -> usize {
active_model().dim
}
pub fn max_tokens() -> usize {
active_model().max_tokens
}
pub const MODEL_DIR: &str = env!("NORNIR_MODEL_DIR");
pub const WEIGHTS_SHA: &str = env!("NORNIR_MODEL_WEIGHTS_SHA");
pub const TOKENIZER_SHA: &str = env!("NORNIR_MODEL_TOKENIZER_SHA");
pub fn model_dir() -> std::path::PathBuf {
if let Ok(d) = std::env::var("NORNIR_MODEL_DIR") {
if !d.is_empty() {
return std::path::PathBuf::from(d);
}
}
let opt = std::path::Path::new("/opt/nornir/models");
if opt.join("tokenizer.json").exists() {
return opt.to_path_buf();
}
std::path::PathBuf::from(MODEL_DIR)
}
pub fn profile() -> ModelProfile {
ModelProfile {
model_name: model_name().to_string(),
weights_sha: WEIGHTS_SHA.to_string(),
tokenizer_sha: TOKENIZER_SHA.to_string(),
pooling: "mean".to_string(),
normalize: true,
dim: dim(),
dtype: "f32".to_string(),
}
}
pub fn pool_and_normalize(hidden: &[f32], n_tokens: usize, dim: usize) -> Vec<f32> {
debug_assert_eq!(hidden.len(), n_tokens * dim);
let mut v = vec![0f32; dim];
for t in 0..n_tokens {
let row = &hidden[t * dim..(t + 1) * dim];
for (acc, &x) in v.iter_mut().zip(row) {
*acc += x;
}
}
let inv = 1.0 / n_tokens.max(1) as f32;
for x in &mut v {
*x *= inv;
}
l2_normalize(&mut v);
v
}
pub fn l2_normalize(v: &mut [f32]) {
let norm = v.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
for x in v {
*x /= norm;
}
}
}
pub fn prepare_tokens(raw: &[u32], max: usize) -> (Vec<i64>, Vec<i64>) {
let n = raw.len().clamp(1, max.max(1));
let ids: Vec<i64> = raw[..n].iter().map(|&x| x as i64).collect();
let mask = vec![1i64; n];
(ids, mask)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_model_profile_matches_registry() {
let m = super::embed_registry::default_model();
assert_eq!(dim(), m.dim);
assert_eq!(dim(), 768, "default jina stays 768-dim");
assert_eq!(model_name(), m.model_name);
let p = profile();
assert_eq!(p.dim, m.dim);
assert_eq!(p.model_name, m.model_name);
}
#[test]
fn pool_respects_dim() {
let dim = 4;
let hidden = vec![1.0f32; 2 * dim];
let v = pool_and_normalize(&hidden, 2, dim);
assert_eq!(v.len(), dim);
for x in &v {
assert!((x - 0.5).abs() < 1e-6, "{x}");
}
}
}