Skip to main content

llm_multimodal/
types.rs

1use std::{collections::HashMap, fmt, path::PathBuf, sync::Arc};
2
3use image::DynamicImage;
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6
7/// Supported multimodal modalities.
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
9#[serde(rename_all = "snake_case")]
10pub enum Modality {
11    Image,
12    ImageEmbeds,
13    Audio,
14    Video,
15}
16
17impl fmt::Display for Modality {
18    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
19        match self {
20            Modality::Image => write!(f, "image"),
21            Modality::ImageEmbeds => write!(f, "image_embeds"),
22            Modality::Audio => write!(f, "audio"),
23            Modality::Video => write!(f, "video"),
24        }
25    }
26}
27
28/// Detail level passed by OpenAI style APIs.
29#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
30#[serde(rename_all = "snake_case")]
31pub enum ImageDetail {
32    #[default]
33    Auto,
34    Low,
35    High,
36}
37
38/// A normalized content part understood by the tracker.
39#[derive(Debug, Clone, Serialize, Deserialize)]
40#[serde(tag = "type", rename_all = "snake_case")]
41pub enum MediaContentPart {
42    Text {
43        text: String,
44    },
45    ImageUrl {
46        url: String,
47        #[serde(skip_serializing_if = "Option::is_none")]
48        detail: Option<ImageDetail>,
49        #[serde(skip_serializing_if = "Option::is_none")]
50        uuid: Option<String>,
51    },
52    ImageData {
53        data: Vec<u8>,
54        #[serde(skip_serializing_if = "Option::is_none")]
55        mime_type: Option<String>,
56        #[serde(skip_serializing_if = "Option::is_none")]
57        uuid: Option<String>,
58        #[serde(skip_serializing_if = "Option::is_none")]
59        detail: Option<ImageDetail>,
60    },
61    ImageEmbeds {
62        payload: Value,
63        #[serde(skip_serializing_if = "Option::is_none")]
64        uuid: Option<String>,
65    },
66}
67
68/// Image source metadata (useful for hashing & tracing).
69#[derive(Debug, Clone, Serialize, Deserialize)]
70#[serde(tag = "kind", rename_all = "snake_case")]
71pub enum ImageSource {
72    Url { url: String },
73    DataUrl,
74    InlineBytes,
75    File { path: PathBuf },
76}
77
78/// Concrete image payload captured by the media connector.
79#[derive(Debug, Clone)]
80pub struct ImageFrame {
81    pub image: DynamicImage,
82    pub raw_bytes: bytes::Bytes,
83    pub detail: ImageDetail,
84    pub source: ImageSource,
85    /// Blake3 hex-digest of raw_bytes, computed at decode time.
86    pub hash: String,
87}
88
89impl ImageFrame {
90    pub fn new(
91        image: DynamicImage,
92        raw_bytes: bytes::Bytes,
93        detail: ImageDetail,
94        source: ImageSource,
95        hash: String,
96    ) -> Self {
97        Self {
98            image,
99            raw_bytes,
100            detail,
101            source,
102            hash,
103        }
104    }
105
106    pub fn data(&self) -> &DynamicImage {
107        &self.image
108    }
109
110    pub fn raw_bytes(&self) -> &[u8] {
111        &self.raw_bytes
112    }
113
114    pub fn source(&self) -> &ImageSource {
115        &self.source
116    }
117
118    pub fn size(&self) -> ImageSize {
119        ImageSize::new(self.image.width(), self.image.height())
120    }
121}
122
123/// Container for all supported multimodal media objects.
124#[derive(Debug, Clone)]
125pub enum TrackedMedia {
126    Image(Arc<ImageFrame>),
127    /// Placeholder variants for future modalities.
128    Audio,
129    Video,
130    Embeddings,
131}
132
133pub type MultiModalData = HashMap<Modality, Vec<TrackedMedia>>;
134pub type MultiModalUUIDs = HashMap<Modality, Vec<Option<String>>>;
135
136pub type TokenId = i32;
137
138/// Declares how a multimodal tensor's first dimension maps to items (images).
139///
140/// Used by [`ModelProcessorSpec::field_layouts`] to tell the backend how to
141/// split tensors for per-item scheduling (vLLM `MultiModalFieldConfig`).
142#[derive(Debug, Clone, PartialEq, Eq)]
143pub enum FieldLayout {
144    /// First dimension equals num_images (one slice per image).
145    Batched,
146    /// Variable-length slices per image.  The sizes are stored in the
147    /// tensor named by `sizes_key` (e.g. `"patches_per_image"`).
148    Flat { sizes_key: String },
149}
150
151impl FieldLayout {
152    /// Convenience constructor for `Flat`.
153    pub fn flat(sizes_key: impl Into<String>) -> Self {
154        Self::Flat {
155            sizes_key: sizes_key.into(),
156        }
157    }
158}
159
160#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
161pub struct ImageSize {
162    pub width: u32,
163    pub height: u32,
164}
165
166impl ImageSize {
167    pub fn new(width: u32, height: u32) -> Self {
168        Self { width, height }
169    }
170}
171
172#[derive(Debug, Clone, Serialize, Deserialize, Default)]
173pub struct PlaceholderRange {
174    pub offset: usize,
175    pub length: usize,
176}
177
178#[derive(Debug, Clone)]
179pub struct PromptReplacement {
180    pub modality: Modality,
181    pub placeholder_token: String,
182    pub tokens: Vec<TokenId>,
183}
184
185impl PromptReplacement {
186    pub fn repeated(
187        modality: Modality,
188        placeholder_token: &str,
189        token_id: TokenId,
190        count: usize,
191    ) -> Self {
192        Self {
193            modality,
194            placeholder_token: placeholder_token.to_string(),
195            tokens: vec![token_id; count],
196        }
197    }
198
199    pub fn sequence(modality: Modality, placeholder_token: &str, sequence: Vec<TokenId>) -> Self {
200        Self {
201            modality,
202            placeholder_token: placeholder_token.to_string(),
203            tokens: sequence,
204        }
205    }
206}
207
208#[cfg(test)]
209mod tests {
210    use super::*;
211
212    #[test]
213    fn placeholder_range_serializes() {
214        let range = PlaceholderRange {
215            offset: 10,
216            length: 4,
217        };
218        let json = serde_json::to_string(&range).unwrap();
219        assert!(json.contains("offset"));
220    }
221
222    #[test]
223    fn prompt_replacement_builders() {
224        let rep = PromptReplacement::repeated(Modality::Image, "<image>", 100, 3);
225        assert_eq!(rep.tokens, vec![100, 100, 100]);
226    }
227}