use crate::error::{Result, TextError};
#[derive(Debug, Clone)]
pub struct CrossLingualConfig {
pub source_lang: String,
pub target_lang: String,
pub n_labels: usize,
pub freeze_embeddings: bool,
pub label_smoothing: f64,
}
impl Default for CrossLingualConfig {
fn default() -> Self {
Self {
source_lang: "en".to_string(),
target_lang: String::new(),
n_labels: 9,
freeze_embeddings: false,
label_smoothing: 0.1,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[non_exhaustive]
pub enum NerLabel {
O,
BPerson,
IPerson,
BOrganization,
IOrganization,
BLocation,
ILocation,
BMisc,
IMisc,
}
impl NerLabel {
pub fn from_bio_str(s: &str) -> Option<Self> {
match s {
"O" => Some(NerLabel::O),
"B-PER" | "B-PERSON" => Some(NerLabel::BPerson),
"I-PER" | "I-PERSON" => Some(NerLabel::IPerson),
"B-ORG" | "B-ORGANIZATION" => Some(NerLabel::BOrganization),
"I-ORG" | "I-ORGANIZATION" => Some(NerLabel::IOrganization),
"B-LOC" | "B-LOCATION" => Some(NerLabel::BLocation),
"I-LOC" | "I-LOCATION" => Some(NerLabel::ILocation),
"B-MISC" => Some(NerLabel::BMisc),
"I-MISC" => Some(NerLabel::IMisc),
_ => None,
}
}
pub fn to_bio_str(self) -> &'static str {
match self {
NerLabel::O => "O",
NerLabel::BPerson => "B-PER",
NerLabel::IPerson => "I-PER",
NerLabel::BOrganization => "B-ORG",
NerLabel::IOrganization => "I-ORG",
NerLabel::BLocation => "B-LOC",
NerLabel::ILocation => "I-LOC",
NerLabel::BMisc => "B-MISC",
NerLabel::IMisc => "I-MISC",
}
}
pub fn label_id(self) -> usize {
match self {
NerLabel::O => 0,
NerLabel::BPerson => 1,
NerLabel::IPerson => 2,
NerLabel::BOrganization => 3,
NerLabel::IOrganization => 4,
NerLabel::BLocation => 5,
NerLabel::ILocation => 6,
NerLabel::BMisc => 7,
NerLabel::IMisc => 8,
}
}
pub fn from_id(id: usize) -> Option<Self> {
match id {
0 => Some(NerLabel::O),
1 => Some(NerLabel::BPerson),
2 => Some(NerLabel::IPerson),
3 => Some(NerLabel::BOrganization),
4 => Some(NerLabel::IOrganization),
5 => Some(NerLabel::BLocation),
6 => Some(NerLabel::ILocation),
7 => Some(NerLabel::BMisc),
8 => Some(NerLabel::IMisc),
_ => None,
}
}
}
#[derive(Debug, Clone)]
pub struct CrossLingualNerConfig {
pub n_labels: usize,
pub hidden_dim: usize,
pub lr: f64,
pub n_epochs: usize,
}
impl Default for CrossLingualNerConfig {
fn default() -> Self {
Self {
n_labels: 9,
hidden_dim: 128,
lr: 0.01,
n_epochs: 5,
}
}
}
pub struct CrossLingualNer {
pub projection_weights: Vec<Vec<f64>>,
pub output_weights: Vec<Vec<f64>>,
pub config: CrossLingualNerConfig,
embed_dim: usize,
}
fn lcg_rand(seed: &mut u64) -> f64 {
*seed = seed
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
let bits = (*seed >> 33) as f64 / (u32::MAX as f64);
(bits - 0.5) * 0.2
}
impl CrossLingualNer {
pub fn new(embed_dim: usize, config: CrossLingualNerConfig) -> Self {
let mut seed: u64 = 0xDEAD_BEEF_CAFE_1234;
let hidden_dim = config.hidden_dim;
let n_labels = config.n_labels;
let projection_weights = (0..hidden_dim)
.map(|_| (0..embed_dim).map(|_| lcg_rand(&mut seed)).collect())
.collect();
let output_weights = (0..n_labels)
.map(|_| (0..hidden_dim).map(|_| lcg_rand(&mut seed)).collect())
.collect();
Self {
projection_weights,
output_weights,
config,
embed_dim,
}
}
fn relu(x: f64) -> f64 {
x.max(0.0)
}
fn softmax(logits: &[f64]) -> Vec<f64> {
let max_v = logits.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
let exps: Vec<f64> = logits.iter().map(|&x| (x - max_v).exp()).collect();
let sum: f64 = exps.iter().sum();
if sum == 0.0 {
exps
} else {
exps.iter().map(|&e| e / sum).collect()
}
}
fn matvec(mat: &[Vec<f64>], vec: &[f64]) -> Vec<f64> {
mat.iter()
.map(|row| row.iter().zip(vec.iter()).map(|(w, x)| w * x).sum())
.collect()
}
fn hidden_vec(&self, embedding: &[f64]) -> Vec<f64> {
CrossLingualNer::matvec(&self.projection_weights, embedding)
.into_iter()
.map(Self::relu)
.collect()
}
pub fn forward(&self, embeddings: &[Vec<f64>]) -> Vec<Vec<f64>> {
embeddings
.iter()
.map(|emb| {
let h = self.hidden_vec(emb);
CrossLingualNer::matvec(&self.output_weights, &h)
})
.collect()
}
pub fn predict(&self, embeddings: &[Vec<f64>]) -> Vec<NerLabel> {
self.forward(embeddings)
.iter()
.map(|logits| {
let best = logits
.iter()
.enumerate()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
.map(|(i, _)| i)
.unwrap_or(0);
NerLabel::from_id(best).unwrap_or(NerLabel::O)
})
.collect()
}
pub fn train_step(&mut self, embeddings: &[Vec<f64>], labels: &[usize]) -> f64 {
assert_eq!(
embeddings.len(),
labels.len(),
"embeddings and labels must have the same length"
);
let n_labels = self.config.n_labels;
let eps = self.config.lr;
let smooth = 0.1_f64;
let delta = 1e-5_f64;
let loss_at = |model: &Self| -> f64 {
let mut total = 0.0_f64;
for (emb, &lbl) in embeddings.iter().zip(labels.iter()) {
let h = model.hidden_vec(emb);
let logits = CrossLingualNer::matvec(&model.output_weights, &h);
let probs = Self::softmax(&logits);
for (k, &p) in probs.iter().enumerate() {
let target = if k == lbl {
1.0 - smooth + smooth / n_labels as f64
} else {
smooth / n_labels as f64
};
total -= target * (p + 1e-15).ln();
}
}
total / embeddings.len() as f64
};
let base_loss = loss_at(self);
let n_out_rows = self.output_weights.len();
let n_out_cols = self.output_weights[0].len();
let mut grad = vec![vec![0.0_f64; n_out_cols]; n_out_rows];
for i in 0..n_out_rows {
for j in 0..n_out_cols {
self.output_weights[i][j] += delta;
let perturbed = loss_at(self);
self.output_weights[i][j] -= delta;
grad[i][j] = (perturbed - base_loss) / delta;
}
}
for i in 0..n_out_rows {
for j in 0..n_out_cols {
self.output_weights[i][j] -= eps * grad[i][j];
}
}
base_loss
}
pub fn transfer(
&self,
source_embeddings: &[Vec<f64>],
source_labels: &[usize],
target_embeddings: &[Vec<f64>],
) -> Vec<NerLabel> {
let mut fine_tuned = CrossLingualNer::new(self.embed_dim, self.config.clone());
fine_tuned.projection_weights = self.projection_weights.clone();
fine_tuned.output_weights = self.output_weights.clone();
for _ in 0..fine_tuned.config.n_epochs {
fine_tuned.train_step(source_embeddings, source_labels);
}
fine_tuned.predict(target_embeddings)
}
}
pub fn compute_character_ngram_features(token: &str, n: usize) -> Vec<f64> {
const BUCKETS: usize = 256;
let mut counts = vec![0.0_f64; BUCKETS];
let chars: Vec<char> = token.chars().collect();
if chars.len() < n {
for ch in &chars {
let bucket = fnv1a_char_hash(*ch) % BUCKETS;
counts[bucket] += 1.0;
}
} else {
for window in chars.windows(n) {
let h = fnv1a_window_hash(window) % BUCKETS;
counts[h] += 1.0;
}
}
let norm: f64 = counts.iter().map(|x| x * x).sum::<f64>().sqrt();
if norm > 1e-12 {
counts.iter_mut().for_each(|x| *x /= norm);
}
counts
}
fn fnv1a_char_hash(ch: char) -> usize {
let mut hash: u32 = 2166136261;
for byte in (ch as u32).to_le_bytes() {
hash ^= byte as u32;
hash = hash.wrapping_mul(16777619);
}
hash as usize
}
fn fnv1a_window_hash(chars: &[char]) -> usize {
let mut hash: u32 = 2166136261;
for &ch in chars {
for byte in (ch as u32).to_le_bytes() {
hash ^= byte as u32;
hash = hash.wrapping_mul(16777619);
}
}
hash as usize
}
pub fn align_embeddings(source: &[Vec<f64>], target: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
if source.is_empty() || target.is_empty() {
return Err(TextError::InvalidInput(
"align_embeddings: source and target must be non-empty".into(),
));
}
let d = source[0].len();
if target[0].len() != d {
return Err(TextError::InvalidInput(format!(
"align_embeddings: embedding dimension mismatch ({} vs {})",
d,
target[0].len()
)));
}
let n = source.len().min(target.len());
let mut m = vec![vec![0.0_f64; d]; d];
for k in 0..n {
for i in 0..d {
for j in 0..d {
m[i][j] += source[k][i] * target[k][j];
}
}
}
let (u_mat, _sigma, v_mat) = jacobi_svd(&m);
let r = mat_mul_transpose(&v_mat, &u_mat);
let aligned = source.iter().map(|s| matvec_t(&r, s)).collect();
Ok(aligned)
}
fn mat_mul_transpose(a: &[Vec<f64>], b: &[Vec<f64>]) -> Vec<Vec<f64>> {
let d = a.len();
let mut res = vec![vec![0.0_f64; d]; d];
for i in 0..d {
for j in 0..d {
for k in 0..d {
res[i][j] += a[i][k] * b[j][k];
}
}
}
res
}
fn matvec_t(m: &[Vec<f64>], x: &[f64]) -> Vec<f64> {
m.iter()
.map(|row| row.iter().zip(x.iter()).map(|(w, v)| w * v).sum())
.collect()
}
fn jacobi_svd(m: &[Vec<f64>]) -> (Vec<Vec<f64>>, Vec<f64>, Vec<Vec<f64>>) {
let d = m.len();
let mut a = mat_mul_mat_t(m);
let mut v = identity(d);
let max_iter = 100 * d * d;
for _ in 0..max_iter {
let mut p = 0usize;
let mut q = 1usize;
let mut max_val = 0.0_f64;
for i in 0..d {
for j in (i + 1)..d {
if a[i][j].abs() > max_val {
max_val = a[i][j].abs();
p = i;
q = j;
}
}
}
if max_val < 1e-12 {
break;
}
let theta = if (a[q][q] - a[p][p]).abs() < 1e-12 {
std::f64::consts::FRAC_PI_4
} else {
0.5 * ((2.0 * a[p][q]) / (a[q][q] - a[p][p])).atan()
};
let c = theta.cos();
let s = theta.sin();
jacobi_rotate(&mut a, p, q, c, s);
jacobi_rotate(&mut v, p, q, c, s);
}
let sigma: Vec<f64> = (0..d).map(|i| a[i][i].abs().sqrt()).collect();
let mut u = vec![vec![0.0_f64; d]; d];
for j in 0..d {
let mv: Vec<f64> = matvec_t(m, &v.iter().map(|row| row[j]).collect::<Vec<_>>());
let sig = sigma[j].max(1e-12);
for i in 0..d {
u[i][j] = mv[i] / sig;
}
}
(u, sigma, v)
}
fn mat_mul_mat_t(m: &[Vec<f64>]) -> Vec<Vec<f64>> {
let d = m.len();
let mut res = vec![vec![0.0_f64; d]; d];
for i in 0..d {
for j in 0..d {
for k in 0..d {
res[i][j] += m[i][k] * m[j][k];
}
}
}
res
}
fn identity(d: usize) -> Vec<Vec<f64>> {
let mut eye = vec![vec![0.0_f64; d]; d];
for i in 0..d {
eye[i][i] = 1.0;
}
eye
}
fn jacobi_rotate(a: &mut [Vec<f64>], p: usize, q: usize, c: f64, s: f64) {
let d = a.len();
for k in 0..d {
let ap = a[k][p];
let aq = a[k][q];
a[k][p] = c * ap - s * aq;
a[k][q] = s * ap + c * aq;
}
for k in 0..d {
let ap = a[p][k];
let aq = a[q][k];
a[p][k] = c * ap - s * aq;
a[q][k] = s * ap + c * aq;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ner_label_from_bio_str() {
assert_eq!(NerLabel::from_bio_str("B-PER"), Some(NerLabel::BPerson));
assert_eq!(
NerLabel::from_bio_str("I-ORG"),
Some(NerLabel::IOrganization)
);
assert_eq!(NerLabel::from_bio_str("O"), Some(NerLabel::O));
assert_eq!(NerLabel::from_bio_str("B-LOC"), Some(NerLabel::BLocation));
assert_eq!(NerLabel::from_bio_str("UNKNOWN"), None);
}
#[test]
fn test_ner_label_round_trip() {
let labels = [
NerLabel::O,
NerLabel::BPerson,
NerLabel::IPerson,
NerLabel::BOrganization,
NerLabel::IOrganization,
NerLabel::BLocation,
NerLabel::ILocation,
NerLabel::BMisc,
NerLabel::IMisc,
];
for lbl in &labels {
assert_eq!(NerLabel::from_id(lbl.label_id()), Some(*lbl));
}
}
#[test]
fn test_forward_shape() {
let embed_dim = 16;
let config = CrossLingualNerConfig {
n_labels: 9,
hidden_dim: 32,
..Default::default()
};
let model = CrossLingualNer::new(embed_dim, config);
let embeddings: Vec<Vec<f64>> = (0..5).map(|_| vec![0.1_f64; embed_dim]).collect();
let logits = model.forward(&embeddings);
assert_eq!(logits.len(), 5);
assert_eq!(logits[0].len(), 9);
}
#[test]
fn test_predict_length() {
let embed_dim = 8;
let model = CrossLingualNer::new(embed_dim, CrossLingualNerConfig::default());
let embeddings: Vec<Vec<f64>> = (0..3).map(|i| vec![i as f64 * 0.1; embed_dim]).collect();
let preds = model.predict(&embeddings);
assert_eq!(preds.len(), 3);
}
#[test]
fn test_train_step_reduces_loss() {
let embed_dim = 8;
let config = CrossLingualNerConfig {
n_labels: 9,
hidden_dim: 16,
lr: 0.05,
n_epochs: 1,
};
let mut model = CrossLingualNer::new(embed_dim, config);
let embeddings = vec![vec![0.5_f64; embed_dim]];
let labels = vec![0usize];
let loss1 = model.train_step(&embeddings, &labels);
let loss2 = model.train_step(&embeddings, &labels);
assert!(
loss2 <= loss1 + 1e-6,
"loss should decrease: {} -> {}",
loss1,
loss2
);
}
#[test]
fn test_character_ngram_features_dim() {
let feat = compute_character_ngram_features("hello", 3);
assert_eq!(feat.len(), 256, "feature vector must be 256-dimensional");
}
#[test]
fn test_character_ngram_features_short_token() {
let feat = compute_character_ngram_features("ab", 4);
assert_eq!(feat.len(), 256);
}
#[test]
fn test_character_ngram_features_normalised() {
let feat = compute_character_ngram_features("hello world", 2);
let norm: f64 = feat.iter().map(|x| x * x).sum::<f64>().sqrt();
assert!(
(norm - 1.0).abs() < 1e-9,
"features should be L2-normalised, got norm={}",
norm
);
}
#[test]
fn test_align_embeddings_identity() {
let vecs: Vec<Vec<f64>> = vec![
vec![1.0, 0.0, 0.0],
vec![0.0, 1.0, 0.0],
vec![0.0, 0.0, 1.0],
];
let aligned = align_embeddings(&vecs, &vecs).expect("alignment failed");
for (a, b) in aligned.iter().zip(vecs.iter()) {
for (x, y) in a.iter().zip(b.iter()) {
assert!((x - y).abs() < 1e-6, "identity alignment failed");
}
}
}
#[test]
fn test_transfer_returns_correct_length() {
let embed_dim = 8;
let model = CrossLingualNer::new(embed_dim, CrossLingualNerConfig::default());
let src: Vec<Vec<f64>> = (0..4).map(|_| vec![0.1_f64; embed_dim]).collect();
let src_labels: Vec<usize> = vec![0, 1, 0, 2];
let tgt: Vec<Vec<f64>> = (0..6).map(|_| vec![0.2_f64; embed_dim]).collect();
let preds = model.transfer(&src, &src_labels, &tgt);
assert_eq!(preds.len(), 6);
}
}