Skip to main content

rlx_cli/
mtmd.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Multimodal turn assembly (PLAN.md M7).
17//!
18//! Replaces `llama-cpp-4`'s `MtmdContext` end-to-end. The runner
19//! receives a list of [`MtmdTurn`]s — text + images + audio — and
20//! produces an [`AssembledTurn`] the per-family VL/Omni runner
21//! consumes via the [`rlx_vlm_base`] traits.
22//!
23//! **Status:** TYPE SKELETON. The shape is in place so `skill` can
24//! write code against `MtmdContext::build_turn(..)` today; the
25//! actual image-loading / audio-resampling implementations land
26//! alongside the per-family runners in M7.
27
28use anyhow::{Result, bail};
29use std::path::PathBuf;
30
31type TokenizerFn<'a> = dyn Fn(&str) -> Result<Vec<u32>> + 'a;
32
33/// Where one image / audio chunk lives.
34#[derive(Debug, Clone)]
35pub enum MediaSource {
36    /// Read from a file path on disk.
37    FilePath(PathBuf),
38    /// Decoded bytes (e.g. base64 from a chat client).
39    Bytes(Vec<u8>),
40}
41
42/// One turn in a multimodal conversation. `text` is rendered through
43/// the same `ChatTemplate` as the text-only path; `images` / `audio`
44/// are interleaved into the LM stream by the per-family runner.
45#[derive(Debug, Clone)]
46pub struct MtmdTurn {
47    pub role: String,
48    pub text: String,
49    pub images: Vec<MediaSource>,
50    pub audio: Vec<MediaSource>,
51}
52
53impl MtmdTurn {
54    pub fn user(text: impl Into<String>) -> Self {
55        Self {
56            role: "user".into(),
57            text: text.into(),
58            images: Vec::new(),
59            audio: Vec::new(),
60        }
61    }
62    pub fn system(text: impl Into<String>) -> Self {
63        Self {
64            role: "system".into(),
65            text: text.into(),
66            images: Vec::new(),
67            audio: Vec::new(),
68        }
69    }
70    pub fn assistant(text: impl Into<String>) -> Self {
71        Self {
72            role: "assistant".into(),
73            text: text.into(),
74            images: Vec::new(),
75            audio: Vec::new(),
76        }
77    }
78    pub fn with_image_path(mut self, path: impl Into<PathBuf>) -> Self {
79        self.images.push(MediaSource::FilePath(path.into()));
80        self
81    }
82    pub fn with_image_bytes(mut self, bytes: Vec<u8>) -> Self {
83        self.images.push(MediaSource::Bytes(bytes));
84        self
85    }
86    pub fn with_audio_path(mut self, path: impl Into<PathBuf>) -> Self {
87        self.audio.push(MediaSource::FilePath(path.into()));
88        self
89    }
90    pub fn with_audio_bytes(mut self, bytes: Vec<u8>) -> Self {
91        self.audio.push(MediaSource::Bytes(bytes));
92        self
93    }
94
95    pub fn has_media(&self) -> bool {
96        !self.images.is_empty() || !self.audio.is_empty()
97    }
98}
99
100/// Result of assembling a turn list into something the per-family
101/// runner can feed into prefill. `text_tokens` is the chat-template
102/// output run through the tokenizer; `image_refs` / `audio_refs`
103/// retain order so the runner knows where to insert the embeddings.
104#[derive(Debug, Clone, Default)]
105pub struct AssembledTurn {
106    pub text_tokens: Vec<u32>,
107    pub image_refs: Vec<MediaSource>,
108    pub audio_refs: Vec<MediaSource>,
109}
110
111/// Context for assembling multimodal turns. Holds the chat template
112/// and (eventually) the tokenizer; per-family runners hand the
113/// resulting [`AssembledTurn`] into their prefill path.
114pub struct MtmdContext {
115    template_source: String,
116    bos_token: Option<String>,
117    eos_token: Option<String>,
118}
119
120impl MtmdContext {
121    /// Build a context from a Jinja chat template (typically loaded
122    /// from a GGUF via [`crate::ChatTemplate::from_gguf`]).
123    pub fn from_template_source(src: impl Into<String>) -> Self {
124        Self {
125            template_source: src.into(),
126            bos_token: None,
127            eos_token: None,
128        }
129    }
130
131    pub fn with_tokens(mut self, bos: Option<String>, eos: Option<String>) -> Self {
132        self.bos_token = bos;
133        self.eos_token = eos;
134        self
135    }
136
137    pub fn template_source(&self) -> &str {
138        &self.template_source
139    }
140    pub fn bos_token(&self) -> Option<&str> {
141        self.bos_token.as_deref()
142    }
143    pub fn eos_token(&self) -> Option<&str> {
144        self.eos_token.as_deref()
145    }
146
147    /// Assemble one turn list into [`AssembledTurn`].
148    ///
149    /// Renders the text using the registered chat template, replaces
150    /// each `<|image|>` / `<|audio|>` marker with the per-family
151    /// placeholder token id (resolved from the optional tokenizer
152    /// vocabulary), and records the order of media in `image_refs` /
153    /// `audio_refs` so the runner can insert the corresponding
154    /// embeddings at decode time.
155    ///
156    /// `tokenizer_fn` lets the caller plug in a per-family text→ids
157    /// encoder (typically `auto_tokenize`). Passing `None` populates
158    /// `text_tokens` with an empty vec — useful for callers that
159    /// own tokenization separately.
160    pub fn build_turn(
161        &self,
162        turns: &[MtmdTurn],
163        tokenizer_fn: Option<&TokenizerFn<'_>>,
164    ) -> Result<AssembledTurn> {
165        if turns.is_empty() {
166            bail!("MtmdContext::build_turn: empty turn list");
167        }
168        let mut text = String::new();
169        let mut image_refs = Vec::new();
170        let mut audio_refs = Vec::new();
171
172        // Minimal ChatML-style assembly. Real chat-template rendering
173        // (Jinja) lives in `crate::ChatTemplate` — when present,
174        // callers should pre-render and pass a single turn.
175        if let Some(bos) = self.bos_token.as_deref() {
176            text.push_str(bos);
177        }
178        for t in turns {
179            text.push_str("<|im_start|>");
180            text.push_str(&t.role);
181            text.push('\n');
182            text.push_str(&t.text);
183            // Insert image/audio markers after the text so the runner
184            // can interleave embeddings in order.
185            for img in &t.images {
186                text.push_str("<|image|>");
187                image_refs.push(img.clone());
188            }
189            for au in &t.audio {
190                text.push_str("<|audio|>");
191                audio_refs.push(au.clone());
192            }
193            text.push_str("<|im_end|>\n");
194        }
195        if let Some(eos) = self.eos_token.as_deref() {
196            text.push_str(eos);
197        }
198
199        let text_tokens = match tokenizer_fn {
200            Some(f) => f(&text)?,
201            None => Vec::new(),
202        };
203        Ok(AssembledTurn {
204            text_tokens,
205            image_refs,
206            audio_refs,
207        })
208    }
209}
210
211#[cfg(test)]
212mod tests {
213    use super::*;
214
215    #[test]
216    fn build_turn_records_media_order() {
217        let ctx = MtmdContext::from_template_source("").with_tokens(None, None);
218        let turn = MtmdTurn::user("describe")
219            .with_image_path("/tmp/a.png")
220            .with_audio_path("/tmp/b.wav")
221            .with_image_path("/tmp/c.png");
222        let out = ctx.build_turn(&[turn], None).unwrap();
223        assert_eq!(out.image_refs.len(), 2);
224        assert_eq!(out.audio_refs.len(), 1);
225        // Tokenizer absent → text_tokens empty.
226        assert!(out.text_tokens.is_empty());
227    }
228
229    #[test]
230    fn build_turn_invokes_tokenizer_callback() {
231        let ctx = MtmdContext::from_template_source("");
232        let counter = std::cell::Cell::new(0u32);
233        let tokenize = |s: &str| -> Result<Vec<u32>> {
234            counter.set(s.len() as u32);
235            Ok(vec![1, 2, 3])
236        };
237        let turn = MtmdTurn::user("hello");
238        let out = ctx
239            .build_turn(
240                &[turn],
241                Some(&tokenize as &dyn Fn(&str) -> Result<Vec<u32>>),
242            )
243            .unwrap();
244        assert_eq!(out.text_tokens, vec![1, 2, 3]);
245        assert!(counter.get() > 0, "tokenizer must see the rendered text");
246    }
247
248    #[test]
249    fn build_turn_rejects_empty() {
250        let ctx = MtmdContext::from_template_source("");
251        let err = ctx.build_turn(&[], None).unwrap_err();
252        assert!(format!("{err}").contains("empty turn list"));
253    }
254}