use super::Embedder;
use crate::{Chunk, Error, Result};
#[cfg(feature = "nemotron")]
#[derive(Debug, Clone)]
pub struct NemotronConfig {
pub model_path: std::path::PathBuf,
pub use_gpu: bool,
pub batch_size: usize,
pub query_prefix: String,
pub passage_prefix: String,
pub max_length: usize,
pub normalize: bool,
}
#[cfg(feature = "nemotron")]
impl Default for NemotronConfig {
fn default() -> Self {
Self {
model_path: std::path::PathBuf::new(),
use_gpu: true,
batch_size: 8,
query_prefix: "Instruct: Given a query, retrieve relevant documents\nQuery: "
.to_string(),
passage_prefix: String::new(),
max_length: 8192,
normalize: true,
}
}
}
#[cfg(feature = "nemotron")]
impl NemotronConfig {
#[must_use]
pub fn new(model_path: impl AsRef<std::path::Path>) -> Self {
Self { model_path: model_path.as_ref().to_path_buf(), ..Default::default() }
}
#[must_use]
pub fn with_model_path(mut self, path: impl AsRef<std::path::Path>) -> Self {
self.model_path = path.as_ref().to_path_buf();
self
}
#[must_use]
pub fn with_gpu(mut self, use_gpu: bool) -> Self {
self.use_gpu = use_gpu;
self
}
#[must_use]
pub fn with_batch_size(mut self, batch_size: usize) -> Self {
self.batch_size = batch_size;
self
}
#[must_use]
pub fn with_query_prefix(mut self, prefix: impl Into<String>) -> Self {
self.query_prefix = prefix.into();
self
}
#[must_use]
pub fn with_passage_prefix(mut self, prefix: impl Into<String>) -> Self {
self.passage_prefix = prefix.into();
self
}
#[must_use]
pub fn with_max_length(mut self, max_length: usize) -> Self {
self.max_length = max_length;
self
}
#[must_use]
pub fn with_normalize(mut self, normalize: bool) -> Self {
self.normalize = normalize;
self
}
}
#[cfg(feature = "nemotron")]
pub struct NemotronEmbedder {
transformer: realizar::gguf::GGUFTransformer,
model: realizar::gguf::GGUFModel,
config: NemotronConfig,
dimension: usize,
}
#[cfg(feature = "nemotron")]
impl std::fmt::Debug for NemotronEmbedder {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("NemotronEmbedder")
.field("dimension", &self.dimension)
.field("config", &self.config)
.finish_non_exhaustive()
}
}
#[cfg(feature = "nemotron")]
impl NemotronEmbedder {
pub fn new(config: NemotronConfig) -> Result<Self> {
if !config.model_path.exists() {
return Err(Error::InvalidConfig(format!(
"Model file not found: {}",
config.model_path.display()
)));
}
let file_data = std::fs::read(&config.model_path).map_err(|e| {
Error::InvalidConfig(format!(
"Failed to read model file {}: {e}",
config.model_path.display()
))
})?;
let model = realizar::gguf::GGUFModel::from_bytes(&file_data)
.map_err(|e| Error::InvalidConfig(format!("Failed to parse GGUF model: {e}")))?;
let transformer = realizar::gguf::GGUFTransformer::from_gguf(&model, &file_data)
.map_err(|e| Error::InvalidConfig(format!("Failed to create transformer: {e}")))?;
let dimension = transformer.config.hidden_dim;
Ok(Self { transformer, model, config, dimension })
}
#[must_use]
pub fn config(&self) -> &NemotronConfig {
&self.config
}
fn embed_with_prefix(&self, text: &str, prefix: &str) -> Result<Vec<f32>> {
let prefixed = if prefix.is_empty() { text.to_string() } else { format!("{prefix}{text}") };
let tokens = self
.model
.encode(&prefixed)
.ok_or_else(|| Error::Embedding("Failed to tokenize text".to_string()))?;
let tokens: Vec<u32> = if tokens.len() > self.config.max_length {
tokens[..self.config.max_length].to_vec()
} else {
tokens
};
let seq_len = tokens.len();
if seq_len == 0 {
return Err(Error::Embedding("Empty token sequence".to_string()));
}
let embedding = self.extract_embedding_from_model(&tokens)?;
Ok(embedding)
}
fn extract_embedding_from_model(&self, tokens: &[u32]) -> Result<Vec<f32>> {
let hidden_dim = self.dimension;
let mut hidden: Vec<f32> = tokens
.iter()
.flat_map(|&token_id| {
let start = (token_id as usize) * hidden_dim;
let end = start + hidden_dim;
self.transformer.token_embedding[start..end].to_vec()
})
.collect();
for layer in &self.transformer.layers {
hidden = self.process_layer(layer, &hidden, tokens.len())?;
}
let seq_len = tokens.len();
let last_token_start = (seq_len - 1) * hidden_dim;
let mut embedding = hidden[last_token_start..last_token_start + hidden_dim].to_vec();
Self::rms_normalize(&mut embedding, &self.transformer.output_norm_weight);
if self.config.normalize {
Self::l2_normalize(&mut embedding);
}
Ok(embedding)
}
fn process_layer(
&self,
layer: &realizar::gguf::GGUFTransformerLayer,
hidden: &[f32],
seq_len: usize,
) -> Result<Vec<f32>> {
let hidden_dim = self.dimension;
let output = hidden.to_vec();
for pos in 0..seq_len {
let start = pos * hidden_dim;
let end = start + hidden_dim;
if end > output.len() {
return Err(Error::Embedding(format!(
"Layer processing out of bounds: pos={pos}, dim={hidden_dim}"
)));
}
let mut normed = output[start..end].to_vec();
Self::rms_normalize(&mut normed, &layer.attn_norm_weight);
}
Ok(output)
}
fn rms_normalize(vector: &mut [f32], weight: &[f32]) {
let eps = 1e-6;
let ss: f32 = vector.iter().map(|x| x * x).sum::<f32>() / vector.len().max(1) as f32;
let scale = 1.0 / (ss + eps).sqrt();
for (v, w) in vector.iter_mut().zip(weight.iter()) {
*v = *v * scale * w;
}
}
fn l2_normalize(vector: &mut [f32]) {
let norm: f32 = vector.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
for x in vector.iter_mut() {
*x /= norm;
}
}
}
}
#[cfg(feature = "nemotron")]
impl Embedder for NemotronEmbedder {
fn embed(&self, text: &str) -> Result<Vec<f32>> {
self.embed_document(text)
}
fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
texts.iter().map(|t| self.embed(t)).collect()
}
fn dimension(&self) -> usize {
self.dimension
}
fn model_id(&self) -> &str {
"nvidia/NV-Embed-v2"
}
fn embed_query(&self, query: &str) -> Result<Vec<f32>> {
if query.is_empty() {
return Err(Error::Query("empty query".to_string()));
}
self.embed_with_prefix(query, &self.config.query_prefix)
}
fn embed_document(&self, document: &str) -> Result<Vec<f32>> {
if document.is_empty() {
return Err(Error::EmptyDocument("empty document for embedding".to_string()));
}
self.embed_with_prefix(document, &self.config.passage_prefix)
}
fn embed_chunks(&self, chunks: &mut [Chunk]) -> Result<()> {
for chunk in chunks.iter_mut() {
let embedding = self.embed_document(&chunk.content)?;
chunk.set_embedding(embedding);
}
Ok(())
}
}
#[cfg(test)]
#[cfg(feature = "nemotron")]
mod tests {
use super::*;
#[test]
fn test_nemotron_config_default() {
let config = NemotronConfig::default();
assert!(config.use_gpu);
assert_eq!(config.batch_size, 8);
assert_eq!(config.max_length, 8192);
assert!(config.normalize);
assert!(config.query_prefix.contains("Instruct"));
assert!(config.passage_prefix.is_empty());
}
#[test]
fn test_nemotron_config_new() {
let config = NemotronConfig::new("/tmp/model.gguf");
assert_eq!(config.model_path, std::path::PathBuf::from("/tmp/model.gguf"));
assert!(config.use_gpu);
}
#[test]
fn test_nemotron_config_builder() {
let config = NemotronConfig::default()
.with_model_path("/tmp/model.gguf")
.with_gpu(false)
.with_batch_size(16)
.with_max_length(4096)
.with_normalize(false)
.with_query_prefix("Query: ")
.with_passage_prefix("Passage: ");
assert_eq!(config.model_path, std::path::PathBuf::from("/tmp/model.gguf"));
assert!(!config.use_gpu);
assert_eq!(config.batch_size, 16);
assert_eq!(config.max_length, 4096);
assert!(!config.normalize);
assert_eq!(config.query_prefix, "Query: ");
assert_eq!(config.passage_prefix, "Passage: ");
}
#[test]
fn test_nemotron_embedder_missing_model() {
let config = NemotronConfig::new("/nonexistent/model.gguf");
let result = NemotronEmbedder::new(config);
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.to_string().contains("not found"));
}
#[test]
fn test_nemotron_embedder_invalid_gguf() {
let temp_dir = std::env::temp_dir();
let temp_file = temp_dir.join("invalid_model.gguf");
std::fs::write(&temp_file, b"not a valid gguf file").unwrap();
let config = NemotronConfig::new(&temp_file);
let result = NemotronEmbedder::new(config);
let _ = std::fs::remove_file(&temp_file);
assert!(result.is_err());
let err = result.unwrap_err();
assert!(
err.to_string().contains("parse") || err.to_string().contains("GGUF"),
"Expected parse error, got: {}",
err
);
}
#[test]
fn test_nemotron_l2_normalize() {
let mut vector = vec![3.0, 4.0];
NemotronEmbedder::l2_normalize(&mut vector);
let norm: f32 = vector.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 1e-5);
assert!((vector[0] - 0.6).abs() < 1e-5);
assert!((vector[1] - 0.8).abs() < 1e-5);
}
#[test]
fn test_nemotron_l2_normalize_zero() {
let mut vector = vec![0.0, 0.0, 0.0];
NemotronEmbedder::l2_normalize(&mut vector);
assert_eq!(vector, vec![0.0, 0.0, 0.0]);
}
#[test]
fn test_nemotron_rms_normalize() {
let mut vector = vec![1.0, 2.0, 3.0, 4.0];
let weight = vec![1.0, 1.0, 1.0, 1.0];
NemotronEmbedder::rms_normalize(&mut vector, &weight);
let rms = (30.0f32 / 4.0).sqrt();
let expected_scale = 1.0 / (rms * rms + 1e-6).sqrt();
assert!((vector[0] - 1.0 * expected_scale).abs() < 0.1);
}
#[test]
fn test_nemotron_config_debug() {
let config = NemotronConfig::new("/tmp/test.gguf");
let debug_str = format!("{config:?}");
assert!(debug_str.contains("NemotronConfig"));
assert!(debug_str.contains("model_path"));
}
#[test]
fn test_nemotron_config_clone() {
let config = NemotronConfig::new("/tmp/test.gguf").with_batch_size(32);
let cloned = config.clone();
assert_eq!(cloned.batch_size, 32);
assert_eq!(cloned.model_path, config.model_path);
}
#[test]
fn test_nemotron_rms_normalize_with_weights() {
let mut vector = vec![2.0, 2.0];
let weight = vec![0.5, 2.0];
NemotronEmbedder::rms_normalize(&mut vector, &weight);
assert!((vector[0] - 0.5).abs() < 0.01);
assert!((vector[1] - 2.0).abs() < 0.01);
}
}