Skip to main content

rlx_runtime/
validators.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Composable request validators (plan #84).
17//!
18//! Borrowed from MAX's `pipelines/core/context_validators.py`. Each
19//! validation rule is a small composable function returning
20//! `Result<(), ValidationError>`. Pipelines pick which to apply for
21//! their request type.
22//!
23//! Why composable instead of a single `validate_request()`?
24//!   - Different request types (text embed, image embed, generation)
25//!     share some rules (max length) but not others (image bounds).
26//!   - Adding a rule = one function, not editing a monolith.
27//!   - Easy to test rules in isolation.
28//!
29//! Used by future serving paths (#31, #32 in PLAN.md). Today the
30//! benchmark runners can use it for input sanity checks.
31
32use std::fmt;
33
34#[derive(Debug, Clone)]
35pub struct ValidationError {
36    pub rule: &'static str,
37    pub message: String,
38}
39
40impl fmt::Display for ValidationError {
41    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
42        write!(f, "[{}] {}", self.rule, self.message)
43    }
44}
45
46impl std::error::Error for ValidationError {}
47
48pub type ValidationResult = Result<(), ValidationError>;
49
50/// A single check on a context value.
51pub trait Validator<C>: Send + Sync {
52    fn check(&self, ctx: &C) -> ValidationResult;
53    fn name(&self) -> &'static str;
54}
55
56/// Run a chain of validators; return the first error or `Ok(())`.
57pub fn run_chain<C>(ctx: &C, chain: &[&dyn Validator<C>]) -> ValidationResult {
58    for v in chain {
59        v.check(ctx)?;
60    }
61    Ok(())
62}
63
64// ── Text-context validators ─────────────────────────────────────
65
66/// Context fields a text request typically carries.
67#[derive(Debug, Clone)]
68pub struct TextContext {
69    pub seq_len: usize,
70    pub batch_size: usize,
71    pub vocab_id_max: usize,
72    pub max_token_id_seen: usize,
73}
74
75pub struct MaxSeqLen(pub usize);
76impl Validator<TextContext> for MaxSeqLen {
77    fn check(&self, ctx: &TextContext) -> ValidationResult {
78        if ctx.seq_len > self.0 {
79            Err(ValidationError {
80                rule: self.name(),
81                message: format!("seq_len {} exceeds max {}", ctx.seq_len, self.0),
82            })
83        } else {
84            Ok(())
85        }
86    }
87    fn name(&self) -> &'static str {
88        "max_seq_len"
89    }
90}
91
92pub struct MaxBatchSize(pub usize);
93impl Validator<TextContext> for MaxBatchSize {
94    fn check(&self, ctx: &TextContext) -> ValidationResult {
95        if ctx.batch_size > self.0 {
96            Err(ValidationError {
97                rule: self.name(),
98                message: format!("batch_size {} exceeds max {}", ctx.batch_size, self.0),
99            })
100        } else {
101            Ok(())
102        }
103    }
104    fn name(&self) -> &'static str {
105        "max_batch_size"
106    }
107}
108
109pub struct TokenIdsInVocab;
110impl Validator<TextContext> for TokenIdsInVocab {
111    fn check(&self, ctx: &TextContext) -> ValidationResult {
112        if ctx.max_token_id_seen >= ctx.vocab_id_max {
113            Err(ValidationError {
114                rule: self.name(),
115                message: format!(
116                    "saw token_id {} but vocab is {}",
117                    ctx.max_token_id_seen, ctx.vocab_id_max
118                ),
119            })
120        } else {
121            Ok(())
122        }
123    }
124    fn name(&self) -> &'static str {
125        "token_ids_in_vocab"
126    }
127}
128
129// ── Image-context validators ────────────────────────────────────
130
131#[derive(Debug, Clone)]
132pub struct ImageContext {
133    pub width: u32,
134    pub height: u32,
135    pub channels: u32,
136}
137
138pub struct ImageMaxBounds {
139    pub max_w: u32,
140    pub max_h: u32,
141}
142impl Validator<ImageContext> for ImageMaxBounds {
143    fn check(&self, ctx: &ImageContext) -> ValidationResult {
144        if ctx.width > self.max_w || ctx.height > self.max_h {
145            Err(ValidationError {
146                rule: self.name(),
147                message: format!(
148                    "{}×{} exceeds max {}×{}",
149                    ctx.width, ctx.height, self.max_w, self.max_h
150                ),
151            })
152        } else {
153            Ok(())
154        }
155    }
156    fn name(&self) -> &'static str {
157        "image_max_bounds"
158    }
159}
160
161pub struct ChannelsAllowed(pub &'static [u32]);
162impl Validator<ImageContext> for ChannelsAllowed {
163    fn check(&self, ctx: &ImageContext) -> ValidationResult {
164        if !self.0.contains(&ctx.channels) {
165            Err(ValidationError {
166                rule: self.name(),
167                message: format!("channels={} not in allowed set {:?}", ctx.channels, self.0),
168            })
169        } else {
170            Ok(())
171        }
172    }
173    fn name(&self) -> &'static str {
174        "channels_allowed"
175    }
176}
177
178#[cfg(test)]
179mod tests {
180    use super::*;
181
182    #[test]
183    fn text_chain_short_circuits() {
184        let ctx = TextContext {
185            seq_len: 600,
186            batch_size: 1,
187            vocab_id_max: 30000,
188            max_token_id_seen: 100,
189        };
190        let max_seq = MaxSeqLen(512);
191        let max_batch = MaxBatchSize(64);
192        let tok = TokenIdsInVocab;
193        let chain: Vec<&dyn Validator<TextContext>> = vec![&max_seq, &max_batch, &tok];
194        let err = run_chain(&ctx, &chain).unwrap_err();
195        assert_eq!(err.rule, "max_seq_len");
196    }
197
198    #[test]
199    fn image_chain_passes() {
200        let ctx = ImageContext {
201            width: 224,
202            height: 224,
203            channels: 3,
204        };
205        let bounds = ImageMaxBounds {
206            max_w: 1024,
207            max_h: 1024,
208        };
209        let chans = ChannelsAllowed(&[1, 3, 4]);
210        let chain: Vec<&dyn Validator<ImageContext>> = vec![&bounds, &chans];
211        assert!(run_chain(&ctx, &chain).is_ok());
212    }
213
214    #[test]
215    fn image_chain_catches_bad_channels() {
216        let ctx = ImageContext {
217            width: 224,
218            height: 224,
219            channels: 2,
220        };
221        let chans = ChannelsAllowed(&[1, 3, 4]);
222        let chain: Vec<&dyn Validator<ImageContext>> = vec![&chans];
223        let err = run_chain(&ctx, &chain).unwrap_err();
224        assert_eq!(err.rule, "channels_allowed");
225    }
226}