use crate::error::{AprenderError, Result};
use std::collections::HashMap;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SurgeryMethod {
DirectCopy,
NearestNeighbor,
AveragePool,
}
impl Default for SurgeryMethod {
fn default() -> Self {
Self::DirectCopy
}
}
#[derive(Debug, Clone)]
pub struct TokenizerSurgeryConfig {
pub source_vocab_size: usize,
pub target_vocab_size: usize,
pub overlap_threshold: f64,
pub method: SurgeryMethod,
}
impl Default for TokenizerSurgeryConfig {
fn default() -> Self {
Self {
source_vocab_size: 0,
target_vocab_size: 0,
overlap_threshold: 0.5,
method: SurgeryMethod::default(),
}
}
}
#[derive(Debug, Clone)]
pub struct VocabMapping {
pub source_to_target: Vec<Option<usize>>,
pub target_to_source: Vec<Option<usize>>,
pub overlap_count: usize,
pub overlap_ratio: f64,
}
#[derive(Debug, Clone)]
pub struct SurgeryReport {
pub tokens_copied: usize,
pub tokens_averaged: usize,
pub tokens_zeroed: usize,
pub overlap_ratio: f64,
}
pub fn compute_vocab_overlap(source_tokens: &[String], target_tokens: &[String]) -> VocabMapping {
let source_index: HashMap<&str, usize> = source_tokens
.iter()
.enumerate()
.map(|(i, t)| (t.as_str(), i))
.collect();
let mut source_to_target = vec![None; source_tokens.len()];
let mut target_to_source = vec![None; target_tokens.len()];
let mut overlap_count = 0usize;
for (target_idx, token) in target_tokens.iter().enumerate() {
if let Some(&source_idx) = source_index.get(token.as_str()) {
source_to_target[source_idx] = Some(target_idx);
target_to_source[target_idx] = Some(source_idx);
overlap_count += 1;
}
}
let smaller = source_tokens.len().min(target_tokens.len()).max(1);
let overlap_ratio = overlap_count as f64 / smaller as f64;
VocabMapping {
source_to_target,
target_to_source,
overlap_count,
overlap_ratio,
}
}
pub fn transplant_embeddings(
source_embeddings: &[f64],
target_embeddings: &mut [f64],
mapping: &VocabMapping,
config: &TokenizerSurgeryConfig,
hidden_dim: usize,
) -> SurgeryReport {
let mut tokens_copied = 0usize;
let mut tokens_averaged = 0usize;
let mut tokens_zeroed = 0usize;
let avg_embedding: Vec<f64> = if config.method == SurgeryMethod::AveragePool {
compute_average_embedding(source_embeddings, config.source_vocab_size, hidden_dim)
} else {
Vec::new()
};
for target_idx in 0..config.target_vocab_size {
let target_offset = target_idx * hidden_dim;
if target_offset + hidden_dim > target_embeddings.len() {
break;
}
if let Some(source_idx) = mapping.target_to_source.get(target_idx).copied().flatten() {
let source_offset = source_idx * hidden_dim;
if source_offset + hidden_dim <= source_embeddings.len() {
target_embeddings[target_offset..target_offset + hidden_dim]
.copy_from_slice(&source_embeddings[source_offset..source_offset + hidden_dim]);
tokens_copied += 1;
} else {
tokens_zeroed += 1;
}
} else {
match config.method {
SurgeryMethod::DirectCopy => {
tokens_zeroed += 1;
}
SurgeryMethod::NearestNeighbor => {
if let Some(nearest_offset) = find_nearest_source_embedding(
target_embeddings,
target_offset,
source_embeddings,
config.source_vocab_size,
hidden_dim,
) {
target_embeddings[target_offset..target_offset + hidden_dim]
.copy_from_slice(
&source_embeddings[nearest_offset..nearest_offset + hidden_dim],
);
tokens_averaged += 1;
} else {
tokens_zeroed += 1;
}
}
SurgeryMethod::AveragePool => {
if avg_embedding.len() == hidden_dim {
target_embeddings[target_offset..target_offset + hidden_dim]
.copy_from_slice(&avg_embedding);
tokens_averaged += 1;
} else {
tokens_zeroed += 1;
}
}
}
}
}
SurgeryReport {
tokens_copied,
tokens_averaged,
tokens_zeroed,
overlap_ratio: mapping.overlap_ratio,
}
}
pub fn validate_surgery(mapping: &VocabMapping, config: &TokenizerSurgeryConfig) -> Result<()> {
if config.overlap_threshold < 0.0 || config.overlap_threshold > 1.0 {
return Err(AprenderError::InvalidHyperparameter {
param: "overlap_threshold".to_string(),
value: format!("{}", config.overlap_threshold),
constraint: "must be between 0.0 and 1.0".to_string(),
});
}
if mapping.overlap_ratio < config.overlap_threshold {
return Err(AprenderError::ValidationError {
message: format!(
"vocabulary overlap {:.2}% is below threshold {:.2}%: \
surgery would destroy too many pre-trained representations \
({} tokens matched out of {} target tokens)",
mapping.overlap_ratio * 100.0,
config.overlap_threshold * 100.0,
mapping.overlap_count,
config.target_vocab_size,
),
});
}
Ok(())
}
fn compute_average_embedding(
source_embeddings: &[f64],
source_vocab_size: usize,
hidden_dim: usize,
) -> Vec<f64> {
if source_vocab_size == 0 || hidden_dim == 0 {
return vec![0.0; hidden_dim];
}
let mut avg = vec![0.0; hidden_dim];
let mut count = 0usize;
for row in 0..source_vocab_size {
let offset = row * hidden_dim;
if offset + hidden_dim > source_embeddings.len() {
break;
}
for (j, val) in source_embeddings[offset..offset + hidden_dim]
.iter()
.enumerate()
{
avg[j] += val;
}
count += 1;
}
if count > 0 {
let scale = 1.0 / count as f64;
for v in &mut avg {
*v *= scale;
}
}
avg
}
fn find_nearest_source_embedding(
target_embeddings: &[f64],
target_offset: usize,
source_embeddings: &[f64],
source_vocab_size: usize,
hidden_dim: usize,
) -> Option<usize> {
if source_vocab_size == 0 || hidden_dim == 0 {
return None;
}
let target_row = &target_embeddings[target_offset..target_offset + hidden_dim];
let mut best_offset = None;
let mut best_dist = f64::MAX;
for row in 0..source_vocab_size {
let offset = row * hidden_dim;
if offset + hidden_dim > source_embeddings.len() {
break;
}
let source_row = &source_embeddings[offset..offset + hidden_dim];
let dist: f64 = target_row
.iter()
.zip(source_row.iter())
.map(|(a, b)| (a - b) * (a - b))
.sum();
if dist < best_dist {
best_dist = dist;
best_offset = Some(offset);
}
}
best_offset
}
#[cfg(test)]
#[path = "tokenizer_surgery_tests.rs"]
mod tests;