#![allow(dead_code)]
use super::KokoroModel;
use super::ops::{gelu_new, layer_norm, linear, softmax};
const EPS: f32 = 1e-12;
const EMB: usize = 128;
impl KokoroModel {
pub fn bert(&self, input_ids: &[i64]) -> Vec<f32> {
let t = input_ids.len();
let h = self.cfg.plbert_hidden; let heads = self.cfg.plbert_heads; let hd = h / heads;
let word = self.t("k.bert.embeddings.word_embeddings.weight"); let pos = self.t("k.bert.embeddings.position_embeddings.weight"); let tok = self.t("k.bert.embeddings.token_type_embeddings.weight"); let mut emb = vec![0.0f32; t * EMB];
for (p, &id) in input_ids.iter().enumerate() {
let wrow = &word[id as usize * EMB..(id as usize + 1) * EMB];
let prow = &pos[p * EMB..(p + 1) * EMB];
for d in 0..EMB {
emb[p * EMB + d] = wrow[d] + prow[d] + tok[d]; }
}
let ln_w = self.t("k.bert.embeddings.LayerNorm.weight");
let ln_b = self.t("k.bert.embeddings.LayerNorm.bias");
let emb = layer_norm(&emb, t, EMB, &ln_w, &ln_b, EPS);
let map_w = self.t("k.bert.encoder.embedding_hidden_mapping_in.weight");
let map_b = self.t("k.bert.encoder.embedding_hidden_mapping_in.bias");
let mut hidden = linear(&emb, t, EMB, &map_w, Some(&map_b), h);
let p = "k.bert.encoder.albert_layer_groups.0.albert_layers.0.";
let qw = self.t(&format!("{p}attention.query.weight"));
let qb = self.t(&format!("{p}attention.query.bias"));
let kw = self.t(&format!("{p}attention.key.weight"));
let kb = self.t(&format!("{p}attention.key.bias"));
let vw = self.t(&format!("{p}attention.value.weight"));
let vb = self.t(&format!("{p}attention.value.bias"));
let dw = self.t(&format!("{p}attention.dense.weight"));
let db = self.t(&format!("{p}attention.dense.bias"));
let aln_w = self.t(&format!("{p}attention.LayerNorm.weight"));
let aln_b = self.t(&format!("{p}attention.LayerNorm.bias"));
let fw = self.t(&format!("{p}ffn.weight"));
let fb = self.t(&format!("{p}ffn.bias"));
let fow = self.t(&format!("{p}ffn_output.weight"));
let fob = self.t(&format!("{p}ffn_output.bias"));
let flw = self.t(&format!("{p}full_layer_layer_norm.weight"));
let flb = self.t(&format!("{p}full_layer_layer_norm.bias"));
let inter = self.cfg.plbert_inter; let scale = 1.0 / (hd as f32).sqrt();
for _layer in 0..self.cfg.plbert_layers {
let q = linear(&hidden, t, h, &qw, Some(&qb), h);
let k = linear(&hidden, t, h, &kw, Some(&kb), h);
let v = linear(&hidden, t, h, &vw, Some(&vb), h);
let mut ctx = vec![0.0f32; t * h];
let mut scores = vec![0.0f32; t];
for head in 0..heads {
let off = head * hd;
for i in 0..t {
for j in 0..t {
let mut acc = 0.0;
for d in 0..hd {
acc += q[i * h + off + d] * k[j * h + off + d];
}
scores[j] = acc * scale;
}
softmax(&mut scores);
for d in 0..hd {
let mut acc = 0.0;
for j in 0..t {
acc += scores[j] * v[j * h + off + d];
}
ctx[i * h + off + d] = acc;
}
}
}
let proj = linear(&ctx, t, h, &dw, Some(&db), h);
let mut attn_in = vec![0.0f32; t * h];
for idx in 0..t * h {
attn_in[idx] = proj[idx] + hidden[idx];
}
let attn_out = layer_norm(&attn_in, t, h, &aln_w, &aln_b, EPS);
let mut ff = linear(&attn_out, t, h, &fw, Some(&fb), inter);
gelu_new(&mut ff);
let ffo = linear(&ff, t, inter, &fow, Some(&fob), h);
let mut ffo_res = vec![0.0f32; t * h];
for idx in 0..t * h {
ffo_res[idx] = ffo[idx] + attn_out[idx];
}
hidden = layer_norm(&ffo_res, t, h, &flw, &flb, EPS);
}
hidden
}
}