use crate::cross_modal_embeddings::{
AudioData, AudioEncoder, CrossModalConfig, CrossModalEncoder, GraphData, GraphEncoder,
ImageData, ImageEncoder, ImageFormat, Modality, ModalityData, MultiModalContent, TextEncoder,
VideoData, VideoEncoder,
};
use crate::Vector;
use crate::VectorStore;
use anyhow::{anyhow, Result};
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
pub struct MultiModalSearchEngine {
config: MultiModalConfig,
encoder: Arc<CrossModalEncoder>,
vector_store: Arc<RwLock<VectorStore>>,
modality_stores: HashMap<Modality, Arc<RwLock<VectorStore>>>,
query_cache: Arc<RwLock<HashMap<String, Vec<SearchResult>>>>,
total_indexed: Arc<RwLock<usize>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MultiModalConfig {
pub cross_modal_config: CrossModalConfig,
pub use_modality_specific_indices: bool,
pub enable_caching: bool,
pub cache_size_limit: usize,
pub search_strategy: SearchStrategy,
pub enable_query_expansion: bool,
pub query_expansion_factor: f32,
}
impl Default for MultiModalConfig {
fn default() -> Self {
Self {
cross_modal_config: CrossModalConfig::default(),
use_modality_specific_indices: true,
enable_caching: true,
cache_size_limit: 1000,
search_strategy: SearchStrategy::HybridFusion,
enable_query_expansion: true,
query_expansion_factor: 1.5,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum SearchStrategy {
JointSpaceOnly,
ModalitySpecific,
HybridFusion,
Adaptive,
}
#[derive(Debug, Clone)]
pub struct MultiModalQuery {
pub modalities: Vec<QueryModality>,
pub weights: Option<HashMap<Modality, f32>>,
pub filters: Vec<QueryFilter>,
pub metadata: HashMap<String, String>,
}
#[derive(Debug, Clone)]
pub enum QueryModality {
Text(String),
Image(Vec<u8>),
Audio(Vec<f32>, u32), Video(Vec<Vec<u8>>), Embedding(Vector), }
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QueryFilter {
pub field: String,
pub operator: FilterOperator,
pub value: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum FilterOperator {
Equals,
NotEquals,
Contains,
GreaterThan,
LessThan,
Regex,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SearchResult {
pub id: String,
pub score: f32,
pub modality: Modality,
pub metadata: HashMap<String, String>,
pub embedding: Option<Vec<f32>>,
pub modality_scores: HashMap<Modality, f32>,
}
impl MultiModalQuery {
pub fn text(text: impl Into<String>) -> Self {
Self {
modalities: vec![QueryModality::Text(text.into())],
weights: None,
filters: Vec::new(),
metadata: HashMap::new(),
}
}
pub fn image(image_data: Vec<u8>) -> Self {
Self {
modalities: vec![QueryModality::Image(image_data)],
weights: None,
filters: Vec::new(),
metadata: HashMap::new(),
}
}
pub fn audio(samples: Vec<f32>, sample_rate: u32) -> Self {
Self {
modalities: vec![QueryModality::Audio(samples, sample_rate)],
weights: None,
filters: Vec::new(),
metadata: HashMap::new(),
}
}
pub fn hybrid(modalities: Vec<QueryModality>) -> Self {
Self {
modalities,
weights: None,
filters: Vec::new(),
metadata: HashMap::new(),
}
}
pub fn with_filter(mut self, filter: QueryFilter) -> Self {
self.filters.push(filter);
self
}
pub fn with_weights(mut self, weights: HashMap<Modality, f32>) -> Self {
self.weights = Some(weights);
self
}
pub fn with_metadata(mut self, key: String, value: String) -> Self {
self.metadata.insert(key, value);
self
}
}
impl MultiModalSearchEngine {
pub fn new_default() -> Result<Self> {
Self::new(MultiModalConfig::default())
}
pub fn new(config: MultiModalConfig) -> Result<Self> {
let text_encoder = Box::new(ProductionTextEncoder::new(
config.cross_modal_config.joint_embedding_dim,
)?);
let image_encoder = Box::new(ProductionImageEncoder::new(
config.cross_modal_config.joint_embedding_dim,
)?);
let audio_encoder = Box::new(ProductionAudioEncoder::new(
config.cross_modal_config.joint_embedding_dim,
)?);
let video_encoder = Box::new(ProductionVideoEncoder::new(
config.cross_modal_config.joint_embedding_dim,
)?);
let graph_encoder = Box::new(ProductionGraphEncoder::new(
config.cross_modal_config.joint_embedding_dim,
)?);
let encoder = Arc::new(CrossModalEncoder::new(
config.cross_modal_config.clone(),
text_encoder,
image_encoder,
audio_encoder,
video_encoder,
graph_encoder,
));
let vector_store = Arc::new(RwLock::new(VectorStore::new()));
let mut modality_stores = HashMap::new();
if config.use_modality_specific_indices {
for modality in &[
Modality::Text,
Modality::Image,
Modality::Audio,
Modality::Video,
] {
let store = Arc::new(RwLock::new(VectorStore::new()));
modality_stores.insert(*modality, store);
}
}
Ok(Self {
config,
encoder,
vector_store,
modality_stores,
query_cache: Arc::new(RwLock::new(HashMap::new())),
total_indexed: Arc::new(RwLock::new(0)),
})
}
pub fn index_content(&self, id: String, content: MultiModalContent) -> Result<()> {
let embedding = self.encoder.encode(&content)?;
{
let mut store = self.vector_store.write();
store.index_vector(id.clone(), embedding.clone())?;
}
if self.config.use_modality_specific_indices {
for (modality, data) in &content.modalities {
if let Some(store) = self.modality_stores.get(modality) {
let modality_embedding = self.encode_modality(*modality, data)?;
let mut store = store.write();
store.index_vector(id.clone(), modality_embedding)?;
}
}
}
*self.total_indexed.write() += 1;
Ok(())
}
pub fn search(&self, query: &MultiModalQuery, k: usize) -> Result<Vec<SearchResult>> {
if self.config.enable_caching {
let cache_key = self.compute_cache_key(query);
if let Some(cached_results) = self.query_cache.read().get(&cache_key) {
return Ok(cached_results.clone());
}
}
let results = match self.config.search_strategy {
SearchStrategy::JointSpaceOnly => self.search_joint_space(query, k)?,
SearchStrategy::ModalitySpecific => self.search_modality_specific(query, k)?,
SearchStrategy::HybridFusion => self.search_hybrid(query, k)?,
SearchStrategy::Adaptive => self.search_adaptive(query, k)?,
};
let filtered_results = self.apply_filters(&results, &query.filters)?;
if self.config.enable_caching {
let cache_key = self.compute_cache_key(query);
let mut cache = self.query_cache.write();
if cache.len() >= self.config.cache_size_limit {
if let Some(first_key) = cache.keys().next().cloned() {
cache.remove(&first_key);
}
}
cache.insert(cache_key, filtered_results.clone());
}
Ok(filtered_results)
}
fn search_joint_space(&self, query: &MultiModalQuery, k: usize) -> Result<Vec<SearchResult>> {
let query_content = self.query_to_content(query)?;
let query_embedding = self.encoder.encode(&query_content)?;
let store = self.vector_store.read();
let results = store.similarity_search_vector(&query_embedding, k)?;
Ok(results
.into_iter()
.map(|(id, score)| SearchResult {
id,
score,
modality: Modality::Text, metadata: HashMap::new(),
embedding: None,
modality_scores: HashMap::new(),
})
.collect())
}
fn search_modality_specific(
&self,
query: &MultiModalQuery,
k: usize,
) -> Result<Vec<SearchResult>> {
let mut all_results: Vec<SearchResult> = Vec::new();
let mut modality_results: HashMap<Modality, Vec<SearchResult>> = HashMap::new();
for query_modality in &query.modalities {
let (modality, data) = match query_modality {
QueryModality::Text(text) => (Modality::Text, ModalityData::Text(text.clone())),
QueryModality::Image(img_data) => {
let image_data = self.parse_image_data(img_data)?;
(Modality::Image, ModalityData::Image(image_data))
}
QueryModality::Audio(samples, rate) => {
let audio_data = AudioData {
samples: samples.clone(),
sample_rate: *rate,
channels: 1,
duration: samples.len() as f32 / *rate as f32,
features: None,
};
(Modality::Audio, ModalityData::Audio(audio_data))
}
QueryModality::Embedding(embedding) => {
let store = self.vector_store.read();
let results = store.similarity_search_vector(embedding, k)?;
all_results.extend(results.into_iter().map(|(id, score)| SearchResult {
id,
score,
modality: Modality::Numeric,
metadata: HashMap::new(),
embedding: None,
modality_scores: HashMap::new(),
}));
continue;
}
QueryModality::Video(_frames) => {
continue; }
};
if let Some(store) = self.modality_stores.get(&modality) {
let embedding = self.encode_modality(modality, &data)?;
let store = store.read();
let results = store.similarity_search_vector(&embedding, k)?;
let search_results: Vec<SearchResult> = results
.into_iter()
.map(|(id, score)| SearchResult {
id,
score,
modality,
metadata: HashMap::new(),
embedding: None,
modality_scores: HashMap::new(),
})
.collect();
modality_results.insert(modality, search_results);
}
}
let fused_results = self.fuse_modality_results(modality_results, query, k)?;
Ok(fused_results)
}
fn search_hybrid(&self, query: &MultiModalQuery, k: usize) -> Result<Vec<SearchResult>> {
let joint_results = self.search_joint_space(query, k * 2)?;
let modality_results = self.search_modality_specific(query, k * 2)?;
let fused = self.fuse_search_results(vec![joint_results, modality_results], &[0.6, 0.4])?;
Ok(fused.into_iter().take(k).collect())
}
fn search_adaptive(&self, query: &MultiModalQuery, k: usize) -> Result<Vec<SearchResult>> {
let num_modalities = query.modalities.len();
if num_modalities == 1 {
return self.search_modality_specific(query, k);
}
self.search_hybrid(query, k)
}
fn encode_modality(&self, _modality: Modality, data: &ModalityData) -> Result<Vector> {
let mut content_map = HashMap::new();
match data {
ModalityData::Text(_text) => {
content_map.insert(Modality::Text, data.clone());
}
ModalityData::Image(_image) => {
content_map.insert(Modality::Image, data.clone());
}
ModalityData::Audio(_audio) => {
content_map.insert(Modality::Audio, data.clone());
}
ModalityData::Video(_video) => {
content_map.insert(Modality::Video, data.clone());
}
ModalityData::Graph(_graph) => {
content_map.insert(Modality::Graph, data.clone());
}
ModalityData::Numeric(values) => {
return Ok(Vector::new(values.clone()));
}
ModalityData::Raw(_) => {
return Err(anyhow!("Raw data encoding not supported"));
}
}
let content = MultiModalContent {
modalities: content_map,
metadata: HashMap::new(),
temporal_info: None,
spatial_info: None,
};
self.encoder.encode(&content)
}
fn query_to_content(&self, query: &MultiModalQuery) -> Result<MultiModalContent> {
let mut modalities = HashMap::new();
for query_modality in &query.modalities {
match query_modality {
QueryModality::Text(text) => {
modalities.insert(Modality::Text, ModalityData::Text(text.clone()));
}
QueryModality::Image(img_data) => {
let image_data = self.parse_image_data(img_data)?;
modalities.insert(Modality::Image, ModalityData::Image(image_data));
}
QueryModality::Audio(samples, rate) => {
let audio_data = AudioData {
samples: samples.clone(),
sample_rate: *rate,
channels: 1,
duration: samples.len() as f32 / *rate as f32,
features: None,
};
modalities.insert(Modality::Audio, ModalityData::Audio(audio_data));
}
QueryModality::Embedding(_) => {
}
QueryModality::Video(frames) => {
let video_frames: Result<Vec<ImageData>> =
frames.iter().map(|f| self.parse_image_data(f)).collect();
let video_data = VideoData {
frames: video_frames?,
audio: None,
fps: 30.0,
duration: frames.len() as f32 / 30.0,
keyframes: vec![0],
};
modalities.insert(Modality::Video, ModalityData::Video(video_data));
}
}
}
Ok(MultiModalContent {
modalities,
metadata: query.metadata.clone(),
temporal_info: None,
spatial_info: None,
})
}
fn parse_image_data(&self, data: &[u8]) -> Result<ImageData> {
#[cfg(feature = "images")]
{
use image::GenericImageView;
let img = image::load_from_memory(data)
.map_err(|e| anyhow!("Failed to decode image: {}", e))?;
let (width, height) = img.dimensions();
let rgb_img = img.to_rgb8();
let raw_data = rgb_img.into_raw();
Ok(ImageData {
data: raw_data,
width,
height,
channels: 3,
format: ImageFormat::RGB,
features: None,
})
}
#[cfg(not(feature = "images"))]
{
Ok(ImageData {
data: data.to_vec(),
width: 0,
height: 0,
channels: 3,
format: ImageFormat::RGB,
features: None,
})
}
}
fn fuse_modality_results(
&self,
modality_results: HashMap<Modality, Vec<SearchResult>>,
query: &MultiModalQuery,
k: usize,
) -> Result<Vec<SearchResult>> {
let mut score_map: HashMap<String, (f32, SearchResult)> = HashMap::new();
for (modality, results) in modality_results {
let weight = query
.weights
.as_ref()
.and_then(|w| w.get(&modality))
.copied()
.unwrap_or(1.0);
for (rank, result) in results.into_iter().enumerate() {
let rrf_score = weight / (60.0 + rank as f32 + 1.0);
score_map
.entry(result.id.clone())
.and_modify(|(score, existing)| {
*score += rrf_score;
existing.modality_scores.insert(modality, result.score);
})
.or_insert_with(|| {
let mut updated_result = result.clone();
updated_result
.modality_scores
.insert(modality, result.score);
(rrf_score, updated_result)
});
}
}
let mut fused_results: Vec<(f32, SearchResult)> = score_map.into_values().collect();
fused_results.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
Ok(fused_results
.into_iter()
.take(k)
.map(|(score, mut result)| {
result.score = score;
result
})
.collect())
}
fn fuse_search_results(
&self,
result_sets: Vec<Vec<SearchResult>>,
weights: &[f32],
) -> Result<Vec<SearchResult>> {
if result_sets.len() != weights.len() {
return Err(anyhow!("Weights length must match result sets length"));
}
let mut score_map: HashMap<String, (f32, SearchResult)> = HashMap::new();
for (results, &weight) in result_sets.into_iter().zip(weights.iter()) {
for (rank, result) in results.into_iter().enumerate() {
let rrf_score = weight / (60.0 + rank as f32 + 1.0);
score_map
.entry(result.id.clone())
.and_modify(|(score, _)| *score += rrf_score)
.or_insert_with(|| (rrf_score, result));
}
}
let mut fused_results: Vec<(f32, SearchResult)> = score_map.into_values().collect();
fused_results.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
Ok(fused_results
.into_iter()
.map(|(score, mut result)| {
result.score = score;
result
})
.collect())
}
fn apply_filters(
&self,
results: &[SearchResult],
filters: &[QueryFilter],
) -> Result<Vec<SearchResult>> {
if filters.is_empty() {
return Ok(results.to_vec());
}
let filtered: Vec<SearchResult> = results
.iter()
.filter(|result| self.matches_filters(result, filters))
.cloned()
.collect();
Ok(filtered)
}
fn matches_filters(&self, result: &SearchResult, filters: &[QueryFilter]) -> bool {
filters.iter().all(|filter| {
if let Some(value) = result.metadata.get(&filter.field) {
match filter.operator {
FilterOperator::Equals => value == &filter.value,
FilterOperator::NotEquals => value != &filter.value,
FilterOperator::Contains => value.contains(&filter.value),
FilterOperator::GreaterThan => value > &filter.value,
FilterOperator::LessThan => value < &filter.value,
FilterOperator::Regex => {
if let Ok(re) = regex::Regex::new(&filter.value) {
re.is_match(value)
} else {
false
}
}
}
} else {
false
}
})
}
fn compute_cache_key(&self, query: &MultiModalQuery) -> String {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
for modality in &query.modalities {
match modality {
QueryModality::Text(text) => text.hash(&mut hasher),
QueryModality::Image(data) => data.hash(&mut hasher),
QueryModality::Audio(samples, rate) => {
samples.len().hash(&mut hasher);
rate.hash(&mut hasher);
}
QueryModality::Video(frames) => frames.len().hash(&mut hasher),
QueryModality::Embedding(emb) => emb.dimensions.hash(&mut hasher),
}
}
format!("{:x}", hasher.finish())
}
pub fn get_statistics(&self) -> MultiModalStatistics {
let num_vectors = *self.total_indexed.read();
let mut modality_counts = HashMap::new();
for modality in self.modality_stores.keys() {
modality_counts.insert(*modality, 0);
}
MultiModalStatistics {
total_vectors: num_vectors,
modality_counts,
cache_size: self.query_cache.read().len(),
cache_hit_rate: 0.0, }
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MultiModalStatistics {
pub total_vectors: usize,
pub modality_counts: HashMap<Modality, usize>,
pub cache_size: usize,
pub cache_hit_rate: f32,
}
pub struct ProductionTextEncoder {
embedding_dim: usize,
vocab_size: usize,
}
impl ProductionTextEncoder {
pub fn new(embedding_dim: usize) -> Result<Self> {
Ok(Self {
embedding_dim,
vocab_size: 10000,
})
}
fn tokenize(&self, text: &str) -> Vec<String> {
text.to_lowercase()
.split_whitespace()
.map(|s| s.to_string())
.collect()
}
fn compute_embedding(&self, tokens: &[String]) -> Vec<f32> {
use std::collections::HashMap;
let mut freq_map = HashMap::new();
for token in tokens {
*freq_map.entry(token.clone()).or_insert(0) += 1;
}
let mut embedding = vec![0.0f32; self.embedding_dim];
for (token, count) in freq_map {
let hash = Self::hash_token(&token);
let index = (hash % self.embedding_dim as u64) as usize;
embedding[index] += count as f32 / tokens.len() as f32;
}
let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
embedding.iter_mut().for_each(|x| *x /= norm);
}
embedding
}
fn hash_token(token: &str) -> u64 {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
token.hash(&mut hasher);
hasher.finish()
}
}
impl TextEncoder for ProductionTextEncoder {
fn encode(&self, text: &str) -> Result<Vector> {
let tokens = self.tokenize(text);
let embedding = self.compute_embedding(&tokens);
Ok(Vector::new(embedding))
}
fn encode_batch(&self, texts: &[String]) -> Result<Vec<Vector>> {
texts.iter().map(|text| self.encode(text)).collect()
}
fn get_embedding_dim(&self) -> usize {
self.embedding_dim
}
}
pub struct ProductionImageEncoder {
embedding_dim: usize,
}
impl ProductionImageEncoder {
pub fn new(embedding_dim: usize) -> Result<Self> {
Ok(Self { embedding_dim })
}
fn extract_image_features(&self, image: &ImageData) -> Result<Vec<f32>> {
let mut features = vec![0.0f32; self.embedding_dim];
let histogram_size = self.embedding_dim / 3;
for i in 0..histogram_size.min(image.data.len()) {
let pixel_value = image.data[i] as f32 / 255.0;
features[i % histogram_size] += pixel_value;
}
let spatial_offset = histogram_size;
features[spatial_offset] = image.width as f32 / 1000.0;
features[spatial_offset + 1] = image.height as f32 / 1000.0;
features[spatial_offset + 2] = (image.width * image.height) as f32 / 1_000_000.0;
let edge_offset = 2 * histogram_size;
for i in 0..(self.embedding_dim - edge_offset).min(100) {
if i + 1 < image.data.len() {
let gradient = (image.data[i + 1] as i32 - image.data[i] as i32).abs() as f32;
features[edge_offset + (i % (self.embedding_dim - edge_offset))] +=
gradient / 255.0;
}
}
let norm: f32 = features.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
features.iter_mut().for_each(|x| *x /= norm);
}
Ok(features)
}
}
impl ImageEncoder for ProductionImageEncoder {
fn encode(&self, image: &ImageData) -> Result<Vector> {
let features = self.extract_image_features(image)?;
Ok(Vector::new(features))
}
fn encode_batch(&self, images: &[ImageData]) -> Result<Vec<Vector>> {
images.iter().map(|img| self.encode(img)).collect()
}
fn get_embedding_dim(&self) -> usize {
self.embedding_dim
}
fn extract_features(&self, image: &ImageData) -> Result<Vec<f32>> {
self.extract_image_features(image)
}
}
pub struct ProductionAudioEncoder {
embedding_dim: usize,
}
impl ProductionAudioEncoder {
pub fn new(embedding_dim: usize) -> Result<Self> {
Ok(Self { embedding_dim })
}
fn extract_audio_features(&self, audio: &AudioData) -> Result<Vec<f32>> {
let mut features = vec![0.0f32; self.embedding_dim];
let chunk_size = audio.samples.len().max(1) / (self.embedding_dim / 4).max(1);
for (i, chunk) in audio.samples.chunks(chunk_size).enumerate() {
if i >= self.embedding_dim / 4 {
break;
}
let energy: f32 = chunk.iter().map(|x| x * x).sum::<f32>() / chunk.len() as f32;
features[i] = energy.sqrt();
}
let zcr_offset = self.embedding_dim / 4;
let mut zero_crossings = 0;
for i in 1..audio.samples.len() {
if (audio.samples[i] >= 0.0) != (audio.samples[i - 1] >= 0.0) {
zero_crossings += 1;
}
}
if zcr_offset < features.len() {
features[zcr_offset] = zero_crossings as f32 / audio.samples.len() as f32;
}
let spectral_offset = self.embedding_dim / 2;
for i in 0..(self.embedding_dim - spectral_offset).min(audio.samples.len()) {
features[spectral_offset + i] =
audio.samples[i].abs() * (i as f32 / audio.samples.len() as f32);
}
let norm: f32 = features.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
features.iter_mut().for_each(|x| *x /= norm);
}
Ok(features)
}
}
impl AudioEncoder for ProductionAudioEncoder {
fn encode(&self, audio: &AudioData) -> Result<Vector> {
let features = self.extract_audio_features(audio)?;
Ok(Vector::new(features))
}
fn encode_batch(&self, audios: &[AudioData]) -> Result<Vec<Vector>> {
audios.iter().map(|audio| self.encode(audio)).collect()
}
fn get_embedding_dim(&self) -> usize {
self.embedding_dim
}
fn extract_features(&self, audio: &AudioData) -> Result<Vec<f32>> {
self.extract_audio_features(audio)
}
}
pub struct ProductionVideoEncoder {
embedding_dim: usize,
image_encoder: ProductionImageEncoder,
}
impl ProductionVideoEncoder {
pub fn new(embedding_dim: usize) -> Result<Self> {
Ok(Self {
embedding_dim,
image_encoder: ProductionImageEncoder::new(embedding_dim)?,
})
}
fn extract_video_features(&self, video: &VideoData) -> Result<Vec<f32>> {
let mut all_features = Vec::new();
for keyframe_idx in &video.keyframes {
if let Some(frame) = video.frames.get(*keyframe_idx) {
let frame_features = self.image_encoder.extract_image_features(frame)?;
all_features.extend(frame_features);
}
}
if all_features.is_empty() && !video.frames.is_empty() {
let first_features = self
.image_encoder
.extract_image_features(&video.frames[0])?;
all_features.extend(first_features);
if video.frames.len() > 1 {
let last_features = self.image_encoder.extract_image_features(
video
.frames
.last()
.expect("video frames validated to have more than one element"),
)?;
all_features.extend(last_features);
}
}
let mut aggregated = vec![0.0f32; self.embedding_dim];
if !all_features.is_empty() {
let chunk_size = all_features.len() / self.embedding_dim.max(1);
if chunk_size > 0 {
for (i, chunk) in all_features.chunks(chunk_size).enumerate() {
if i >= self.embedding_dim {
break;
}
aggregated[i] = chunk.iter().sum::<f32>() / chunk.len() as f32;
}
}
}
if self.embedding_dim > 3 {
aggregated[self.embedding_dim - 3] = video.fps / 60.0;
aggregated[self.embedding_dim - 2] = video.duration / 600.0;
aggregated[self.embedding_dim - 1] = video.frames.len() as f32 / 1000.0;
}
let norm: f32 = aggregated.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
aggregated.iter_mut().for_each(|x| *x /= norm);
}
Ok(aggregated)
}
}
impl VideoEncoder for ProductionVideoEncoder {
fn encode(&self, video: &VideoData) -> Result<Vector> {
let features = self.extract_video_features(video)?;
Ok(Vector::new(features))
}
fn encode_keyframes(&self, video: &VideoData) -> Result<Vec<Vector>> {
video
.keyframes
.iter()
.filter_map(|&idx| video.frames.get(idx))
.map(|frame| self.image_encoder.encode(frame))
.collect()
}
fn get_embedding_dim(&self) -> usize {
self.embedding_dim
}
}
pub struct ProductionGraphEncoder {
embedding_dim: usize,
}
impl ProductionGraphEncoder {
pub fn new(embedding_dim: usize) -> Result<Self> {
Ok(Self { embedding_dim })
}
fn extract_graph_features(&self, graph: &GraphData) -> Result<Vec<f32>> {
let mut features = vec![0.0f32; self.embedding_dim];
let mut node_degrees = HashMap::new();
for edge in &graph.edges {
*node_degrees.entry(edge.source.clone()).or_insert(0) += 1;
*node_degrees.entry(edge.target.clone()).or_insert(0) += 1;
}
let degrees: Vec<usize> = node_degrees.values().copied().collect();
if !degrees.is_empty() {
let avg_degree = degrees.iter().sum::<usize>() as f32 / degrees.len() as f32;
let max_degree = *degrees.iter().max().unwrap_or(&0) as f32;
features[0] = avg_degree / 100.0;
features[1] = max_degree / 100.0;
features[2] = graph.nodes.len() as f32 / 1000.0;
features[3] = graph.edges.len() as f32 / 1000.0;
}
for (_i, node) in graph.nodes.iter().enumerate().take(self.embedding_dim / 2) {
if !node.labels.is_empty() {
let label_hash = Self::hash_string(&node.labels[0]);
let idx = 4 + (label_hash % (self.embedding_dim as u64 / 2 - 4)) as usize;
features[idx] += 1.0 / graph.nodes.len() as f32;
}
}
let norm: f32 = features.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
features.iter_mut().for_each(|x| *x /= norm);
}
Ok(features)
}
fn hash_string(s: &str) -> u64 {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
s.hash(&mut hasher);
hasher.finish()
}
}
impl GraphEncoder for ProductionGraphEncoder {
fn encode(&self, graph: &GraphData) -> Result<Vector> {
let features = self.extract_graph_features(graph)?;
Ok(Vector::new(features))
}
fn encode_node(&self, node: &crate::cross_modal_embeddings::GraphNode) -> Result<Vector> {
let graph = GraphData {
nodes: vec![node.clone()],
edges: Vec::new(),
metadata: HashMap::new(),
};
self.encode(&graph)
}
fn encode_subgraph(
&self,
nodes: &[crate::cross_modal_embeddings::GraphNode],
edges: &[crate::cross_modal_embeddings::GraphEdge],
) -> Result<Vector> {
let graph = GraphData {
nodes: nodes.to_vec(),
edges: edges.to_vec(),
metadata: HashMap::new(),
};
self.encode(&graph)
}
fn get_embedding_dim(&self) -> usize {
self.embedding_dim
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_text_query() -> Result<()> {
let _engine = MultiModalSearchEngine::new_default()?;
let query = MultiModalQuery::text("test query");
assert_eq!(query.modalities.len(), 1);
Ok(())
}
#[test]
fn test_hybrid_query() -> Result<()> {
let query = MultiModalQuery::hybrid(vec![
QueryModality::Text("test".to_string()),
QueryModality::Embedding(Vector::new(vec![0.0; 128])),
]);
assert_eq!(query.modalities.len(), 2);
Ok(())
}
#[test]
fn test_text_encoder() -> Result<()> {
let encoder = ProductionTextEncoder::new(128)?;
let vector = encoder.encode("hello world")?;
assert_eq!(vector.dimensions, 128);
let magnitude = vector.magnitude();
assert!((magnitude - 1.0).abs() < 0.1);
Ok(())
}
#[test]
fn test_image_encoder() -> Result<()> {
let encoder = ProductionImageEncoder::new(256)?;
let image_data = ImageData {
data: vec![128; 1024],
width: 32,
height: 32,
channels: 3,
format: ImageFormat::RGB,
features: None,
};
let vector = encoder.encode(&image_data)?;
assert_eq!(vector.dimensions, 256);
Ok(())
}
#[test]
fn test_audio_encoder() -> Result<()> {
let encoder = ProductionAudioEncoder::new(128)?;
let audio_data = AudioData {
samples: vec![0.5; 44100], sample_rate: 44100,
channels: 1,
duration: 1.0,
features: None,
};
let vector = encoder.encode(&audio_data)?;
assert_eq!(vector.dimensions, 128);
Ok(())
}
#[test]
fn test_modality_fusion() -> Result<()> {
let engine = MultiModalSearchEngine::new_default()?;
let mut modalities = HashMap::new();
modalities.insert(Modality::Text, ModalityData::Text("test".to_string()));
let content = MultiModalContent {
modalities,
metadata: HashMap::new(),
temporal_info: None,
spatial_info: None,
};
engine.index_content("test1".to_string(), content)?;
let stats = engine.get_statistics();
assert_eq!(stats.total_vectors, 1);
Ok(())
}
}