use anyhow::Result;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Modality {
Text,
Image,
Audio,
}
#[derive(Debug, Clone)]
pub struct ImagePatches {
pub patches: Vec<f32>,
pub grid_h: usize,
pub grid_w: usize,
pub patch_h: usize,
pub patch_w: usize,
pub channels: usize,
}
impl ImagePatches {
pub fn num_patches(&self) -> usize {
self.grid_h * self.grid_w
}
pub fn patch_dim(&self) -> usize {
self.channels * self.patch_h * self.patch_w
}
}
pub trait ImagePreprocessor: Send {
fn preprocess_path(&self, path: &std::path::Path) -> Result<ImagePatches>;
fn preprocess_bytes(&self, bytes: &[u8]) -> Result<ImagePatches>;
}
pub trait VisionTower: Send {
fn embed(&mut self, patches: &ImagePatches) -> Result<Vec<f32>>;
fn hidden_size(&self) -> usize;
}
pub trait Projector: Send {
fn project(&mut self, vision_embed: &[f32], num_patches: usize) -> Result<Vec<f32>>;
fn output_dim(&self) -> usize;
}
pub trait AudioEncoder: Send {
fn embed_audio(&mut self, samples: &[f32], sample_rate: u32) -> Result<Vec<f32>>;
fn hidden_size(&self) -> usize;
}
#[derive(Debug, Clone, Default)]
pub struct MultimodalPrompt {
pub chunks: Vec<PromptChunk>,
}
#[derive(Debug, Clone)]
pub enum PromptChunk {
Text(Vec<u32>),
Image(ImagePatches),
Audio { samples: Vec<f32>, sample_rate: u32 },
}
impl MultimodalPrompt {
pub fn push(&mut self, chunk: PromptChunk) {
self.chunks.push(chunk);
}
pub fn is_text_only(&self) -> bool {
self.chunks
.iter()
.all(|c| matches!(c, PromptChunk::Text(_)))
}
pub fn num_chunks(&self) -> usize {
self.chunks.len()
}
}