1use crate::cross_modal_embeddings::{
42 AudioData, AudioEncoder, CrossModalConfig, CrossModalEncoder, GraphData, GraphEncoder,
43 ImageData, ImageEncoder, ImageFormat, Modality, ModalityData, MultiModalContent, TextEncoder,
44 VideoData, VideoEncoder,
45};
46use crate::Vector;
47use crate::VectorStore;
48use anyhow::{anyhow, Result};
49use parking_lot::RwLock;
50use serde::{Deserialize, Serialize};
51use std::collections::HashMap;
52use std::sync::Arc;
53
54pub struct MultiModalSearchEngine {
56 config: MultiModalConfig,
57 encoder: Arc<CrossModalEncoder>,
58 vector_store: Arc<RwLock<VectorStore>>,
59 modality_stores: HashMap<Modality, Arc<RwLock<VectorStore>>>,
60 query_cache: Arc<RwLock<HashMap<String, Vec<SearchResult>>>>,
61 total_indexed: Arc<RwLock<usize>>,
62}
63
64#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct MultiModalConfig {
67 pub cross_modal_config: CrossModalConfig,
69 pub use_modality_specific_indices: bool,
71 pub enable_caching: bool,
73 pub cache_size_limit: usize,
75 pub search_strategy: SearchStrategy,
77 pub enable_query_expansion: bool,
79 pub query_expansion_factor: f32,
81}
82
83impl Default for MultiModalConfig {
84 fn default() -> Self {
85 Self {
86 cross_modal_config: CrossModalConfig::default(),
87 use_modality_specific_indices: true,
88 enable_caching: true,
89 cache_size_limit: 1000,
90 search_strategy: SearchStrategy::HybridFusion,
91 enable_query_expansion: true,
92 query_expansion_factor: 1.5,
93 }
94 }
95}
96
97#[derive(Debug, Clone, Serialize, Deserialize)]
99pub enum SearchStrategy {
100 JointSpaceOnly,
102 ModalitySpecific,
104 HybridFusion,
106 Adaptive,
108}
109
110#[derive(Debug, Clone)]
112pub struct MultiModalQuery {
113 pub modalities: Vec<QueryModality>,
114 pub weights: Option<HashMap<Modality, f32>>,
115 pub filters: Vec<QueryFilter>,
116 pub metadata: HashMap<String, String>,
117}
118
119#[derive(Debug, Clone)]
121pub enum QueryModality {
122 Text(String),
123 Image(Vec<u8>),
124 Audio(Vec<f32>, u32), Video(Vec<Vec<u8>>), Embedding(Vector), }
128
129#[derive(Debug, Clone, Serialize, Deserialize)]
131pub struct QueryFilter {
132 pub field: String,
133 pub operator: FilterOperator,
134 pub value: String,
135}
136
137#[derive(Debug, Clone, Serialize, Deserialize)]
138pub enum FilterOperator {
139 Equals,
140 NotEquals,
141 Contains,
142 GreaterThan,
143 LessThan,
144 Regex,
145}
146
147#[derive(Debug, Clone, Serialize, Deserialize)]
149pub struct SearchResult {
150 pub id: String,
151 pub score: f32,
152 pub modality: Modality,
153 pub metadata: HashMap<String, String>,
154 pub embedding: Option<Vec<f32>>,
155 pub modality_scores: HashMap<Modality, f32>,
156}
157
158impl MultiModalQuery {
159 pub fn text(text: impl Into<String>) -> Self {
161 Self {
162 modalities: vec![QueryModality::Text(text.into())],
163 weights: None,
164 filters: Vec::new(),
165 metadata: HashMap::new(),
166 }
167 }
168
169 pub fn image(image_data: Vec<u8>) -> Self {
171 Self {
172 modalities: vec![QueryModality::Image(image_data)],
173 weights: None,
174 filters: Vec::new(),
175 metadata: HashMap::new(),
176 }
177 }
178
179 pub fn audio(samples: Vec<f32>, sample_rate: u32) -> Self {
181 Self {
182 modalities: vec![QueryModality::Audio(samples, sample_rate)],
183 weights: None,
184 filters: Vec::new(),
185 metadata: HashMap::new(),
186 }
187 }
188
189 pub fn hybrid(modalities: Vec<QueryModality>) -> Self {
191 Self {
192 modalities,
193 weights: None,
194 filters: Vec::new(),
195 metadata: HashMap::new(),
196 }
197 }
198
199 pub fn with_filter(mut self, filter: QueryFilter) -> Self {
201 self.filters.push(filter);
202 self
203 }
204
205 pub fn with_weights(mut self, weights: HashMap<Modality, f32>) -> Self {
207 self.weights = Some(weights);
208 self
209 }
210
211 pub fn with_metadata(mut self, key: String, value: String) -> Self {
213 self.metadata.insert(key, value);
214 self
215 }
216}
217
218impl MultiModalSearchEngine {
219 pub fn new_default() -> Result<Self> {
221 Self::new(MultiModalConfig::default())
222 }
223
224 pub fn new(config: MultiModalConfig) -> Result<Self> {
226 let text_encoder = Box::new(ProductionTextEncoder::new(
228 config.cross_modal_config.joint_embedding_dim,
229 )?);
230 let image_encoder = Box::new(ProductionImageEncoder::new(
231 config.cross_modal_config.joint_embedding_dim,
232 )?);
233 let audio_encoder = Box::new(ProductionAudioEncoder::new(
234 config.cross_modal_config.joint_embedding_dim,
235 )?);
236 let video_encoder = Box::new(ProductionVideoEncoder::new(
237 config.cross_modal_config.joint_embedding_dim,
238 )?);
239 let graph_encoder = Box::new(ProductionGraphEncoder::new(
240 config.cross_modal_config.joint_embedding_dim,
241 )?);
242
243 let encoder = Arc::new(CrossModalEncoder::new(
244 config.cross_modal_config.clone(),
245 text_encoder,
246 image_encoder,
247 audio_encoder,
248 video_encoder,
249 graph_encoder,
250 ));
251
252 let vector_store = Arc::new(RwLock::new(VectorStore::new()));
254
255 let mut modality_stores = HashMap::new();
257 if config.use_modality_specific_indices {
258 for modality in &[
259 Modality::Text,
260 Modality::Image,
261 Modality::Audio,
262 Modality::Video,
263 ] {
264 let store = Arc::new(RwLock::new(VectorStore::new()));
265 modality_stores.insert(*modality, store);
266 }
267 }
268
269 Ok(Self {
270 config,
271 encoder,
272 vector_store,
273 modality_stores,
274 query_cache: Arc::new(RwLock::new(HashMap::new())),
275 total_indexed: Arc::new(RwLock::new(0)),
276 })
277 }
278
279 pub fn index_content(&self, id: String, content: MultiModalContent) -> Result<()> {
281 let embedding = self.encoder.encode(&content)?;
283
284 {
286 let mut store = self.vector_store.write();
287 store.index_vector(id.clone(), embedding.clone())?;
288 }
289
290 if self.config.use_modality_specific_indices {
292 for (modality, data) in &content.modalities {
293 if let Some(store) = self.modality_stores.get(modality) {
294 let modality_embedding = self.encode_modality(*modality, data)?;
296
297 let mut store = store.write();
298 store.index_vector(id.clone(), modality_embedding)?;
299 }
300 }
301 }
302
303 *self.total_indexed.write() += 1;
305
306 Ok(())
307 }
308
309 pub fn search(&self, query: &MultiModalQuery, k: usize) -> Result<Vec<SearchResult>> {
311 if self.config.enable_caching {
313 let cache_key = self.compute_cache_key(query);
314 if let Some(cached_results) = self.query_cache.read().get(&cache_key) {
315 return Ok(cached_results.clone());
316 }
317 }
318
319 let results = match self.config.search_strategy {
321 SearchStrategy::JointSpaceOnly => self.search_joint_space(query, k)?,
322 SearchStrategy::ModalitySpecific => self.search_modality_specific(query, k)?,
323 SearchStrategy::HybridFusion => self.search_hybrid(query, k)?,
324 SearchStrategy::Adaptive => self.search_adaptive(query, k)?,
325 };
326
327 let filtered_results = self.apply_filters(&results, &query.filters)?;
329
330 if self.config.enable_caching {
332 let cache_key = self.compute_cache_key(query);
333 let mut cache = self.query_cache.write();
334
335 if cache.len() >= self.config.cache_size_limit {
337 if let Some(first_key) = cache.keys().next().cloned() {
339 cache.remove(&first_key);
340 }
341 }
342
343 cache.insert(cache_key, filtered_results.clone());
344 }
345
346 Ok(filtered_results)
347 }
348
349 fn search_joint_space(&self, query: &MultiModalQuery, k: usize) -> Result<Vec<SearchResult>> {
351 let query_content = self.query_to_content(query)?;
353
354 let query_embedding = self.encoder.encode(&query_content)?;
356
357 let store = self.vector_store.read();
359 let results = store.similarity_search_vector(&query_embedding, k)?;
360
361 Ok(results
363 .into_iter()
364 .map(|(id, score)| SearchResult {
365 id,
366 score,
367 modality: Modality::Text, metadata: HashMap::new(),
369 embedding: None,
370 modality_scores: HashMap::new(),
371 })
372 .collect())
373 }
374
375 fn search_modality_specific(
377 &self,
378 query: &MultiModalQuery,
379 k: usize,
380 ) -> Result<Vec<SearchResult>> {
381 let mut all_results: Vec<SearchResult> = Vec::new();
382 let mut modality_results: HashMap<Modality, Vec<SearchResult>> = HashMap::new();
383
384 for query_modality in &query.modalities {
386 let (modality, data) = match query_modality {
387 QueryModality::Text(text) => (Modality::Text, ModalityData::Text(text.clone())),
388 QueryModality::Image(img_data) => {
389 let image_data = self.parse_image_data(img_data)?;
390 (Modality::Image, ModalityData::Image(image_data))
391 }
392 QueryModality::Audio(samples, rate) => {
393 let audio_data = AudioData {
394 samples: samples.clone(),
395 sample_rate: *rate,
396 channels: 1,
397 duration: samples.len() as f32 / *rate as f32,
398 features: None,
399 };
400 (Modality::Audio, ModalityData::Audio(audio_data))
401 }
402 QueryModality::Embedding(embedding) => {
403 let store = self.vector_store.read();
405 let results = store.similarity_search_vector(embedding, k)?;
406 all_results.extend(results.into_iter().map(|(id, score)| SearchResult {
407 id,
408 score,
409 modality: Modality::Numeric,
410 metadata: HashMap::new(),
411 embedding: None,
412 modality_scores: HashMap::new(),
413 }));
414 continue;
415 }
416 QueryModality::Video(_frames) => {
417 continue; }
420 };
421
422 if let Some(store) = self.modality_stores.get(&modality) {
423 let embedding = self.encode_modality(modality, &data)?;
424
425 let store = store.read();
426 let results = store.similarity_search_vector(&embedding, k)?;
427
428 let search_results: Vec<SearchResult> = results
429 .into_iter()
430 .map(|(id, score)| SearchResult {
431 id,
432 score,
433 modality,
434 metadata: HashMap::new(),
435 embedding: None,
436 modality_scores: HashMap::new(),
437 })
438 .collect();
439
440 modality_results.insert(modality, search_results);
441 }
442 }
443
444 let fused_results = self.fuse_modality_results(modality_results, query, k)?;
446
447 Ok(fused_results)
448 }
449
450 fn search_hybrid(&self, query: &MultiModalQuery, k: usize) -> Result<Vec<SearchResult>> {
452 let joint_results = self.search_joint_space(query, k * 2)?;
453 let modality_results = self.search_modality_specific(query, k * 2)?;
454
455 let fused = self.fuse_search_results(vec![joint_results, modality_results], &[0.6, 0.4])?;
457
458 Ok(fused.into_iter().take(k).collect())
460 }
461
462 fn search_adaptive(&self, query: &MultiModalQuery, k: usize) -> Result<Vec<SearchResult>> {
464 let num_modalities = query.modalities.len();
466
467 if num_modalities == 1 {
469 return self.search_modality_specific(query, k);
470 }
471
472 self.search_hybrid(query, k)
474 }
475
476 fn encode_modality(&self, _modality: Modality, data: &ModalityData) -> Result<Vector> {
478 let mut content_map = HashMap::new();
480
481 match data {
482 ModalityData::Text(_text) => {
483 content_map.insert(Modality::Text, data.clone());
484 }
485 ModalityData::Image(_image) => {
486 content_map.insert(Modality::Image, data.clone());
487 }
488 ModalityData::Audio(_audio) => {
489 content_map.insert(Modality::Audio, data.clone());
490 }
491 ModalityData::Video(_video) => {
492 content_map.insert(Modality::Video, data.clone());
493 }
494 ModalityData::Graph(_graph) => {
495 content_map.insert(Modality::Graph, data.clone());
496 }
497 ModalityData::Numeric(values) => {
498 return Ok(Vector::new(values.clone()));
500 }
501 ModalityData::Raw(_) => {
502 return Err(anyhow!("Raw data encoding not supported"));
503 }
504 }
505
506 let content = MultiModalContent {
507 modalities: content_map,
508 metadata: HashMap::new(),
509 temporal_info: None,
510 spatial_info: None,
511 };
512
513 self.encoder.encode(&content)
514 }
515
516 fn query_to_content(&self, query: &MultiModalQuery) -> Result<MultiModalContent> {
518 let mut modalities = HashMap::new();
519
520 for query_modality in &query.modalities {
521 match query_modality {
522 QueryModality::Text(text) => {
523 modalities.insert(Modality::Text, ModalityData::Text(text.clone()));
524 }
525 QueryModality::Image(img_data) => {
526 let image_data = self.parse_image_data(img_data)?;
527 modalities.insert(Modality::Image, ModalityData::Image(image_data));
528 }
529 QueryModality::Audio(samples, rate) => {
530 let audio_data = AudioData {
531 samples: samples.clone(),
532 sample_rate: *rate,
533 channels: 1,
534 duration: samples.len() as f32 / *rate as f32,
535 features: None,
536 };
537 modalities.insert(Modality::Audio, ModalityData::Audio(audio_data));
538 }
539 QueryModality::Embedding(_) => {
540 }
542 QueryModality::Video(frames) => {
543 let video_frames: Result<Vec<ImageData>> =
544 frames.iter().map(|f| self.parse_image_data(f)).collect();
545
546 let video_data = VideoData {
547 frames: video_frames?,
548 audio: None,
549 fps: 30.0,
550 duration: frames.len() as f32 / 30.0,
551 keyframes: vec![0],
552 };
553 modalities.insert(Modality::Video, ModalityData::Video(video_data));
554 }
555 }
556 }
557
558 Ok(MultiModalContent {
559 modalities,
560 metadata: query.metadata.clone(),
561 temporal_info: None,
562 spatial_info: None,
563 })
564 }
565
566 fn parse_image_data(&self, data: &[u8]) -> Result<ImageData> {
568 #[cfg(feature = "images")]
570 {
571 use image::GenericImageView;
572
573 let img = image::load_from_memory(data)
574 .map_err(|e| anyhow!("Failed to decode image: {}", e))?;
575
576 let (width, height) = img.dimensions();
577 let rgb_img = img.to_rgb8();
578 let raw_data = rgb_img.into_raw();
579
580 Ok(ImageData {
581 data: raw_data,
582 width,
583 height,
584 channels: 3,
585 format: ImageFormat::RGB,
586 features: None,
587 })
588 }
589
590 #[cfg(not(feature = "images"))]
591 {
592 Ok(ImageData {
594 data: data.to_vec(),
595 width: 0,
596 height: 0,
597 channels: 3,
598 format: ImageFormat::RGB,
599 features: None,
600 })
601 }
602 }
603
604 fn fuse_modality_results(
606 &self,
607 modality_results: HashMap<Modality, Vec<SearchResult>>,
608 query: &MultiModalQuery,
609 k: usize,
610 ) -> Result<Vec<SearchResult>> {
611 let mut score_map: HashMap<String, (f32, SearchResult)> = HashMap::new();
613
614 for (modality, results) in modality_results {
615 let weight = query
616 .weights
617 .as_ref()
618 .and_then(|w| w.get(&modality))
619 .copied()
620 .unwrap_or(1.0);
621
622 for (rank, result) in results.into_iter().enumerate() {
623 let rrf_score = weight / (60.0 + rank as f32 + 1.0);
624
625 score_map
626 .entry(result.id.clone())
627 .and_modify(|(score, existing)| {
628 *score += rrf_score;
629 existing.modality_scores.insert(modality, result.score);
630 })
631 .or_insert_with(|| {
632 let mut updated_result = result.clone();
633 updated_result
634 .modality_scores
635 .insert(modality, result.score);
636 (rrf_score, updated_result)
637 });
638 }
639 }
640
641 let mut fused_results: Vec<(f32, SearchResult)> = score_map.into_values().collect();
643 fused_results.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap());
644
645 Ok(fused_results
647 .into_iter()
648 .take(k)
649 .map(|(score, mut result)| {
650 result.score = score;
651 result
652 })
653 .collect())
654 }
655
656 fn fuse_search_results(
658 &self,
659 result_sets: Vec<Vec<SearchResult>>,
660 weights: &[f32],
661 ) -> Result<Vec<SearchResult>> {
662 if result_sets.len() != weights.len() {
663 return Err(anyhow!("Weights length must match result sets length"));
664 }
665
666 let mut score_map: HashMap<String, (f32, SearchResult)> = HashMap::new();
667
668 for (results, &weight) in result_sets.into_iter().zip(weights.iter()) {
669 for (rank, result) in results.into_iter().enumerate() {
670 let rrf_score = weight / (60.0 + rank as f32 + 1.0);
671
672 score_map
673 .entry(result.id.clone())
674 .and_modify(|(score, _)| *score += rrf_score)
675 .or_insert_with(|| (rrf_score, result));
676 }
677 }
678
679 let mut fused_results: Vec<(f32, SearchResult)> = score_map.into_values().collect();
681 fused_results.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap());
682
683 Ok(fused_results
685 .into_iter()
686 .map(|(score, mut result)| {
687 result.score = score;
688 result
689 })
690 .collect())
691 }
692
693 fn apply_filters(
695 &self,
696 results: &[SearchResult],
697 filters: &[QueryFilter],
698 ) -> Result<Vec<SearchResult>> {
699 if filters.is_empty() {
700 return Ok(results.to_vec());
701 }
702
703 let filtered: Vec<SearchResult> = results
704 .iter()
705 .filter(|result| self.matches_filters(result, filters))
706 .cloned()
707 .collect();
708
709 Ok(filtered)
710 }
711
712 fn matches_filters(&self, result: &SearchResult, filters: &[QueryFilter]) -> bool {
714 filters.iter().all(|filter| {
715 if let Some(value) = result.metadata.get(&filter.field) {
716 match filter.operator {
717 FilterOperator::Equals => value == &filter.value,
718 FilterOperator::NotEquals => value != &filter.value,
719 FilterOperator::Contains => value.contains(&filter.value),
720 FilterOperator::GreaterThan => value > &filter.value,
721 FilterOperator::LessThan => value < &filter.value,
722 FilterOperator::Regex => {
723 if let Ok(re) = regex::Regex::new(&filter.value) {
724 re.is_match(value)
725 } else {
726 false
727 }
728 }
729 }
730 } else {
731 false
732 }
733 })
734 }
735
736 fn compute_cache_key(&self, query: &MultiModalQuery) -> String {
738 use std::collections::hash_map::DefaultHasher;
739 use std::hash::{Hash, Hasher};
740
741 let mut hasher = DefaultHasher::new();
742
743 for modality in &query.modalities {
745 match modality {
746 QueryModality::Text(text) => text.hash(&mut hasher),
747 QueryModality::Image(data) => data.hash(&mut hasher),
748 QueryModality::Audio(samples, rate) => {
749 samples.len().hash(&mut hasher);
750 rate.hash(&mut hasher);
751 }
752 QueryModality::Video(frames) => frames.len().hash(&mut hasher),
753 QueryModality::Embedding(emb) => emb.dimensions.hash(&mut hasher),
754 }
755 }
756
757 format!("{:x}", hasher.finish())
758 }
759
760 pub fn get_statistics(&self) -> MultiModalStatistics {
762 let num_vectors = *self.total_indexed.read();
764
765 let mut modality_counts = HashMap::new();
766 for modality in self.modality_stores.keys() {
767 modality_counts.insert(*modality, 0);
769 }
770
771 MultiModalStatistics {
772 total_vectors: num_vectors,
773 modality_counts,
774 cache_size: self.query_cache.read().len(),
775 cache_hit_rate: 0.0, }
777 }
778}
779
780#[derive(Debug, Clone, Serialize, Deserialize)]
782pub struct MultiModalStatistics {
783 pub total_vectors: usize,
784 pub modality_counts: HashMap<Modality, usize>,
785 pub cache_size: usize,
786 pub cache_hit_rate: f32,
787}
788
789pub struct ProductionTextEncoder {
793 embedding_dim: usize,
794 vocab_size: usize,
795}
796
797impl ProductionTextEncoder {
798 pub fn new(embedding_dim: usize) -> Result<Self> {
799 Ok(Self {
800 embedding_dim,
801 vocab_size: 10000,
802 })
803 }
804
805 fn tokenize(&self, text: &str) -> Vec<String> {
807 text.to_lowercase()
808 .split_whitespace()
809 .map(|s| s.to_string())
810 .collect()
811 }
812
813 fn compute_embedding(&self, tokens: &[String]) -> Vec<f32> {
815 use std::collections::HashMap;
816
817 let mut freq_map = HashMap::new();
819 for token in tokens {
820 *freq_map.entry(token.clone()).or_insert(0) += 1;
821 }
822
823 let mut embedding = vec![0.0f32; self.embedding_dim];
825 for (token, count) in freq_map {
826 let hash = Self::hash_token(&token);
827 let index = (hash % self.embedding_dim as u64) as usize;
828 embedding[index] += count as f32 / tokens.len() as f32;
829 }
830
831 let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
833 if norm > 0.0 {
834 embedding.iter_mut().for_each(|x| *x /= norm);
835 }
836
837 embedding
838 }
839
840 fn hash_token(token: &str) -> u64 {
841 use std::collections::hash_map::DefaultHasher;
842 use std::hash::{Hash, Hasher};
843
844 let mut hasher = DefaultHasher::new();
845 token.hash(&mut hasher);
846 hasher.finish()
847 }
848}
849
850impl TextEncoder for ProductionTextEncoder {
851 fn encode(&self, text: &str) -> Result<Vector> {
852 let tokens = self.tokenize(text);
853 let embedding = self.compute_embedding(&tokens);
854 Ok(Vector::new(embedding))
855 }
856
857 fn encode_batch(&self, texts: &[String]) -> Result<Vec<Vector>> {
858 texts.iter().map(|text| self.encode(text)).collect()
859 }
860
861 fn get_embedding_dim(&self) -> usize {
862 self.embedding_dim
863 }
864}
865
866pub struct ProductionImageEncoder {
868 embedding_dim: usize,
869}
870
871impl ProductionImageEncoder {
872 pub fn new(embedding_dim: usize) -> Result<Self> {
873 Ok(Self { embedding_dim })
874 }
875
876 fn extract_image_features(&self, image: &ImageData) -> Result<Vec<f32>> {
878 let mut features = vec![0.0f32; self.embedding_dim];
882
883 let histogram_size = self.embedding_dim / 3;
885 for i in 0..histogram_size.min(image.data.len()) {
886 let pixel_value = image.data[i] as f32 / 255.0;
887 features[i % histogram_size] += pixel_value;
888 }
889
890 let spatial_offset = histogram_size;
892 features[spatial_offset] = image.width as f32 / 1000.0;
893 features[spatial_offset + 1] = image.height as f32 / 1000.0;
894 features[spatial_offset + 2] = (image.width * image.height) as f32 / 1_000_000.0;
895
896 let edge_offset = 2 * histogram_size;
898 for i in 0..(self.embedding_dim - edge_offset).min(100) {
899 if i + 1 < image.data.len() {
900 let gradient = (image.data[i + 1] as i32 - image.data[i] as i32).abs() as f32;
901 features[edge_offset + (i % (self.embedding_dim - edge_offset))] +=
902 gradient / 255.0;
903 }
904 }
905
906 let norm: f32 = features.iter().map(|x| x * x).sum::<f32>().sqrt();
908 if norm > 0.0 {
909 features.iter_mut().for_each(|x| *x /= norm);
910 }
911
912 Ok(features)
913 }
914}
915
916impl ImageEncoder for ProductionImageEncoder {
917 fn encode(&self, image: &ImageData) -> Result<Vector> {
918 let features = self.extract_image_features(image)?;
919 Ok(Vector::new(features))
920 }
921
922 fn encode_batch(&self, images: &[ImageData]) -> Result<Vec<Vector>> {
923 images.iter().map(|img| self.encode(img)).collect()
924 }
925
926 fn get_embedding_dim(&self) -> usize {
927 self.embedding_dim
928 }
929
930 fn extract_features(&self, image: &ImageData) -> Result<Vec<f32>> {
931 self.extract_image_features(image)
932 }
933}
934
935pub struct ProductionAudioEncoder {
937 embedding_dim: usize,
938}
939
940impl ProductionAudioEncoder {
941 pub fn new(embedding_dim: usize) -> Result<Self> {
942 Ok(Self { embedding_dim })
943 }
944
945 fn extract_audio_features(&self, audio: &AudioData) -> Result<Vec<f32>> {
947 let mut features = vec![0.0f32; self.embedding_dim];
948
949 let chunk_size = audio.samples.len().max(1) / (self.embedding_dim / 4).max(1);
951 for (i, chunk) in audio.samples.chunks(chunk_size).enumerate() {
952 if i >= self.embedding_dim / 4 {
953 break;
954 }
955 let energy: f32 = chunk.iter().map(|x| x * x).sum::<f32>() / chunk.len() as f32;
956 features[i] = energy.sqrt();
957 }
958
959 let zcr_offset = self.embedding_dim / 4;
961 let mut zero_crossings = 0;
962 for i in 1..audio.samples.len() {
963 if (audio.samples[i] >= 0.0) != (audio.samples[i - 1] >= 0.0) {
964 zero_crossings += 1;
965 }
966 }
967 if zcr_offset < features.len() {
968 features[zcr_offset] = zero_crossings as f32 / audio.samples.len() as f32;
969 }
970
971 let spectral_offset = self.embedding_dim / 2;
973 for i in 0..(self.embedding_dim - spectral_offset).min(audio.samples.len()) {
974 features[spectral_offset + i] =
975 audio.samples[i].abs() * (i as f32 / audio.samples.len() as f32);
976 }
977
978 let norm: f32 = features.iter().map(|x| x * x).sum::<f32>().sqrt();
980 if norm > 0.0 {
981 features.iter_mut().for_each(|x| *x /= norm);
982 }
983
984 Ok(features)
985 }
986}
987
988impl AudioEncoder for ProductionAudioEncoder {
989 fn encode(&self, audio: &AudioData) -> Result<Vector> {
990 let features = self.extract_audio_features(audio)?;
991 Ok(Vector::new(features))
992 }
993
994 fn encode_batch(&self, audios: &[AudioData]) -> Result<Vec<Vector>> {
995 audios.iter().map(|audio| self.encode(audio)).collect()
996 }
997
998 fn get_embedding_dim(&self) -> usize {
999 self.embedding_dim
1000 }
1001
1002 fn extract_features(&self, audio: &AudioData) -> Result<Vec<f32>> {
1003 self.extract_audio_features(audio)
1004 }
1005}
1006
1007pub struct ProductionVideoEncoder {
1009 embedding_dim: usize,
1010 image_encoder: ProductionImageEncoder,
1011}
1012
1013impl ProductionVideoEncoder {
1014 pub fn new(embedding_dim: usize) -> Result<Self> {
1015 Ok(Self {
1016 embedding_dim,
1017 image_encoder: ProductionImageEncoder::new(embedding_dim)?,
1018 })
1019 }
1020
1021 fn extract_video_features(&self, video: &VideoData) -> Result<Vec<f32>> {
1023 let mut all_features = Vec::new();
1024
1025 for keyframe_idx in &video.keyframes {
1027 if let Some(frame) = video.frames.get(*keyframe_idx) {
1028 let frame_features = self.image_encoder.extract_image_features(frame)?;
1029 all_features.extend(frame_features);
1030 }
1031 }
1032
1033 if all_features.is_empty() && !video.frames.is_empty() {
1035 let first_features = self
1036 .image_encoder
1037 .extract_image_features(&video.frames[0])?;
1038 all_features.extend(first_features);
1039
1040 if video.frames.len() > 1 {
1041 let last_features = self
1042 .image_encoder
1043 .extract_image_features(video.frames.last().unwrap())?;
1044 all_features.extend(last_features);
1045 }
1046 }
1047
1048 let mut aggregated = vec![0.0f32; self.embedding_dim];
1050 if !all_features.is_empty() {
1051 let chunk_size = all_features.len() / self.embedding_dim.max(1);
1052 if chunk_size > 0 {
1053 for (i, chunk) in all_features.chunks(chunk_size).enumerate() {
1054 if i >= self.embedding_dim {
1055 break;
1056 }
1057 aggregated[i] = chunk.iter().sum::<f32>() / chunk.len() as f32;
1058 }
1059 }
1060 }
1061
1062 if self.embedding_dim > 3 {
1064 aggregated[self.embedding_dim - 3] = video.fps / 60.0;
1065 aggregated[self.embedding_dim - 2] = video.duration / 600.0;
1066 aggregated[self.embedding_dim - 1] = video.frames.len() as f32 / 1000.0;
1067 }
1068
1069 let norm: f32 = aggregated.iter().map(|x| x * x).sum::<f32>().sqrt();
1071 if norm > 0.0 {
1072 aggregated.iter_mut().for_each(|x| *x /= norm);
1073 }
1074
1075 Ok(aggregated)
1076 }
1077}
1078
1079impl VideoEncoder for ProductionVideoEncoder {
1080 fn encode(&self, video: &VideoData) -> Result<Vector> {
1081 let features = self.extract_video_features(video)?;
1082 Ok(Vector::new(features))
1083 }
1084
1085 fn encode_keyframes(&self, video: &VideoData) -> Result<Vec<Vector>> {
1086 video
1087 .keyframes
1088 .iter()
1089 .filter_map(|&idx| video.frames.get(idx))
1090 .map(|frame| self.image_encoder.encode(frame))
1091 .collect()
1092 }
1093
1094 fn get_embedding_dim(&self) -> usize {
1095 self.embedding_dim
1096 }
1097}
1098
1099pub struct ProductionGraphEncoder {
1101 embedding_dim: usize,
1102}
1103
1104impl ProductionGraphEncoder {
1105 pub fn new(embedding_dim: usize) -> Result<Self> {
1106 Ok(Self { embedding_dim })
1107 }
1108
1109 fn extract_graph_features(&self, graph: &GraphData) -> Result<Vec<f32>> {
1111 let mut features = vec![0.0f32; self.embedding_dim];
1112
1113 let mut node_degrees = HashMap::new();
1115 for edge in &graph.edges {
1116 *node_degrees.entry(edge.source.clone()).or_insert(0) += 1;
1117 *node_degrees.entry(edge.target.clone()).or_insert(0) += 1;
1118 }
1119
1120 let degrees: Vec<usize> = node_degrees.values().copied().collect();
1122 if !degrees.is_empty() {
1123 let avg_degree = degrees.iter().sum::<usize>() as f32 / degrees.len() as f32;
1124 let max_degree = *degrees.iter().max().unwrap_or(&0) as f32;
1125
1126 features[0] = avg_degree / 100.0;
1127 features[1] = max_degree / 100.0;
1128 features[2] = graph.nodes.len() as f32 / 1000.0;
1129 features[3] = graph.edges.len() as f32 / 1000.0;
1130 }
1131
1132 for (_i, node) in graph.nodes.iter().enumerate().take(self.embedding_dim / 2) {
1134 if !node.labels.is_empty() {
1135 let label_hash = Self::hash_string(&node.labels[0]);
1136 let idx = 4 + (label_hash % (self.embedding_dim as u64 / 2 - 4)) as usize;
1137 features[idx] += 1.0 / graph.nodes.len() as f32;
1138 }
1139 }
1140
1141 let norm: f32 = features.iter().map(|x| x * x).sum::<f32>().sqrt();
1143 if norm > 0.0 {
1144 features.iter_mut().for_each(|x| *x /= norm);
1145 }
1146
1147 Ok(features)
1148 }
1149
1150 fn hash_string(s: &str) -> u64 {
1151 use std::collections::hash_map::DefaultHasher;
1152 use std::hash::{Hash, Hasher};
1153
1154 let mut hasher = DefaultHasher::new();
1155 s.hash(&mut hasher);
1156 hasher.finish()
1157 }
1158}
1159
1160impl GraphEncoder for ProductionGraphEncoder {
1161 fn encode(&self, graph: &GraphData) -> Result<Vector> {
1162 let features = self.extract_graph_features(graph)?;
1163 Ok(Vector::new(features))
1164 }
1165
1166 fn encode_node(&self, node: &crate::cross_modal_embeddings::GraphNode) -> Result<Vector> {
1167 let graph = GraphData {
1169 nodes: vec![node.clone()],
1170 edges: Vec::new(),
1171 metadata: HashMap::new(),
1172 };
1173 self.encode(&graph)
1174 }
1175
1176 fn encode_subgraph(
1177 &self,
1178 nodes: &[crate::cross_modal_embeddings::GraphNode],
1179 edges: &[crate::cross_modal_embeddings::GraphEdge],
1180 ) -> Result<Vector> {
1181 let graph = GraphData {
1182 nodes: nodes.to_vec(),
1183 edges: edges.to_vec(),
1184 metadata: HashMap::new(),
1185 };
1186 self.encode(&graph)
1187 }
1188
1189 fn get_embedding_dim(&self) -> usize {
1190 self.embedding_dim
1191 }
1192}
1193
1194#[cfg(test)]
1195mod tests {
1196 use super::*;
1197
1198 #[test]
1199 fn test_text_query() -> Result<()> {
1200 let _engine = MultiModalSearchEngine::new_default()?;
1201
1202 let query = MultiModalQuery::text("test query");
1203 assert_eq!(query.modalities.len(), 1);
1204
1205 Ok(())
1206 }
1207
1208 #[test]
1209 fn test_hybrid_query() -> Result<()> {
1210 let query = MultiModalQuery::hybrid(vec![
1211 QueryModality::Text("test".to_string()),
1212 QueryModality::Embedding(Vector::new(vec![0.0; 128])),
1213 ]);
1214
1215 assert_eq!(query.modalities.len(), 2);
1216
1217 Ok(())
1218 }
1219
1220 #[test]
1221 fn test_text_encoder() -> Result<()> {
1222 let encoder = ProductionTextEncoder::new(128)?;
1223
1224 let vector = encoder.encode("hello world")?;
1225 assert_eq!(vector.dimensions, 128);
1226
1227 let magnitude = vector.magnitude();
1229 assert!((magnitude - 1.0).abs() < 0.1);
1230
1231 Ok(())
1232 }
1233
1234 #[test]
1235 fn test_image_encoder() -> Result<()> {
1236 let encoder = ProductionImageEncoder::new(256)?;
1237
1238 let image_data = ImageData {
1239 data: vec![128; 1024],
1240 width: 32,
1241 height: 32,
1242 channels: 3,
1243 format: ImageFormat::RGB,
1244 features: None,
1245 };
1246
1247 let vector = encoder.encode(&image_data)?;
1248 assert_eq!(vector.dimensions, 256);
1249
1250 Ok(())
1251 }
1252
1253 #[test]
1254 fn test_audio_encoder() -> Result<()> {
1255 let encoder = ProductionAudioEncoder::new(128)?;
1256
1257 let audio_data = AudioData {
1258 samples: vec![0.5; 44100], sample_rate: 44100,
1260 channels: 1,
1261 duration: 1.0,
1262 features: None,
1263 };
1264
1265 let vector = encoder.encode(&audio_data)?;
1266 assert_eq!(vector.dimensions, 128);
1267
1268 Ok(())
1269 }
1270
1271 #[test]
1272 fn test_modality_fusion() -> Result<()> {
1273 let engine = MultiModalSearchEngine::new_default()?;
1274
1275 let mut modalities = HashMap::new();
1277 modalities.insert(Modality::Text, ModalityData::Text("test".to_string()));
1278
1279 let content = MultiModalContent {
1280 modalities,
1281 metadata: HashMap::new(),
1282 temporal_info: None,
1283 spatial_info: None,
1284 };
1285
1286 engine.index_content("test1".to_string(), content)?;
1287
1288 let stats = engine.get_statistics();
1289 assert_eq!(stats.total_vectors, 1);
1290
1291 Ok(())
1292 }
1293}