rlx_runtime/
validators.rs1use std::fmt;
33
34#[derive(Debug, Clone)]
35pub struct ValidationError {
36 pub rule: &'static str,
37 pub message: String,
38}
39
40impl fmt::Display for ValidationError {
41 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
42 write!(f, "[{}] {}", self.rule, self.message)
43 }
44}
45
46impl std::error::Error for ValidationError {}
47
48pub type ValidationResult = Result<(), ValidationError>;
49
50pub trait Validator<C>: Send + Sync {
52 fn check(&self, ctx: &C) -> ValidationResult;
53 fn name(&self) -> &'static str;
54}
55
56pub fn run_chain<C>(ctx: &C, chain: &[&dyn Validator<C>]) -> ValidationResult {
58 for v in chain {
59 v.check(ctx)?;
60 }
61 Ok(())
62}
63
64#[derive(Debug, Clone)]
68pub struct TextContext {
69 pub seq_len: usize,
70 pub batch_size: usize,
71 pub vocab_id_max: usize,
72 pub max_token_id_seen: usize,
73}
74
75pub struct MaxSeqLen(pub usize);
76impl Validator<TextContext> for MaxSeqLen {
77 fn check(&self, ctx: &TextContext) -> ValidationResult {
78 if ctx.seq_len > self.0 {
79 Err(ValidationError {
80 rule: self.name(),
81 message: format!("seq_len {} exceeds max {}", ctx.seq_len, self.0),
82 })
83 } else {
84 Ok(())
85 }
86 }
87 fn name(&self) -> &'static str {
88 "max_seq_len"
89 }
90}
91
92pub struct MaxBatchSize(pub usize);
93impl Validator<TextContext> for MaxBatchSize {
94 fn check(&self, ctx: &TextContext) -> ValidationResult {
95 if ctx.batch_size > self.0 {
96 Err(ValidationError {
97 rule: self.name(),
98 message: format!("batch_size {} exceeds max {}", ctx.batch_size, self.0),
99 })
100 } else {
101 Ok(())
102 }
103 }
104 fn name(&self) -> &'static str {
105 "max_batch_size"
106 }
107}
108
109pub struct TokenIdsInVocab;
110impl Validator<TextContext> for TokenIdsInVocab {
111 fn check(&self, ctx: &TextContext) -> ValidationResult {
112 if ctx.max_token_id_seen >= ctx.vocab_id_max {
113 Err(ValidationError {
114 rule: self.name(),
115 message: format!(
116 "saw token_id {} but vocab is {}",
117 ctx.max_token_id_seen, ctx.vocab_id_max
118 ),
119 })
120 } else {
121 Ok(())
122 }
123 }
124 fn name(&self) -> &'static str {
125 "token_ids_in_vocab"
126 }
127}
128
129#[derive(Debug, Clone)]
132pub struct ImageContext {
133 pub width: u32,
134 pub height: u32,
135 pub channels: u32,
136}
137
138pub struct ImageMaxBounds {
139 pub max_w: u32,
140 pub max_h: u32,
141}
142impl Validator<ImageContext> for ImageMaxBounds {
143 fn check(&self, ctx: &ImageContext) -> ValidationResult {
144 if ctx.width > self.max_w || ctx.height > self.max_h {
145 Err(ValidationError {
146 rule: self.name(),
147 message: format!(
148 "{}×{} exceeds max {}×{}",
149 ctx.width, ctx.height, self.max_w, self.max_h
150 ),
151 })
152 } else {
153 Ok(())
154 }
155 }
156 fn name(&self) -> &'static str {
157 "image_max_bounds"
158 }
159}
160
161pub struct ChannelsAllowed(pub &'static [u32]);
162impl Validator<ImageContext> for ChannelsAllowed {
163 fn check(&self, ctx: &ImageContext) -> ValidationResult {
164 if !self.0.contains(&ctx.channels) {
165 Err(ValidationError {
166 rule: self.name(),
167 message: format!("channels={} not in allowed set {:?}", ctx.channels, self.0),
168 })
169 } else {
170 Ok(())
171 }
172 }
173 fn name(&self) -> &'static str {
174 "channels_allowed"
175 }
176}
177
178#[cfg(test)]
179mod tests {
180 use super::*;
181
182 #[test]
183 fn text_chain_short_circuits() {
184 let ctx = TextContext {
185 seq_len: 600,
186 batch_size: 1,
187 vocab_id_max: 30000,
188 max_token_id_seen: 100,
189 };
190 let max_seq = MaxSeqLen(512);
191 let max_batch = MaxBatchSize(64);
192 let tok = TokenIdsInVocab;
193 let chain: Vec<&dyn Validator<TextContext>> = vec![&max_seq, &max_batch, &tok];
194 let err = run_chain(&ctx, &chain).unwrap_err();
195 assert_eq!(err.rule, "max_seq_len");
196 }
197
198 #[test]
199 fn image_chain_passes() {
200 let ctx = ImageContext {
201 width: 224,
202 height: 224,
203 channels: 3,
204 };
205 let bounds = ImageMaxBounds {
206 max_w: 1024,
207 max_h: 1024,
208 };
209 let chans = ChannelsAllowed(&[1, 3, 4]);
210 let chain: Vec<&dyn Validator<ImageContext>> = vec![&bounds, &chans];
211 assert!(run_chain(&ctx, &chain).is_ok());
212 }
213
214 #[test]
215 fn image_chain_catches_bad_channels() {
216 let ctx = ImageContext {
217 width: 224,
218 height: 224,
219 channels: 2,
220 };
221 let chans = ChannelsAllowed(&[1, 3, 4]);
222 let chain: Vec<&dyn Validator<ImageContext>> = vec![&chans];
223 let err = run_chain(&ctx, &chain).unwrap_err();
224 assert_eq!(err.rule, "channels_allowed");
225 }
226}