1use std::path::Path;
37use std::sync::Arc;
38
39use ndarray::IxDyn;
40use ort::session::Session;
41use ort::value::Tensor;
42use parking_lot::Mutex;
43use serde::{Deserialize, Serialize};
44use tokenizers::Tokenizer;
45
46use crate::embeddings::Embedder;
47use crate::error::{Error, Result};
48
49#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
71pub enum EmbeddingTask {
72 #[default]
75 NL2Code,
76 Code2Code,
79 Code2NL,
82 Code2Completion,
85 QA,
88}
89
90impl EmbeddingTask {
91 pub fn query_prefix(&self) -> &'static str {
95 match self {
96 Self::NL2Code => "Find the most relevant code snippet given the following query:\n",
97 Self::Code2Code => "Find an equivalent code snippet given the following code snippet:\n",
98 Self::Code2NL => "Find the most relevant comment given the following code snippet:\n",
99 Self::Code2Completion => "Find the most relevant completion given the following start of code snippet:\n",
100 Self::QA => "Find the most relevant answer given the following question:\n",
101 }
102 }
103
104 pub fn passage_prefix(&self) -> &'static str {
108 match self {
109 Self::NL2Code => "Candidate code snippet:\n",
110 Self::Code2Code => "Candidate code snippet:\n",
111 Self::Code2NL => "Candidate comment:\n",
112 Self::Code2Completion => "Candidate completion:\n",
113 Self::QA => "Candidate answer:\n",
114 }
115 }
116
117 pub fn instruction_prefix(&self) -> &'static str {
121 self.query_prefix()
122 }
123
124 pub fn name(&self) -> &'static str {
126 match self {
127 Self::NL2Code => "nl2code",
128 Self::Code2Code => "code2code",
129 Self::Code2NL => "code2nl",
130 Self::Code2Completion => "code2completion",
131 Self::QA => "qa",
132 }
133 }
134
135 pub fn from_name(name: &str) -> Option<Self> {
137 match name.to_lowercase().as_str() {
138 "nl2code" | "text2code" | "natural-language-to-code" => Some(Self::NL2Code),
139 "code2code" | "code-to-code" | "similar-code" => Some(Self::Code2Code),
140 "code2nl" | "code-to-text" | "summarize" => Some(Self::Code2NL),
141 "code2completion" | "completion" | "autocomplete" => Some(Self::Code2Completion),
142 "qa" | "question-answering" | "technical-qa" => Some(Self::QA),
143 _ => None,
144 }
145 }
146}
147
148#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
153pub enum MatryoshkaDimension {
154 D128 = 128,
156 D256 = 256,
158 D512 = 512,
160 D1024 = 1024,
162 D1536 = 1536,
164}
165
166impl Default for MatryoshkaDimension {
167 fn default() -> Self {
168 Self::D1536
169 }
170}
171
172impl MatryoshkaDimension {
173 pub fn value(&self) -> usize {
175 *self as usize
176 }
177
178 pub fn all() -> &'static [MatryoshkaDimension] {
180 &[
181 Self::D128,
182 Self::D256,
183 Self::D512,
184 Self::D1024,
185 Self::D1536,
186 ]
187 }
188
189 pub fn from_value(dim: usize) -> Self {
191 if dim <= 128 {
192 Self::D128
193 } else if dim <= 256 {
194 Self::D256
195 } else if dim <= 512 {
196 Self::D512
197 } else if dim <= 1024 {
198 Self::D1024
199 } else {
200 Self::D1536
201 }
202 }
203}
204
205#[derive(Debug, Clone, Serialize, Deserialize)]
211pub struct ExecutionProviderInfo {
212 pub name: String,
214 pub is_gpu: bool,
216 pub device_id: Option<u32>,
218 pub details: Option<String>,
220}
221
222impl ExecutionProviderInfo {
223 pub fn cpu() -> Self {
225 Self {
226 name: "CPU".to_string(),
227 is_gpu: false,
228 device_id: None,
229 details: Some("Default CPU execution".to_string()),
230 }
231 }
232
233 #[allow(dead_code)]
235 pub fn cuda(device_id: u32) -> Self {
236 Self {
237 name: "CUDA".to_string(),
238 is_gpu: true,
239 device_id: Some(device_id),
240 details: Some(format!("NVIDIA CUDA GPU (device {})", device_id)),
241 }
242 }
243
244 #[allow(dead_code)]
246 pub fn tensorrt(device_id: u32) -> Self {
247 Self {
248 name: "TensorRT".to_string(),
249 is_gpu: true,
250 device_id: Some(device_id),
251 details: Some(format!("NVIDIA TensorRT GPU (device {})", device_id)),
252 }
253 }
254
255 #[allow(dead_code)]
257 pub fn directml(device_id: u32) -> Self {
258 Self {
259 name: "DirectML".to_string(),
260 is_gpu: true,
261 device_id: Some(device_id),
262 details: Some(format!("DirectML GPU (device {})", device_id)),
263 }
264 }
265
266 #[allow(dead_code)]
268 pub fn coreml() -> Self {
269 Self {
270 name: "CoreML".to_string(),
271 is_gpu: true,
272 device_id: None,
273 details: Some("Apple CoreML (Neural Engine/GPU)".to_string()),
274 }
275 }
276
277 pub fn description(&self) -> String {
279 if self.is_gpu {
280 format!("{} (GPU accelerated)", self.name)
281 } else {
282 format!("{} (no GPU)", self.name)
283 }
284 }
285}
286
287pub struct OnnxEmbedder {
303 session: Arc<Mutex<Session>>,
304 tokenizer: Tokenizer,
305 dimension: usize,
306 max_length: usize,
307 execution_provider: ExecutionProviderInfo,
308}
309
310impl OnnxEmbedder {
311 pub fn from_directory<P: AsRef<Path>>(model_dir: P) -> Result<Self> {
317 let model_dir = model_dir.as_ref();
318
319 let model_names = [
321 "model.onnx",
322 "model_optimized.onnx",
323 "model-w-mean-pooling.onnx",
324 "model_quantized.onnx",
325 "encoder_model.onnx",
326 ];
327
328 let model_path = model_names
329 .iter()
330 .map(|name| model_dir.join(name))
331 .find(|p| p.exists())
332 .ok_or_else(|| {
333 Error::model_load(format!(
334 "No ONNX model file found in {}. Expected one of: {:?}",
335 model_dir.display(),
336 model_names
337 ))
338 })?;
339
340 let tokenizer_path = model_dir.join("tokenizer.json");
342 if !tokenizer_path.exists() {
343 return Err(Error::model_load(format!(
344 "tokenizer.json not found in {}",
345 model_dir.display()
346 )));
347 }
348
349 Self::new(&model_path, &tokenizer_path, None, 512)
350 }
351
352 pub fn new<P: AsRef<Path>>(
354 model_path: P,
355 tokenizer_path: P,
356 dimension: Option<usize>,
357 max_length: usize,
358 ) -> Result<Self> {
359 let model_path = model_path.as_ref();
360 let tokenizer_path = tokenizer_path.as_ref();
361
362 let tokenizer = Tokenizer::from_file(tokenizer_path)
364 .map_err(|e| Error::model_load(format!("Failed to load tokenizer: {}", e)))?;
365
366 tracing::info!("Loading ONNX model from: {}", model_path.display());
368
369 #[allow(unused_mut)]
370 let mut builder = Session::builder()
371 .map_err(|e| Error::model_load(format!("Failed to create session builder: {}", e)))?;
372
373 #[allow(unused_mut, unused_variables)]
376 let mut _execution_provider = ExecutionProviderInfo::cpu();
377
378 #[cfg(feature = "cuda")]
380 {
381 use ort::execution_providers::CUDAExecutionProvider;
382 tracing::info!("CUDA support enabled, attempting GPU acceleration");
383 builder = builder
384 .with_execution_providers([CUDAExecutionProvider::default().build()])
385 .map_err(|e| Error::model_load(format!("Failed to configure CUDA: {}", e)))?;
386 _execution_provider = ExecutionProviderInfo::cuda(0);
387 }
388
389 #[cfg(feature = "tensorrt")]
390 {
391 use ort::execution_providers::TensorRTExecutionProvider;
392 tracing::info!("TensorRT support enabled, attempting GPU acceleration");
393 builder = builder
394 .with_execution_providers([TensorRTExecutionProvider::default().build()])
395 .map_err(|e| Error::model_load(format!("Failed to configure TensorRT: {}", e)))?;
396 _execution_provider = ExecutionProviderInfo::tensorrt(0);
397 }
398
399 #[cfg(feature = "directml")]
400 {
401 use ort::execution_providers::DirectMLExecutionProvider;
402 tracing::info!("DirectML support enabled, attempting GPU acceleration");
403 builder = builder
404 .with_execution_providers([DirectMLExecutionProvider::default().build()])
405 .map_err(|e| Error::model_load(format!("Failed to configure DirectML: {}", e)))?;
406 _execution_provider = ExecutionProviderInfo::directml(0);
407 }
408
409 #[cfg(feature = "coreml")]
410 {
411 use ort::execution_providers::CoreMLExecutionProvider;
412 tracing::info!("CoreML support enabled, attempting GPU acceleration");
413 builder = builder
414 .with_execution_providers([CoreMLExecutionProvider::default().build()])
415 .map_err(|e| Error::model_load(format!("Failed to configure CoreML: {}", e)))?;
416 _execution_provider = ExecutionProviderInfo::coreml();
417 }
418
419 let session = builder
420 .with_intra_threads(4)
421 .map_err(|e| Error::model_load(format!("Failed to set threads: {}", e)))?
422 .commit_from_file(model_path)
423 .map_err(|e| Error::model_load(format!("Failed to load ONNX model: {}", e)))?;
424
425 let detected_dimension = if model_path.to_string_lossy().contains("jina-code-embeddings-1.5b-ONNX") {
427 1536 } else {
429 dimension.unwrap_or(768) };
431
432 let dimension = detected_dimension;
433
434 let actual_provider = detect_actual_execution_provider(&session);
437 let execution_provider = actual_provider;
438
439 tracing::info!(
440 "Loaded ONNX model (dim={}, max_len={}, provider={})",
441 dimension,
442 max_length,
443 execution_provider.description()
444 );
445
446 Ok(Self {
447 session: Arc::new(Mutex::new(session)),
448 tokenizer,
449 dimension,
450 max_length,
451 execution_provider,
452 })
453 }
454
455 pub fn with_max_length(mut self, max_length: usize) -> Self {
457 self.max_length = max_length;
458 self
459 }
460
461 pub fn execution_provider(&self) -> &ExecutionProviderInfo {
463 &self.execution_provider
464 }
465
466 pub fn is_gpu_accelerated(&self) -> bool {
468 self.execution_provider.is_gpu
469 }
470
471 fn mean_pooling(&self, data: &[f32], shape: &[i64], attention_mask: &[i64]) -> Vec<f32> {
473 if shape.len() != 3 {
474 return data.to_vec();
476 }
477
478 let seq_len = shape[1] as usize;
479 let dim = shape[2] as usize;
480 let mut result = vec![0.0f32; dim];
481 let mut count = 0.0f32;
482
483 for i in 0..seq_len {
484 if i < attention_mask.len() && attention_mask[i] == 1 {
485 for j in 0..dim {
486 result[j] += data[i * dim + j];
487 }
488 count += 1.0;
489 }
490 }
491
492 if count > 0.0 {
493 for val in &mut result {
494 *val /= count;
495 }
496 }
497
498 let norm: f32 = result.iter().map(|x| x * x).sum::<f32>().sqrt();
500 if norm > 0.0 {
501 for val in &mut result {
502 *val /= norm;
503 }
504 }
505
506 result
507 }
508}
509
510impl Embedder for OnnxEmbedder {
511 fn embed(&self, text: &str) -> Result<Vec<f32>> {
512 let results = self.embed_batch(&[text])?;
513 Ok(results.into_iter().next().unwrap_or_default())
514 }
515
516 fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
517 if texts.is_empty() {
518 return Ok(Vec::new());
519 }
520
521 let mut all_embeddings = Vec::new();
522
523 for text in texts {
525 let encoding = self
527 .tokenizer
528 .encode(*text, true)
529 .map_err(|e| Error::embedding(format!("Tokenization failed: {}", e)))?;
530
531 let ids = encoding.get_ids();
532 let mask = encoding.get_attention_mask();
533 let seq_len = ids.len().min(self.max_length);
534
535 let input_ids: Vec<i64> = ids.iter().take(seq_len).map(|&id| id as i64).collect();
537 let attention_mask: Vec<i64> = mask.iter().take(seq_len).map(|&m| m as i64).collect();
538 let position_ids: Vec<i64> = (0..seq_len as i64).collect();
539
540 let input_ids_array = ndarray::Array::from_shape_vec(IxDyn(&[1, seq_len]), input_ids.clone())
542 .map_err(|e| Error::embedding(format!("Failed to create input tensor: {}", e)))?;
543 let input_ids_tensor = Tensor::from_array(input_ids_array)
544 .map_err(|e| Error::embedding(format!("Failed to create input tensor: {}", e)))?;
545
546 let attention_mask_array = ndarray::Array::from_shape_vec(IxDyn(&[1, seq_len]), attention_mask.clone())
547 .map_err(|e| Error::embedding(format!("Failed to create mask tensor: {}", e)))?;
548 let attention_mask_tensor = Tensor::from_array(attention_mask_array)
549 .map_err(|e| Error::embedding(format!("Failed to create mask tensor: {}", e)))?;
550
551 let position_ids_array = ndarray::Array::from_shape_vec(IxDyn(&[1, seq_len]), position_ids)
552 .map_err(|e| Error::embedding(format!("Failed to create position tensor: {}", e)))?;
553 let position_ids_tensor = Tensor::from_array(position_ids_array)
554 .map_err(|e| Error::embedding(format!("Failed to create position tensor: {}", e)))?;
555
556 let mut session = self.session.lock();
558
559 let first_output_name = session.outputs.first()
561 .map(|o| o.name.clone())
562 .unwrap_or_else(|| "output".to_string());
563
564 let outputs = if session.inputs.iter().any(|input| input.name == "position_ids") {
566 session
567 .run(ort::inputs![
568 "input_ids" => input_ids_tensor,
569 "attention_mask" => attention_mask_tensor,
570 "position_ids" => position_ids_tensor,
571 ])
572 .map_err(|e| Error::embedding(format!("Inference failed: {}", e)))?
573 } else if session.inputs.iter().any(|input| input.name == "token_type_ids") {
574 session
575 .run(ort::inputs![
576 "input_ids" => input_ids_tensor,
577 "attention_mask" => attention_mask_tensor,
578 "token_type_ids" => position_ids_tensor, ])
580 .map_err(|e| Error::embedding(format!("Inference failed: {}", e)))?
581 } else {
582 session
583 .run(ort::inputs![
584 "input_ids" => input_ids_tensor,
585 "attention_mask" => attention_mask_tensor,
586 ])
587 .map_err(|e| Error::embedding(format!("Inference failed: {}", e)))?
588 };
589
590 let output = if let Some(val) = outputs.get("last_hidden_state") {
592 val
593 } else if let Some(val) = outputs.get("sentence_embedding") {
594 val
595 } else {
596 outputs.get(&first_output_name)
597 .ok_or_else(|| Error::embedding("No output found".to_string()))?
598 };
599
600 let (output_shape, output_data) = output
601 .try_extract_tensor::<f32>()
602 .map_err(|e| Error::embedding(format!("Failed to extract output: {}", e)))?;
603
604 let shape_vec: Vec<i64> = output_shape.iter().map(|&d| d as i64).collect();
605
606 let embedding = if shape_vec.len() == 2 {
607 let emb: Vec<f32> = output_data.to_vec();
609 normalize_vector(emb)
610 } else if shape_vec.len() == 3 {
611 self.mean_pooling(output_data, &shape_vec, &attention_mask)
613 } else {
614 return Err(Error::embedding(format!(
615 "Unexpected output shape: {:?}",
616 shape_vec
617 )));
618 };
619
620 all_embeddings.push(embedding);
621 }
622
623 Ok(all_embeddings)
624 }
625
626 fn dimension(&self) -> usize {
627 self.dimension
628 }
629
630 fn name(&self) -> &'static str {
631 "onnx-runtime"
632 }
633
634 fn max_sequence_length(&self) -> usize {
635 self.max_length
636 }
637}
638
639fn detect_actual_execution_provider(session: &Session) -> ExecutionProviderInfo {
642 if let Ok(metadata) = session.metadata() {
648 let producer = metadata.producer().unwrap_or_default();
649 let description = metadata.description().unwrap_or_default();
650
651 tracing::debug!("Session metadata - producer: {}, description: {}", producer, description);
652 }
653
654 #[cfg(feature = "cuda")]
658 {
659 if is_cuda_available() {
662 return ExecutionProviderInfo::cuda(0);
663 } else {
664 tracing::warn!("CUDA feature enabled but CUDA runtime not available - falling back to CPU");
665 }
666 }
667
668 #[cfg(feature = "directml")]
669 {
670 if is_directml_available() {
671 return ExecutionProviderInfo::directml(0);
672 } else {
673 tracing::warn!("DirectML feature enabled but not available - falling back to CPU");
674 }
675 }
676
677 #[cfg(feature = "coreml")]
678 {
679 return ExecutionProviderInfo::coreml();
680 }
681
682 ExecutionProviderInfo::cpu()
683}
684
685#[cfg(feature = "cuda")]
687fn is_cuda_available() -> bool {
688 #[cfg(target_os = "windows")]
693 {
694 use std::path::Path;
695
696 let cuda_paths = [
698 "C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA",
699 "C:\\CUDA",
700 ];
701
702 for base in cuda_paths {
703 let path = Path::new(base);
704 if path.exists() {
705 for entry in std::fs::read_dir(path).ok().into_iter().flatten() {
707 if let Ok(entry) = entry {
708 let name = entry.file_name();
709 let name_str = name.to_string_lossy();
710 if name_str.starts_with("v") && entry.path().is_dir() {
711 let cudart = entry.path().join("bin").join("cudart64_12.dll");
712 let cudart_11 = entry.path().join("bin").join("cudart64_11.dll");
713 if cudart.exists() || cudart_11.exists() {
714 tracing::info!("Found CUDA runtime at: {}", entry.path().display());
715 return true;
716 }
717 }
718 }
719 }
720 }
721 }
722
723 if let Ok(path) = std::env::var("PATH") {
725 for dir in path.split(';') {
726 let cudart = Path::new(dir).join("cudart64_12.dll");
727 let cudart_11 = Path::new(dir).join("cudart64_11.dll");
728 if cudart.exists() || cudart_11.exists() {
729 tracing::info!("Found CUDA runtime in PATH: {}", dir);
730 return true;
731 }
732 }
733 }
734
735 if let Ok(output) = std::process::Command::new("nvidia-smi").output() {
737 if output.status.success() {
738 tracing::info!("nvidia-smi available, assuming CUDA works");
739 return true;
740 }
741 }
742
743 false
744 }
745
746 #[cfg(not(target_os = "windows"))]
747 {
748 if let Ok(output) = std::process::Command::new("nvidia-smi").output() {
750 return output.status.success();
751 }
752 false
753 }
754}
755
756#[cfg(feature = "directml")]
757fn is_directml_available() -> bool {
758 #[cfg(target_os = "windows")]
760 {
761 true
762 }
763 #[cfg(not(target_os = "windows"))]
764 {
765 false
766 }
767}
768
769fn normalize_vector(mut v: Vec<f32>) -> Vec<f32> {
771 let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
772 if norm > 0.0 {
773 for x in &mut v {
774 *x /= norm;
775 }
776 }
777 v
778}
779
780pub struct HashEmbedder {
785 dimension: usize,
786}
787
788impl HashEmbedder {
789 pub fn new(dimension: usize) -> Self {
791 Self { dimension }
792 }
793
794 fn hash_to_embedding(&self, text: &str) -> Vec<f32> {
795 use std::collections::hash_map::DefaultHasher;
796 use std::hash::{Hash, Hasher};
797
798 let mut result = vec![0.0f32; self.dimension];
799 let hash = {
800 let mut hasher = DefaultHasher::new();
801 text.hash(&mut hasher);
802 hasher.finish()
803 };
804
805 let mut seed = hash;
807 for val in result.iter_mut() {
808 seed = seed.wrapping_mul(6364136223846793005).wrapping_add(1);
809 *val = ((seed >> 32) as f32 / u32::MAX as f32) * 2.0 - 1.0;
810 }
811
812 normalize_vector(result)
813 }
814}
815
816impl Embedder for HashEmbedder {
817 fn embed(&self, text: &str) -> Result<Vec<f32>> {
818 Ok(self.hash_to_embedding(text))
819 }
820
821 fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
822 Ok(texts.iter().map(|t| self.hash_to_embedding(t)).collect())
823 }
824
825 fn dimension(&self) -> usize {
826 self.dimension
827 }
828
829 fn name(&self) -> &'static str {
830 "hash"
831 }
832}
833
834pub struct JinaCodeEmbedder {
869 inner: OnnxEmbedder,
871 task: EmbeddingTask,
873 output_dimension: MatryoshkaDimension,
875 mode: EmbeddingMode,
877}
878
879#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
885pub enum EmbeddingMode {
886 #[default]
889 Passage,
890 Query,
893}
894
895impl JinaCodeEmbedder {
896 pub const DEFAULT_MAX_LENGTH: usize = 32768;
898 pub const DEFAULT_DIMENSION: usize = 1536;
900
901 pub fn from_directory<P: AsRef<Path>>(model_dir: P) -> Result<Self> {
907 let inner = OnnxEmbedder::from_directory(&model_dir)?;
908 Ok(Self {
909 inner,
910 task: EmbeddingTask::default(),
911 output_dimension: MatryoshkaDimension::default(),
912 mode: EmbeddingMode::default(),
913 })
914 }
915
916 pub fn from_onnx_embedder(inner: OnnxEmbedder) -> Self {
918 Self {
919 inner,
920 task: EmbeddingTask::default(),
921 output_dimension: MatryoshkaDimension::default(),
922 mode: EmbeddingMode::default(),
923 }
924 }
925
926 pub fn with_task(mut self, task: EmbeddingTask) -> Self {
928 self.task = task;
929 self
930 }
931
932 pub fn with_dimension(mut self, dimension: MatryoshkaDimension) -> Self {
937 self.output_dimension = dimension;
938 self
939 }
940
941 pub fn with_max_length(mut self, max_length: usize) -> Self {
943 self.inner.max_length = max_length;
944 self
945 }
946
947 pub fn with_mode(mut self, mode: EmbeddingMode) -> Self {
952 self.mode = mode;
953 self
954 }
955
956 pub fn mode(&self) -> EmbeddingMode {
958 self.mode
959 }
960
961 pub fn task(&self) -> EmbeddingTask {
963 self.task
964 }
965
966 pub fn output_dimension(&self) -> MatryoshkaDimension {
968 self.output_dimension
969 }
970
971 pub fn execution_provider(&self) -> &ExecutionProviderInfo {
973 self.inner.execution_provider()
974 }
975
976 pub fn is_gpu_accelerated(&self) -> bool {
978 self.inner.is_gpu_accelerated()
979 }
980
981 pub fn embed_query(&self, text: &str) -> Result<Vec<f32>> {
995 let prefixed = format!("{}{}", self.task.query_prefix(), text);
996 let embedding = self.inner.embed(&prefixed)?;
997 Ok(self.truncate_embedding(embedding))
998 }
999
1000 pub fn embed_passage(&self, text: &str) -> Result<Vec<f32>> {
1009 let prefixed = format!("{}{}", self.task.passage_prefix(), text);
1010 let embedding = self.inner.embed(&prefixed)?;
1011 Ok(self.truncate_embedding(embedding))
1012 }
1013
1014 pub fn embed_queries(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
1016 let prefixed: Vec<String> = texts.iter()
1017 .map(|t| format!("{}{}", self.task.query_prefix(), t))
1018 .collect();
1019 let refs: Vec<&str> = prefixed.iter().map(|s| s.as_str()).collect();
1020
1021 let embeddings = self.inner.embed_batch(&refs)?;
1022 Ok(embeddings.into_iter()
1023 .map(|e| self.truncate_embedding(e))
1024 .collect())
1025 }
1026
1027 pub fn embed_passages(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
1031 let prefixed: Vec<String> = texts.iter()
1032 .map(|t| format!("{}{}", self.task.passage_prefix(), t))
1033 .collect();
1034 let refs: Vec<&str> = prefixed.iter().map(|s| s.as_str()).collect();
1035
1036 let embeddings = self.inner.embed_batch(&refs)?;
1037 Ok(embeddings.into_iter()
1038 .map(|e| self.truncate_embedding(e))
1039 .collect())
1040 }
1041
1042 fn apply_task_prefix(&self, text: &str) -> String {
1050 let prefix = match self.mode {
1051 EmbeddingMode::Passage => self.task.passage_prefix(),
1052 EmbeddingMode::Query => self.task.query_prefix(),
1053 };
1054 format!("{}{}", prefix, text)
1055 }
1056
1057 fn truncate_embedding(&self, embedding: Vec<f32>) -> Vec<f32> {
1059 let target_dim = self.output_dimension.value();
1060 if embedding.len() <= target_dim {
1061 normalize_vector(embedding)
1062 } else {
1063 let truncated: Vec<f32> = embedding.into_iter().take(target_dim).collect();
1064 normalize_vector(truncated)
1065 }
1066 }
1067
1068 fn last_token_pooling(&self, data: &[f32], shape: &[i64], attention_mask: &[i64]) -> Vec<f32> {
1070 if shape.len() != 3 {
1071 return data.to_vec();
1073 }
1074
1075 let seq_len = shape[1] as usize;
1076 let dim = shape[2] as usize;
1077
1078 let last_valid_pos = attention_mask.iter()
1080 .enumerate()
1081 .rev()
1082 .find(|(_, &mask)| mask == 1)
1083 .map(|(i, _)| i)
1084 .unwrap_or(seq_len.saturating_sub(1));
1085
1086 let start = last_valid_pos * dim;
1088 let end = start + dim;
1089
1090 if end <= data.len() {
1091 let result: Vec<f32> = data[start..end].to_vec();
1092 normalize_vector(result)
1093 } else {
1094 self.inner.mean_pooling(data, shape, attention_mask)
1096 }
1097 }
1098}
1099
1100impl Embedder for JinaCodeEmbedder {
1101 fn embed(&self, text: &str) -> Result<Vec<f32>> {
1102 let prefixed = self.apply_task_prefix(text);
1103 let embedding = self.inner.embed(&prefixed)?;
1104 Ok(self.truncate_embedding(embedding))
1105 }
1106
1107 fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
1108 let prefixed: Vec<String> = texts.iter()
1109 .map(|t| self.apply_task_prefix(t))
1110 .collect();
1111 let refs: Vec<&str> = prefixed.iter().map(|s| s.as_str()).collect();
1112
1113 let embeddings = self.inner.embed_batch(&refs)?;
1114 Ok(embeddings.into_iter()
1115 .map(|e| self.truncate_embedding(e))
1116 .collect())
1117 }
1118
1119 fn embed_for_query(&self, text: &str) -> Result<Vec<f32>> {
1121 let prefixed = format!("{}{}", self.task.query_prefix(), text);
1123 let embedding = self.inner.embed(&prefixed)?;
1124 Ok(self.truncate_embedding(embedding))
1125 }
1126
1127 fn dimension(&self) -> usize {
1128 self.output_dimension.value()
1129 }
1130
1131 fn name(&self) -> &'static str {
1132 "jina-code-1.5b"
1133 }
1134
1135 fn max_sequence_length(&self) -> usize {
1136 self.inner.max_sequence_length()
1137 }
1138}
1139
1140#[derive(Debug, Clone, Serialize, Deserialize)]
1146pub struct JinaCodeConfig {
1147 pub model_path: std::path::PathBuf,
1149 pub task: EmbeddingTask,
1151 pub dimension: MatryoshkaDimension,
1153 pub max_length: usize,
1155}
1156
1157impl Default for JinaCodeConfig {
1158 fn default() -> Self {
1159 Self {
1160 model_path: std::path::PathBuf::new(),
1161 task: EmbeddingTask::default(),
1162 dimension: MatryoshkaDimension::default(),
1163 max_length: JinaCodeEmbedder::DEFAULT_MAX_LENGTH,
1164 }
1165 }
1166}
1167
1168impl JinaCodeConfig {
1169 pub fn from_directory<P: AsRef<Path>>(path: P) -> Self {
1171 Self {
1172 model_path: path.as_ref().to_path_buf(),
1173 ..Default::default()
1174 }
1175 }
1176
1177 pub fn with_task(mut self, task: EmbeddingTask) -> Self {
1179 self.task = task;
1180 self
1181 }
1182
1183 pub fn with_dimension(mut self, dimension: MatryoshkaDimension) -> Self {
1185 self.dimension = dimension;
1186 self
1187 }
1188
1189 pub fn with_max_length(mut self, max_length: usize) -> Self {
1191 self.max_length = max_length;
1192 self
1193 }
1194
1195 pub fn load(&self) -> Result<JinaCodeEmbedder> {
1197 JinaCodeEmbedder::from_directory(&self.model_path)
1198 .map(|e| e
1199 .with_task(self.task)
1200 .with_dimension(self.dimension)
1201 .with_max_length(self.max_length))
1202 }
1203}
1204
1205#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
1211pub struct ModelConfig {
1212 pub model_path: std::path::PathBuf,
1214 pub tokenizer_path: Option<std::path::PathBuf>,
1216 pub dimension: Option<usize>,
1218 pub max_length: usize,
1220}
1221
1222impl Default for ModelConfig {
1223 fn default() -> Self {
1224 Self {
1225 model_path: std::path::PathBuf::new(),
1226 tokenizer_path: None,
1227 dimension: None,
1228 max_length: 512,
1229 }
1230 }
1231}
1232
1233impl ModelConfig {
1234 pub fn from_directory<P: AsRef<Path>>(path: P) -> Self {
1236 Self {
1237 model_path: path.as_ref().to_path_buf(),
1238 tokenizer_path: None,
1239 dimension: None,
1240 max_length: 512,
1241 }
1242 }
1243
1244 pub fn with_max_length(mut self, max_length: usize) -> Self {
1246 self.max_length = max_length;
1247 self
1248 }
1249
1250 pub fn with_dimension(mut self, dimension: usize) -> Self {
1252 self.dimension = Some(dimension);
1253 self
1254 }
1255
1256 pub fn load(&self) -> Result<OnnxEmbedder> {
1258 if self.model_path.is_dir() {
1259 let mut embedder = OnnxEmbedder::from_directory(&self.model_path)?;
1260 embedder.max_length = self.max_length;
1261 if let Some(dim) = self.dimension {
1262 embedder.dimension = dim;
1263 }
1264 Ok(embedder)
1265 } else {
1266 let tokenizer_path = self
1267 .tokenizer_path
1268 .clone()
1269 .unwrap_or_else(|| self.model_path.with_file_name("tokenizer.json"));
1270
1271 OnnxEmbedder::new(
1272 &self.model_path,
1273 &tokenizer_path,
1274 self.dimension,
1275 self.max_length,
1276 )
1277 }
1278 }
1279}
1280
1281#[cfg(test)]
1282mod tests {
1283 use super::*;
1284
1285 #[test]
1286 fn test_hash_embedder() {
1287 let embedder = HashEmbedder::new(384);
1288
1289 let embedding = embedder.embed("test code").unwrap();
1290 assert_eq!(embedding.len(), 384);
1291
1292 let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
1294 assert!((norm - 1.0).abs() < 0.01);
1295
1296 let embedding2 = embedder.embed("test code").unwrap();
1298 assert_eq!(embedding, embedding2);
1299
1300 let embedding3 = embedder.embed("other code").unwrap();
1302 assert_ne!(embedding, embedding3);
1303 }
1304
1305 #[test]
1306 fn test_batch_embedding() {
1307 let embedder = HashEmbedder::new(128);
1308
1309 let texts = vec!["hello", "world", "test"];
1310 let embeddings = embedder.embed_batch(&texts).unwrap();
1311
1312 assert_eq!(embeddings.len(), 3);
1313 for emb in &embeddings {
1314 assert_eq!(emb.len(), 128);
1315 }
1316 }
1317}