use super::tokenizer::Tokenizer;
struct TrieNode {
children: std::collections::HashMap<usize, usize>,
}
impl TrieNode {
fn new() -> Self {
Self {
children: std::collections::HashMap::new(),
}
}
}
pub struct Biaser {
nodes: Vec<TrieNode>,
boost: f32,
phrase_count: usize,
}
impl Biaser {
pub(crate) fn from_sequences(sequences: Vec<Vec<usize>>, boost: f32) -> Option<Self> {
let mut nodes = vec![TrieNode::new()];
let mut phrase_count = 0;
for seq in sequences {
if seq.is_empty() {
continue;
}
phrase_count += 1;
let mut node = 0usize;
for tok in seq {
node = match nodes[node].children.get(&tok) {
Some(&child) => child,
None => {
let child = nodes.len();
nodes.push(TrieNode::new());
nodes[node].children.insert(tok, child);
child
}
};
}
}
if phrase_count == 0 {
return None;
}
Some(Self {
nodes,
boost,
phrase_count,
})
}
pub fn from_phrases(
tokenizer: &Tokenizer,
phrases: &[(String, f32)],
boost: f32,
) -> Option<Self> {
if boost <= 0.0 {
return None;
}
let mut sequences = Vec::new();
let mut dropped = 0usize;
for (phrase, weight) in phrases {
if *weight <= 0.0 {
continue;
}
match tokenizer.encode_phrase(phrase) {
Some(ids) => sequences.push(ids),
None => {
dropped += 1;
tracing::debug!(phrase = %phrase, "hotword dropped: not representable in active vocab");
}
}
}
if dropped > 0 {
tracing::warn!(
"{dropped} hotword phrase(s) dropped (not representable in active vocab)"
);
}
Self::from_sequences(sequences, boost)
}
pub fn phrase_count(&self) -> usize {
self.phrase_count
}
pub(crate) fn new_state(&self) -> BiasState {
BiasState {
active: vec![0],
}
}
pub(crate) fn boost_logits(&self, state: &BiasState, logits: &mut [f32]) {
for &node in &state.active {
for &tok in self.nodes[node].children.keys() {
if tok < logits.len() {
logits[tok] += self.boost;
}
}
}
}
pub(crate) fn advance(&self, state: &mut BiasState, tok: usize) {
let mut next = Vec::new();
for &node in &state.active {
if let Some(&child) = self.nodes[node].children.get(&tok)
&& !next.contains(&child)
{
next.push(child);
}
}
if !next.contains(&0) {
next.push(0);
}
state.active = next;
}
}
pub(crate) struct BiasState {
active: Vec<usize>,
}
#[cfg(test)]
mod tests {
use super::*;
fn biaser(seqs: Vec<Vec<usize>>, boost: f32) -> Biaser {
Biaser::from_sequences(seqs, boost).expect("non-empty sequences")
}
#[test]
fn test_from_sequences_empty_returns_none() {
assert!(Biaser::from_sequences(vec![], 5.0).is_none());
assert!(Biaser::from_sequences(vec![vec![]], 5.0).is_none());
}
#[test]
fn test_boost_applies_to_first_token_of_each_hotword() {
let b = biaser(vec![vec![1, 2], vec![3]], 5.0);
let state = b.new_state();
let mut logits = vec![0.0; 5];
b.boost_logits(&state, &mut logits);
assert_eq!(logits[1], 5.0);
assert_eq!(logits[3], 5.0);
assert_eq!(logits[2], 0.0, "mid-hotword token not boosted at root");
assert_eq!(logits[0], 0.0);
}
#[test]
fn test_advance_then_boost_continuation() {
let b = biaser(vec![vec![1, 2]], 5.0);
let mut state = b.new_state();
b.advance(&mut state, 1);
let mut logits = vec![0.0; 5];
b.boost_logits(&state, &mut logits);
assert_eq!(logits[2], 5.0, "continuation token boosted after prefix");
assert_eq!(logits[1], 5.0, "root keeps a fresh hotword start available");
}
#[test]
fn test_advance_off_prefix_resets_to_root_only() {
let b = biaser(vec![vec![1, 2]], 5.0);
let mut state = b.new_state();
b.advance(&mut state, 1); b.advance(&mut state, 9); let mut logits = vec![0.0; 5];
b.boost_logits(&state, &mut logits);
assert_eq!(logits[2], 0.0, "continuation no longer boosted after reset");
assert_eq!(logits[1], 5.0, "root start still boosted");
}
#[test]
fn test_shared_prefix_keeps_both_branches_active() {
let b = biaser(vec![vec![1, 2], vec![1, 3]], 4.0);
let mut state = b.new_state();
b.advance(&mut state, 1);
let mut logits = vec![0.0; 5];
b.boost_logits(&state, &mut logits);
assert_eq!(logits[2], 4.0);
assert_eq!(logits[3], 4.0);
}
#[test]
fn test_boost_ignores_out_of_range_token_id() {
let b = biaser(vec![vec![99]], 5.0);
let state = b.new_state();
let mut logits = vec![0.0; 5];
b.boost_logits(&state, &mut logits); assert!(logits.iter().all(|&l| l == 0.0));
}
}