use std::sync::Arc;
use super::{CodeEmbedder, CodeEmbeddingError, CodeLanguage, Result};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum EnsembleStrategy {
#[default]
Concatenate,
WeightedAverage,
MaxPooling,
MeanPooling,
}
pub struct EnsembleCodeEmbedder {
embedders: Vec<Arc<dyn CodeEmbedder>>,
weights: Vec<f64>,
strategy: EnsembleStrategy,
embedding_dim: usize,
normalize_final: bool,
}
impl EnsembleCodeEmbedder {
pub fn new(embedders: Vec<Arc<dyn CodeEmbedder>>) -> Self {
let embedding_dim: usize = embedders.iter().map(|e| e.embedding_dim()).sum();
let weights = vec![1.0; embedders.len()];
Self {
embedders,
weights,
strategy: EnsembleStrategy::Concatenate,
embedding_dim,
normalize_final: true,
}
}
pub fn with_strategy(
embedders: Vec<Arc<dyn CodeEmbedder>>,
strategy: EnsembleStrategy,
weights: Option<Vec<f64>>,
) -> Result<Self> {
if embedders.is_empty() {
return Err(CodeEmbeddingError::Inference(
"Ensemble requires at least one embedder".to_string(),
));
}
let weights =
weights.unwrap_or_else(|| vec![1.0 / embedders.len() as f64; embedders.len()]);
if weights.len() != embedders.len() {
return Err(CodeEmbeddingError::Inference(format!(
"Weight count ({}) must match embedder count ({})",
weights.len(),
embedders.len()
)));
}
let first_dim = embedders[0].embedding_dim();
if strategy != EnsembleStrategy::Concatenate {
for (i, embedder) in embedders.iter().enumerate().skip(1) {
if embedder.embedding_dim() != first_dim {
return Err(CodeEmbeddingError::Inference(format!(
"{:?} strategy requires equal embedding dimensions. \
Embedder {} has dim {} but expected {}",
strategy,
i,
embedder.embedding_dim(),
first_dim
)));
}
}
}
let embedding_dim = match strategy {
EnsembleStrategy::Concatenate => embedders.iter().map(|e| e.embedding_dim()).sum(),
_ => first_dim,
};
Ok(Self {
embedders,
weights,
strategy,
embedding_dim,
normalize_final: true,
})
}
pub fn set_normalize_final(&mut self, normalize: bool) {
self.normalize_final = normalize;
}
pub fn strategy(&self) -> EnsembleStrategy {
self.strategy
}
pub fn weights(&self) -> &[f64] {
&self.weights
}
pub fn num_embedders(&self) -> usize {
self.embedders.len()
}
fn combine_embeddings(&self, embeddings: Vec<Vec<f32>>) -> Vec<f32> {
if embeddings.is_empty() {
return vec![];
}
let mut result = match self.strategy {
EnsembleStrategy::Concatenate => {
embeddings.into_iter().flatten().collect()
}
EnsembleStrategy::WeightedAverage => {
let dim = embeddings[0].len();
let mut combined = vec![0.0f32; dim];
for (embedding, &weight) in embeddings.iter().zip(self.weights.iter()) {
let weight = weight as f32;
for (i, &val) in embedding.iter().enumerate() {
combined[i] += val * weight;
}
}
combined
}
EnsembleStrategy::MaxPooling => {
let dim = embeddings[0].len();
let mut combined = vec![f32::NEG_INFINITY; dim];
for embedding in &embeddings {
for (i, &val) in embedding.iter().enumerate() {
if val > combined[i] {
combined[i] = val;
}
}
}
combined
}
EnsembleStrategy::MeanPooling => {
let dim = embeddings[0].len();
let n = embeddings.len() as f32;
let mut combined = vec![0.0f32; dim];
for embedding in &embeddings {
for (i, &val) in embedding.iter().enumerate() {
combined[i] += val / n;
}
}
combined
}
};
if self.normalize_final {
super::normalize_embedding(&mut result);
}
result
}
}
impl CodeEmbedder for EnsembleCodeEmbedder {
fn embed_code(&self, code: &str, language: CodeLanguage) -> Result<Vec<f32>> {
let embeddings: Vec<Vec<f32>> = self
.embedders
.iter()
.map(|e| e.embed_code(code, language))
.collect::<Result<Vec<_>>>()?;
Ok(self.combine_embeddings(embeddings))
}
fn embed_code_batch(
&self,
codes: &[&str],
languages: &[CodeLanguage],
) -> Result<Vec<Vec<f32>>> {
if codes.is_empty() {
return Ok(vec![]);
}
let all_model_embeddings: Vec<Vec<Vec<f32>>> = self
.embedders
.iter()
.map(|e| e.embed_code_batch(codes, languages))
.collect::<Result<Vec<_>>>()?;
let num_codes = codes.len();
let mut results = Vec::with_capacity(num_codes);
for i in 0..num_codes {
let embeddings: Vec<Vec<f32>> = all_model_embeddings
.iter()
.map(|model_embeddings| model_embeddings[i].clone())
.collect();
results.push(self.combine_embeddings(embeddings));
}
Ok(results)
}
fn embedding_dim(&self) -> usize {
self.embedding_dim
}
fn model_name(&self) -> &str {
"Ensemble"
}
fn max_sequence_length(&self) -> usize {
self.embedders
.iter()
.map(|e| e.max_sequence_length())
.min()
.unwrap_or(512)
}
fn supported_languages(&self) -> &[CodeLanguage] {
&[]
}
}
impl std::fmt::Debug for EnsembleCodeEmbedder {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("EnsembleCodeEmbedder")
.field("num_embedders", &self.embedders.len())
.field("strategy", &self.strategy)
.field("embedding_dim", &self.embedding_dim)
.field("weights", &self.weights)
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
struct MockEmbedder {
dim: usize,
value: f32,
}
impl MockEmbedder {
fn new(dim: usize, value: f32) -> Self {
Self { dim, value }
}
}
impl CodeEmbedder for MockEmbedder {
fn embed_code(&self, _code: &str, _language: CodeLanguage) -> Result<Vec<f32>> {
Ok(vec![self.value; self.dim])
}
fn embed_code_batch(
&self,
codes: &[&str],
languages: &[CodeLanguage],
) -> Result<Vec<Vec<f32>>> {
codes
.iter()
.zip(
languages
.iter()
.chain(std::iter::repeat(&CodeLanguage::Unknown)),
)
.map(|(code, lang)| self.embed_code(code, *lang))
.collect()
}
fn embedding_dim(&self) -> usize {
self.dim
}
fn model_name(&self) -> &str {
"Mock"
}
fn max_sequence_length(&self) -> usize {
512
}
fn supported_languages(&self) -> &[CodeLanguage] {
&[]
}
}
#[test]
fn test_concatenate_strategy() {
let embedders: Vec<Arc<dyn CodeEmbedder>> = vec![
Arc::new(MockEmbedder::new(3, 1.0)),
Arc::new(MockEmbedder::new(2, 2.0)),
];
let mut ensemble = EnsembleCodeEmbedder::new(embedders);
ensemble.set_normalize_final(false);
assert_eq!(ensemble.embedding_dim(), 5);
let embedding = ensemble.embed_code("test", CodeLanguage::Rust).unwrap();
assert_eq!(embedding.len(), 5);
assert_eq!(&embedding[..3], &[1.0, 1.0, 1.0]);
assert_eq!(&embedding[3..], &[2.0, 2.0]);
}
#[test]
fn test_weighted_average_strategy() {
let embedders: Vec<Arc<dyn CodeEmbedder>> = vec![
Arc::new(MockEmbedder::new(3, 1.0)),
Arc::new(MockEmbedder::new(3, 2.0)),
];
let mut ensemble = EnsembleCodeEmbedder::with_strategy(
embedders,
EnsembleStrategy::WeightedAverage,
Some(vec![0.5, 0.5]),
)
.unwrap();
ensemble.set_normalize_final(false);
assert_eq!(ensemble.embedding_dim(), 3);
let embedding = ensemble.embed_code("test", CodeLanguage::Rust).unwrap();
assert_eq!(embedding.len(), 3);
assert!((embedding[0] - 1.5).abs() < 1e-6);
}
#[test]
fn test_max_pooling_strategy() {
let embedders: Vec<Arc<dyn CodeEmbedder>> = vec![
Arc::new(MockEmbedder::new(3, 1.0)),
Arc::new(MockEmbedder::new(3, 2.0)),
];
let mut ensemble =
EnsembleCodeEmbedder::with_strategy(embedders, EnsembleStrategy::MaxPooling, None)
.unwrap();
ensemble.set_normalize_final(false);
let embedding = ensemble.embed_code("test", CodeLanguage::Rust).unwrap();
assert_eq!(embedding.len(), 3);
assert!((embedding[0] - 2.0).abs() < 1e-6);
}
#[test]
fn test_mean_pooling_strategy() {
let embedders: Vec<Arc<dyn CodeEmbedder>> = vec![
Arc::new(MockEmbedder::new(3, 1.0)),
Arc::new(MockEmbedder::new(3, 3.0)),
];
let mut ensemble =
EnsembleCodeEmbedder::with_strategy(embedders, EnsembleStrategy::MeanPooling, None)
.unwrap();
ensemble.set_normalize_final(false);
let embedding = ensemble.embed_code("test", CodeLanguage::Rust).unwrap();
assert_eq!(embedding.len(), 3);
assert!((embedding[0] - 2.0).abs() < 1e-6);
}
#[test]
fn test_dimension_mismatch_error() {
let embedders: Vec<Arc<dyn CodeEmbedder>> = vec![
Arc::new(MockEmbedder::new(3, 1.0)),
Arc::new(MockEmbedder::new(4, 2.0)), ];
let result =
EnsembleCodeEmbedder::with_strategy(embedders, EnsembleStrategy::WeightedAverage, None);
assert!(result.is_err());
}
}