mod codet5;
mod ensemble;
mod graphcodebert;
mod unixcoder;
pub use codet5::{CodeT5Config, CodeT5Embedder};
pub use ensemble::{EnsembleCodeEmbedder, EnsembleStrategy};
pub use graphcodebert::{GraphCodeBertConfig, GraphCodeBertEmbedder};
pub use unixcoder::{UniXcoderConfig, UniXcoderEmbedder};
use std::sync::Arc;
pub type Result<T> = std::result::Result<T, CodeEmbeddingError>;
#[derive(Debug, thiserror::Error)]
pub enum CodeEmbeddingError {
#[error("Failed to load model: {0}")]
ModelLoad(String),
#[error("Tokenization error: {0}")]
Tokenization(String),
#[error("Inference error: {0}")]
Inference(String),
#[error("ONNX Runtime error: {0}")]
Onnx(String),
#[error("Unsupported language: {0}")]
UnsupportedLanguage(String),
#[error("I/O error: {0}")]
Io(#[from] std::io::Error),
}
impl From<ort::Error> for CodeEmbeddingError {
fn from(e: ort::Error) -> Self {
CodeEmbeddingError::Onnx(e.to_string())
}
}
impl From<tokenizers::Error> for CodeEmbeddingError {
fn from(e: tokenizers::Error) -> Self {
CodeEmbeddingError::Tokenization(e.to_string())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum CodeLanguage {
Python,
Java,
JavaScript,
TypeScript,
Go,
Ruby,
Php,
C,
Cpp,
CSharp,
Rust,
Kotlin,
Scala,
Swift,
Haskell,
OCaml,
Elixir,
Bash,
Rholang,
MeTTa,
Unknown,
}
impl CodeLanguage {
pub fn prefix(&self) -> &'static str {
match self {
Self::Python => "<python>",
Self::Java => "<java>",
Self::JavaScript => "<javascript>",
Self::TypeScript => "<typescript>",
Self::Go => "<go>",
Self::Ruby => "<ruby>",
Self::Php => "<php>",
Self::C => "<c>",
Self::Cpp => "<cpp>",
Self::CSharp => "<c_sharp>",
Self::Rust => "<rust>",
Self::Kotlin => "<kotlin>",
Self::Scala => "<scala>",
Self::Swift => "<swift>",
Self::Haskell => "<haskell>",
Self::OCaml => "<ocaml>",
Self::Elixir => "<elixir>",
Self::Bash => "<bash>",
Self::Rholang => "<rholang>",
Self::MeTTa => "<metta>",
Self::Unknown => "",
}
}
pub fn name(&self) -> &'static str {
match self {
Self::Python => "python",
Self::Java => "java",
Self::JavaScript => "javascript",
Self::TypeScript => "typescript",
Self::Go => "go",
Self::Ruby => "ruby",
Self::Php => "php",
Self::C => "c",
Self::Cpp => "cpp",
Self::CSharp => "csharp",
Self::Rust => "rust",
Self::Kotlin => "kotlin",
Self::Scala => "scala",
Self::Swift => "swift",
Self::Haskell => "haskell",
Self::OCaml => "ocaml",
Self::Elixir => "elixir",
Self::Bash => "bash",
Self::Rholang => "rholang",
Self::MeTTa => "metta",
Self::Unknown => "unknown",
}
}
pub fn from_extension(ext: &str) -> Self {
match ext.to_lowercase().as_str() {
"py" => Self::Python,
"java" => Self::Java,
"js" | "mjs" | "cjs" => Self::JavaScript,
"ts" | "tsx" => Self::TypeScript,
"go" => Self::Go,
"rb" => Self::Ruby,
"php" => Self::Php,
"c" | "h" => Self::C,
"cpp" | "cc" | "cxx" | "hpp" | "hxx" => Self::Cpp,
"cs" => Self::CSharp,
"rs" => Self::Rust,
"kt" | "kts" => Self::Kotlin,
"scala" | "sc" => Self::Scala,
"swift" => Self::Swift,
"hs" | "lhs" => Self::Haskell,
"ml" | "mli" => Self::OCaml,
"ex" | "exs" => Self::Elixir,
"sh" | "bash" => Self::Bash,
"rho" => Self::Rholang,
"metta" | "mtt" => Self::MeTTa,
_ => Self::Unknown,
}
}
}
pub trait CodeEmbedder: Send + Sync {
fn embed_code(&self, code: &str, language: CodeLanguage) -> Result<Vec<f32>>;
fn embed_code_batch(&self, codes: &[&str], languages: &[CodeLanguage])
-> Result<Vec<Vec<f32>>>;
fn embedding_dim(&self) -> usize;
fn model_name(&self) -> &str;
fn max_sequence_length(&self) -> usize;
fn supported_languages(&self) -> &[CodeLanguage];
fn supports_language(&self, language: CodeLanguage) -> bool {
let supported = self.supported_languages();
supported.is_empty() || supported.contains(&language)
}
}
#[derive(Clone, Debug)]
pub struct CodeEmbeddingCacheConfig {
pub max_entries: usize,
pub hash_keys: bool,
}
impl Default for CodeEmbeddingCacheConfig {
fn default() -> Self {
Self {
max_entries: 10000,
hash_keys: true,
}
}
}
pub struct CodeEmbeddingCache {
cache: dashmap::DashMap<u64, Arc<[f32]>>,
config: CodeEmbeddingCacheConfig,
}
impl CodeEmbeddingCache {
pub fn new(config: CodeEmbeddingCacheConfig) -> Self {
Self {
cache: dashmap::DashMap::with_capacity(config.max_entries),
config,
}
}
pub fn get(&self, code: &str, language: CodeLanguage) -> Option<Arc<[f32]>> {
let key = self.compute_key(code, language);
self.cache.get(&key).map(|v| Arc::clone(&v))
}
pub fn insert(&self, code: &str, language: CodeLanguage, embedding: Vec<f32>) {
if self.cache.len() >= self.config.max_entries {
if let Some(entry) = self.cache.iter().next() {
let key = *entry.key();
drop(entry);
self.cache.remove(&key);
}
}
let key = self.compute_key(code, language);
self.cache
.insert(key, Arc::from(embedding.into_boxed_slice()));
}
pub fn clear(&self) {
self.cache.clear();
}
pub fn len(&self) -> usize {
self.cache.len()
}
pub fn is_empty(&self) -> bool {
self.cache.is_empty()
}
fn compute_key(&self, code: &str, language: CodeLanguage) -> u64 {
use crate::util::hash::{fnv1a, GXHASH_MIN_SIZE};
use std::hash::{Hash, Hasher};
if code.len() >= GXHASH_MIN_SIZE {
let mut hasher = gxhash::GxHasher::default();
code.hash(&mut hasher);
language.hash(&mut hasher);
hasher.finish()
} else {
let mut hash = fnv1a(code.as_bytes());
hash ^= language as u64;
hash.wrapping_mul(0x100000001b3) }
}
}
pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
debug_assert_eq!(a.len(), b.len(), "Embedding dimensions must match");
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a == 0.0 || norm_b == 0.0 {
0.0
} else {
dot / (norm_a * norm_b)
}
}
pub fn normalize_embedding(embedding: &mut [f32]) {
let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
for x in embedding.iter_mut() {
*x /= norm;
}
}
}
pub fn normalize_embedding_clone(embedding: &[f32]) -> Vec<f32> {
let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
embedding.iter().map(|x| x / norm).collect()
} else {
embedding.to_vec()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_language_from_extension() {
assert_eq!(CodeLanguage::from_extension("py"), CodeLanguage::Python);
assert_eq!(CodeLanguage::from_extension("rs"), CodeLanguage::Rust);
assert_eq!(CodeLanguage::from_extension("rho"), CodeLanguage::Rholang);
assert_eq!(CodeLanguage::from_extension("metta"), CodeLanguage::MeTTa);
assert_eq!(CodeLanguage::from_extension("xyz"), CodeLanguage::Unknown);
}
#[test]
fn test_cosine_similarity() {
let a = vec![1.0, 0.0, 0.0];
let b = vec![1.0, 0.0, 0.0];
assert!((cosine_similarity(&a, &b) - 1.0).abs() < 1e-6);
let c = vec![0.0, 1.0, 0.0];
assert!((cosine_similarity(&a, &c) - 0.0).abs() < 1e-6);
let d = vec![-1.0, 0.0, 0.0];
assert!((cosine_similarity(&a, &d) - (-1.0)).abs() < 1e-6);
}
#[test]
fn test_normalize_embedding() {
let mut embedding = vec![3.0, 4.0];
normalize_embedding(&mut embedding);
let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 1e-6);
}
#[test]
fn test_embedding_cache() {
let cache = CodeEmbeddingCache::new(CodeEmbeddingCacheConfig {
max_entries: 10,
hash_keys: true,
});
let embedding = vec![1.0, 2.0, 3.0];
cache.insert("fn main() {}", CodeLanguage::Rust, embedding.clone());
let retrieved = cache.get("fn main() {}", CodeLanguage::Rust);
assert!(retrieved.is_some());
assert_eq!(retrieved.unwrap().as_ref(), &embedding[..]);
let missed = cache.get("fn main() {}", CodeLanguage::Python);
assert!(missed.is_none());
}
}