use std::path::Path;
use crate::backend::EmbedBackend;
use crate::chunk::CodeChunk;
use crate::embed::{SearchConfig, embed_all};
use crate::encoder::VectorEncoder;
use crate::profile::Profiler;
pub struct BertEncoder {
backends: Vec<Box<dyn EmbedBackend>>,
tokenizer: tokenizers::Tokenizer,
model_repo: String,
hidden_dim: usize,
}
impl BertEncoder {
#[must_use]
pub fn new(
backends: Vec<Box<dyn EmbedBackend>>,
tokenizer: tokenizers::Tokenizer,
model_repo: String,
hidden_dim: usize,
) -> Self {
Self {
backends,
tokenizer,
model_repo,
hidden_dim,
}
}
#[must_use]
pub fn backends(&self) -> &[Box<dyn EmbedBackend>] {
&self.backends
}
#[must_use]
pub fn tokenizer(&self) -> &tokenizers::Tokenizer {
&self.tokenizer
}
}
impl VectorEncoder for BertEncoder {
fn embed_root(
&self,
root: &Path,
cfg: &SearchConfig,
profiler: &Profiler,
) -> crate::Result<(Vec<CodeChunk>, Vec<Vec<f32>>)> {
let backend_refs: Vec<&dyn EmbedBackend> = self.backends.iter().map(Box::as_ref).collect();
embed_all(root, &backend_refs, &self.tokenizer, cfg, profiler)
}
fn hidden_dim(&self) -> usize {
self.hidden_dim
}
fn identity(&self) -> &str {
&self.model_repo
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn bert_encoder_implements_vector_encoder() {
fn assert_trait_object<T: VectorEncoder + Send + Sync>() {}
assert_trait_object::<BertEncoder>();
}
#[test]
fn embed_root_return_type_matches_embed_all() {
fn signature_check<E: VectorEncoder>(
e: &E,
root: &Path,
cfg: &SearchConfig,
profiler: &Profiler,
) -> crate::Result<(Vec<CodeChunk>, Vec<Vec<f32>>)> {
e.embed_root(root, cfg, profiler)
}
let _ = signature_check::<BertEncoder>;
}
}