use std::cmp::Ordering;
use std::collections::{BinaryHeap, HashMap};
struct TopKEntry<T> {
similarity: f32,
item: T,
}
impl<T> PartialEq for TopKEntry<T> {
fn eq(&self, other: &Self) -> bool {
self.similarity == other.similarity
}
}
impl<T> Eq for TopKEntry<T> {}
impl<T> PartialOrd for TopKEntry<T> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl<T> Ord for TopKEntry<T> {
fn cmp(&self, other: &Self) -> Ordering {
other
.similarity
.partial_cmp(&self.similarity)
.unwrap_or(Ordering::Equal)
}
}
#[derive(Debug, Clone)]
pub struct EmbeddingConfig {
pub dimension: usize,
pub normalize: bool,
pub window_size: usize,
pub min_count: usize,
pub negative_samples: usize,
pub learning_rate: f64,
}
impl Default for EmbeddingConfig {
fn default() -> Self {
Self {
dimension: 128,
normalize: true,
window_size: 5,
min_count: 5,
negative_samples: 5,
learning_rate: 0.025,
}
}
}
#[derive(Debug, Clone)]
pub struct CommandEmbedding {
pub command: String,
pub vector: Vec<f32>,
pub frequency: u64,
pub category: Option<CommandCategory>,
}
impl CommandEmbedding {
pub fn new(command: String, vector: Vec<f32>, frequency: u64) -> Self {
Self {
command,
vector,
frequency,
category: None,
}
}
pub fn with_category(mut self, category: CommandCategory) -> Self {
self.category = Some(category);
self
}
pub fn cosine_similarity(&self, other: &Self) -> f32 {
cosine_similarity(&self.vector, &other.vector)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum CommandCategory {
GreekLetter,
Operator,
Relation,
Accent,
Delimiter,
Function,
Spacing,
Environment,
Formatting,
Structure,
Arrow,
Other,
}
impl CommandCategory {
pub fn from_command(name: &str) -> Self {
match name {
"alpha" | "beta" | "gamma" | "delta" | "epsilon" | "zeta" | "eta" | "theta"
| "iota" | "kappa" | "lambda" | "mu" | "nu" | "xi" | "omicron" | "pi" | "rho"
| "sigma" | "tau" | "upsilon" | "phi" | "chi" | "psi" | "omega" | "Gamma" | "Delta"
| "Theta" | "Lambda" | "Xi" | "Pi" | "Sigma" | "Upsilon" | "Phi" | "Psi" | "Omega"
| "varepsilon" | "vartheta" | "varpi" | "varrho" | "varsigma" | "varphi" => {
CommandCategory::GreekLetter
}
"sum" | "prod" | "int" | "oint" | "iint" | "iiint" | "coprod" | "bigcup" | "bigcap"
| "bigvee" | "bigwedge" | "bigoplus" | "bigotimes" | "biguplus" | "bigsqcup"
| "lim" | "sup" | "inf" | "max" | "min" => CommandCategory::Operator,
"leq" | "geq" | "neq" | "equiv" | "sim" | "simeq" | "approx" | "cong" | "propto"
| "subset" | "supset" | "subseteq" | "supseteq" | "in" | "ni" | "notin" | "prec"
| "succ" | "preceq" | "succeq" | "ll" | "gg" | "perp" | "parallel" | "vdash"
| "dashv" | "models" => CommandCategory::Relation,
"hat" | "check" | "breve" | "acute" | "grave" | "tilde" | "bar" | "vec" | "dot"
| "ddot" | "dddot" | "ddddot" | "widehat" | "widetilde" | "overline" | "underline"
| "overbrace" | "underbrace" => CommandCategory::Accent,
"left" | "right" | "big" | "Big" | "bigg" | "Bigg" | "lfloor" | "rfloor" | "lceil"
| "rceil" | "langle" | "rangle" | "vert" | "Vert" => CommandCategory::Delimiter,
"sin" | "cos" | "tan" | "cot" | "sec" | "csc" | "arcsin" | "arccos" | "arctan"
| "sinh" | "cosh" | "tanh" | "coth" | "exp" | "log" | "ln" | "lg" | "det" | "dim"
| "ker" | "hom" | "arg" | "deg" | "gcd" => CommandCategory::Function,
"quad" | "qquad" | "," | ";" | ":" | "!" | "hspace" | "vspace" | "hfill" | "vfill"
| "smallskip" | "medskip" | "bigskip" | "thinspace" | "enspace" => {
CommandCategory::Spacing
}
"begin" | "end" => CommandCategory::Environment,
"textbf" | "textit" | "texttt" | "textrm" | "textsf" | "textsc" | "emph" | "mathbf"
| "mathit" | "mathrm" | "mathsf" | "mathtt" | "mathcal" | "mathfrak" | "mathbb"
| "boldsymbol" => CommandCategory::Formatting,
"section" | "subsection" | "subsubsection" | "chapter" | "part" | "paragraph"
| "subparagraph" | "title" | "author" | "date" | "maketitle" => {
CommandCategory::Structure
}
"rightarrow" | "leftarrow" | "leftrightarrow" | "Rightarrow" | "Leftarrow"
| "Leftrightarrow" | "longrightarrow" | "longleftarrow" | "longleftrightarrow"
| "Longrightarrow" | "Longleftarrow" | "Longleftrightarrow" | "uparrow"
| "downarrow" | "updownarrow" | "Uparrow" | "Downarrow" | "Updownarrow" | "nearrow"
| "searrow" | "nwarrow" | "swarrow" | "mapsto" | "hookrightarrow" | "hookleftarrow" => {
CommandCategory::Arrow
}
_ => CommandCategory::Other,
}
}
}
#[derive(Debug, Clone)]
pub struct EquationEmbedding {
pub source: String,
pub vector: Vec<f32>,
pub source_id: Option<String>,
pub label: Option<String>,
}
impl EquationEmbedding {
pub fn new(source: String, vector: Vec<f32>) -> Self {
Self {
source,
vector,
source_id: None,
label: None,
}
}
pub fn with_source_id(mut self, id: String) -> Self {
self.source_id = Some(id);
self
}
pub fn with_label(mut self, label: String) -> Self {
self.label = Some(label);
self
}
pub fn cosine_similarity(&self, other: &Self) -> f32 {
cosine_similarity(&self.vector, &other.vector)
}
}
pub struct LaTeXEmbedder {
command_embeddings: HashMap<String, CommandEmbedding>,
equation_embeddings: Vec<EquationEmbedding>,
config: EmbeddingConfig,
unknown_embedding: Vec<f32>,
}
impl LaTeXEmbedder {
pub fn new() -> Self {
Self::with_config(EmbeddingConfig::default())
}
pub fn with_config(config: EmbeddingConfig) -> Self {
let unknown_embedding = vec![0.0; config.dimension];
Self {
command_embeddings: HashMap::new(),
equation_embeddings: Vec::new(),
config,
unknown_embedding,
}
}
#[cfg(feature = "serde-extras")]
pub fn load_command_embeddings(&mut self, path: &std::path::Path) -> crate::Result<()> {
use std::io::BufRead;
let file = std::fs::File::open(path)?;
let reader = std::io::BufReader::new(file);
for line in reader.lines() {
let line = line?;
let parts: Vec<&str> = line.split_whitespace().collect();
if parts.len() > 1 {
let command = parts[0].to_string();
let vector: Vec<f32> = parts[1..].iter().filter_map(|s| s.parse().ok()).collect();
if vector.len() == self.config.dimension {
let mut embedding = CommandEmbedding::new(command.clone(), vector, 0);
embedding.category = Some(CommandCategory::from_command(&command));
self.command_embeddings.insert(command, embedding);
}
}
}
Ok(())
}
pub fn command_embedding(&self, command: &str) -> Option<&CommandEmbedding> {
self.command_embeddings.get(command)
}
pub fn command_vector(&self, command: &str) -> &[f32] {
self.command_embeddings
.get(command)
.map(|e| e.vector.as_slice())
.unwrap_or(&self.unknown_embedding)
}
pub fn add_command_embedding(&mut self, embedding: CommandEmbedding) {
self.command_embeddings
.insert(embedding.command.clone(), embedding);
}
pub fn add_equation_embedding(&mut self, embedding: EquationEmbedding) {
self.equation_embeddings.push(embedding);
}
pub fn most_similar_commands(&self, command: &str, k: usize) -> Vec<(&str, f32)> {
if k == 0 {
return Vec::new();
}
let query_embedding = match self.command_embeddings.get(command) {
Some(e) => &e.vector,
None => return Vec::new(),
};
let mut heap: BinaryHeap<TopKEntry<&str>> = BinaryHeap::with_capacity(k + 1);
for (name, embedding) in &self.command_embeddings {
if name == command {
continue;
}
let sim = cosine_similarity(query_embedding, &embedding.vector);
if heap.len() < k {
heap.push(TopKEntry {
similarity: sim,
item: name.as_str(),
});
} else if let Some(min_entry) = heap.peek() {
if sim > min_entry.similarity {
heap.pop();
heap.push(TopKEntry {
similarity: sim,
item: name.as_str(),
});
}
}
}
let mut results: Vec<(&str, f32)> = heap
.into_iter()
.map(|entry| (entry.item, entry.similarity))
.collect();
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal));
results
}
pub fn most_similar_equations(
&self,
query_vector: &[f32],
k: usize,
) -> Vec<(&EquationEmbedding, f32)> {
if k == 0 {
return Vec::new();
}
let mut heap: BinaryHeap<TopKEntry<usize>> = BinaryHeap::with_capacity(k + 1);
for (idx, embedding) in self.equation_embeddings.iter().enumerate() {
let sim = cosine_similarity(query_vector, &embedding.vector);
if heap.len() < k {
heap.push(TopKEntry {
similarity: sim,
item: idx,
});
} else if let Some(min_entry) = heap.peek() {
if sim > min_entry.similarity {
heap.pop();
heap.push(TopKEntry {
similarity: sim,
item: idx,
});
}
}
}
let mut results: Vec<(&EquationEmbedding, f32)> = heap
.into_iter()
.map(|entry| (&self.equation_embeddings[entry.item], entry.similarity))
.collect();
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal));
results
}
pub fn sequence_embedding(&self, commands: &[&str]) -> Vec<f32> {
if commands.is_empty() {
return self.unknown_embedding.clone();
}
let mut centroid = vec![0.0f32; self.config.dimension];
let mut count = 0;
for command in commands {
if let Some(embedding) = self.command_embeddings.get(*command) {
for (i, v) in embedding.vector.iter().enumerate() {
centroid[i] += v;
}
count += 1;
}
}
if count > 0 {
for v in &mut centroid {
*v /= count as f32;
}
if self.config.normalize {
normalize_vector(&mut centroid);
}
}
centroid
}
pub fn vocab_size(&self) -> usize {
self.command_embeddings.len()
}
pub fn dimension(&self) -> usize {
self.config.dimension
}
pub fn contains_command(&self, command: &str) -> bool {
self.command_embeddings.contains_key(command)
}
pub fn commands_in_category(&self, category: CommandCategory) -> Vec<&str> {
self.command_embeddings
.iter()
.filter(|(_, e)| e.category == Some(category))
.map(|(name, _)| name.as_str())
.collect()
}
}
impl Default for LaTeXEmbedder {
fn default() -> Self {
Self::new()
}
}
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() || a.is_empty() {
return 0.0;
}
let mut dot = 0.0f32;
let mut norm_a = 0.0f32;
let mut norm_b = 0.0f32;
for (x, y) in a.iter().zip(b.iter()) {
dot += x * y;
norm_a += x * x;
norm_b += y * y;
}
let denom = (norm_a * norm_b).sqrt();
if denom > 0.0 {
dot / denom
} else {
0.0
}
}
fn normalize_vector(v: &mut [f32]) {
let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
for x in v.iter_mut() {
*x /= norm;
}
}
}
#[derive(Debug, Clone)]
pub struct SimilarityResult {
pub item: String,
pub score: f32,
pub category: Option<CommandCategory>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_command_category() {
assert_eq!(
CommandCategory::from_command("alpha"),
CommandCategory::GreekLetter
);
assert_eq!(
CommandCategory::from_command("sum"),
CommandCategory::Operator
);
assert_eq!(
CommandCategory::from_command("sin"),
CommandCategory::Function
);
assert_eq!(
CommandCategory::from_command("leq"),
CommandCategory::Relation
);
assert_eq!(
CommandCategory::from_command("begin"),
CommandCategory::Environment
);
}
#[test]
fn test_cosine_similarity() {
let a = vec![1.0, 0.0, 0.0];
let b = vec![1.0, 0.0, 0.0];
assert!((cosine_similarity(&a, &b) - 1.0).abs() < 1e-6);
let c = vec![0.0, 1.0, 0.0];
assert!(cosine_similarity(&a, &c).abs() < 1e-6);
}
#[test]
fn test_normalize_vector() {
let mut v = vec![3.0, 4.0];
normalize_vector(&mut v);
let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 1e-6);
}
#[test]
fn test_embedder_basic() {
let mut embedder = LaTeXEmbedder::new();
let alpha = CommandEmbedding::new("alpha".to_string(), vec![1.0; 128], 100);
embedder.add_command_embedding(alpha);
assert!(embedder.contains_command("alpha"));
assert!(!embedder.contains_command("beta"));
}
#[test]
fn test_sequence_embedding() {
let mut embedder = LaTeXEmbedder::new();
embedder.add_command_embedding(CommandEmbedding::new(
"alpha".to_string(),
vec![1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
100,
));
embedder.add_command_embedding(CommandEmbedding::new(
"beta".to_string(),
vec![0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
100,
));
let mut config = EmbeddingConfig::default();
config.dimension = 8;
let mut embedder = LaTeXEmbedder::with_config(config);
embedder.add_command_embedding(CommandEmbedding::new(
"alpha".to_string(),
vec![1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
100,
));
embedder.add_command_embedding(CommandEmbedding::new(
"beta".to_string(),
vec![0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
100,
));
let seq_emb = embedder.sequence_embedding(&["alpha", "beta"]);
assert_eq!(seq_emb.len(), 8);
}
}