1use crate::{LnmpField, LnmpRecord, LnmpValue};
8
9#[derive(Debug, Clone, PartialEq, Eq)]
11pub enum StructuralError {
12 MaxDepthExceeded {
14 max_depth: usize,
16 seen_depth: usize,
18 },
19 MaxFieldsExceeded {
21 max_fields: usize,
23 seen_fields: usize,
25 },
26 MaxStringLengthExceeded {
28 max_len: usize,
30 seen_len: usize,
32 },
33 MaxArrayLengthExceeded {
35 max_len: usize,
37 seen_len: usize,
39 },
40}
41
42impl std::fmt::Display for StructuralError {
43 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44 match self {
45 StructuralError::MaxDepthExceeded {
46 max_depth,
47 seen_depth,
48 } => {
49 write!(
50 f,
51 "maximum nesting depth exceeded (max={}, saw={})",
52 max_depth, seen_depth
53 )
54 }
55 StructuralError::MaxFieldsExceeded {
56 max_fields,
57 seen_fields,
58 } => {
59 write!(
60 f,
61 "maximum field count exceeded (max={}, saw={})",
62 max_fields, seen_fields
63 )
64 }
65 StructuralError::MaxStringLengthExceeded { max_len, seen_len } => {
66 write!(
67 f,
68 "maximum string length exceeded (max={}, saw={})",
69 max_len, seen_len
70 )
71 }
72 StructuralError::MaxArrayLengthExceeded { max_len, seen_len } => {
73 write!(
74 f,
75 "maximum array length exceeded (max={}, saw={})",
76 max_len, seen_len
77 )
78 }
79 }
80 }
81}
82
83impl std::error::Error for StructuralError {}
84
85#[derive(Debug, Clone)]
87pub struct StructuralLimits {
88 pub max_depth: usize,
90 pub max_fields: usize,
92 pub max_string_len: usize,
94 pub max_array_items: usize,
96}
97
98impl Default for StructuralLimits {
99 fn default() -> Self {
100 Self {
101 max_depth: 32,
103 max_fields: 4096,
105 max_string_len: 16 * 1024,
107 max_array_items: 1024,
109 }
110 }
111}
112
113impl StructuralLimits {
114 pub fn validate_record(&self, record: &LnmpRecord) -> Result<(), StructuralError> {
116 let mut field_count = 0;
117 self.validate_fields(record.fields(), 0, &mut field_count)
118 }
119
120 fn validate_fields(
121 &self,
122 fields: &[LnmpField],
123 depth: usize,
124 field_count: &mut usize,
125 ) -> Result<(), StructuralError> {
126 if depth > self.max_depth {
127 return Err(StructuralError::MaxDepthExceeded {
128 max_depth: self.max_depth,
129 seen_depth: depth,
130 });
131 }
132
133 for field in fields {
134 *field_count += 1;
135 if *field_count > self.max_fields {
136 return Err(StructuralError::MaxFieldsExceeded {
137 max_fields: self.max_fields,
138 seen_fields: *field_count,
139 });
140 }
141 self.validate_value(&field.value, depth + 1, field_count)?;
142 }
143
144 Ok(())
145 }
146
147 fn validate_value(
148 &self,
149 value: &LnmpValue,
150 depth: usize,
151 field_count: &mut usize,
152 ) -> Result<(), StructuralError> {
153 match value {
154 LnmpValue::String(s) => {
155 if s.len() > self.max_string_len {
156 return Err(StructuralError::MaxStringLengthExceeded {
157 max_len: self.max_string_len,
158 seen_len: s.len(),
159 });
160 }
161 Ok(())
162 }
163 LnmpValue::StringArray(arr) => {
164 if arr.len() > self.max_array_items {
165 return Err(StructuralError::MaxArrayLengthExceeded {
166 max_len: self.max_array_items,
167 seen_len: arr.len(),
168 });
169 }
170 for s in arr {
171 if s.len() > self.max_string_len {
172 return Err(StructuralError::MaxStringLengthExceeded {
173 max_len: self.max_string_len,
174 seen_len: s.len(),
175 });
176 }
177 }
178 Ok(())
179 }
180 LnmpValue::IntArray(ints) => {
181 if ints.len() > self.max_array_items {
182 return Err(StructuralError::MaxArrayLengthExceeded {
183 max_len: self.max_array_items,
184 seen_len: ints.len(),
185 });
186 }
187 Ok(())
188 }
189 LnmpValue::FloatArray(floats) => {
190 if floats.len() > self.max_array_items {
191 return Err(StructuralError::MaxArrayLengthExceeded {
192 max_len: self.max_array_items,
193 seen_len: floats.len(),
194 });
195 }
196 Ok(())
197 }
198 LnmpValue::BoolArray(bools) => {
199 if bools.len() > self.max_array_items {
200 return Err(StructuralError::MaxArrayLengthExceeded {
201 max_len: self.max_array_items,
202 seen_len: bools.len(),
203 });
204 }
205 Ok(())
206 }
207 LnmpValue::NestedRecord(record) => {
208 self.validate_fields(record.fields(), depth, field_count)
209 }
210 LnmpValue::NestedArray(records) => {
211 if records.len() > self.max_array_items {
212 return Err(StructuralError::MaxArrayLengthExceeded {
213 max_len: self.max_array_items,
214 seen_len: records.len(),
215 });
216 }
217 for record in records {
218 self.validate_fields(record.fields(), depth, field_count)?;
219 }
220 Ok(())
221 }
222 LnmpValue::Int(_)
224 | LnmpValue::Float(_)
225 | LnmpValue::Bool(_)
226 | LnmpValue::Embedding(_)
227 | LnmpValue::EmbeddingDelta(_) => Ok(()),
228 #[cfg(feature = "quant")]
229 LnmpValue::QuantizedEmbedding(_) => Ok(()),
230 }
231 }
232}
233
234#[cfg(test)]
235mod tests {
236 use super::*;
237
238 fn basic_record(string_len: usize) -> LnmpRecord {
239 let mut record = LnmpRecord::new();
240 record.add_field(LnmpField {
241 fid: 1,
242 value: LnmpValue::String("a".repeat(string_len)),
243 });
244 record
245 }
246
247 #[test]
248 fn validates_within_limits() {
249 let limits = StructuralLimits::default();
250 let record = basic_record(4);
251 assert!(limits.validate_record(&record).is_ok());
252 }
253
254 #[test]
255 fn rejects_oversized_string() {
256 let limits = StructuralLimits {
257 max_string_len: 2,
258 ..StructuralLimits::default()
259 };
260 let record = basic_record(3);
261 let err = limits.validate_record(&record).unwrap_err();
262 assert!(matches!(
263 err,
264 StructuralError::MaxStringLengthExceeded { .. }
265 ));
266 }
267
268 #[test]
269 fn rejects_excessive_depth() {
270 let mut inner = LnmpRecord::new();
271 inner.add_field(LnmpField {
272 fid: 2,
273 value: LnmpValue::Int(1),
274 });
275 let mut outer = LnmpRecord::new();
276 outer.add_field(LnmpField {
277 fid: 1,
278 value: LnmpValue::NestedRecord(Box::new(inner)),
279 });
280
281 let limits = StructuralLimits {
282 max_depth: 0,
283 ..StructuralLimits::default()
284 };
285 let err = limits.validate_record(&outer).unwrap_err();
286 assert!(matches!(err, StructuralError::MaxDepthExceeded { .. }));
287 }
288
289 #[test]
290 fn rejects_field_count_overflow() {
291 let mut record = LnmpRecord::new();
292 record.add_field(LnmpField {
293 fid: 1,
294 value: LnmpValue::Int(1),
295 });
296 record.add_field(LnmpField {
297 fid: 2,
298 value: LnmpValue::Int(2),
299 });
300
301 let limits = StructuralLimits {
302 max_fields: 1,
303 ..StructuralLimits::default()
304 };
305 let err = limits.validate_record(&record).unwrap_err();
306 assert!(matches!(err, StructuralError::MaxFieldsExceeded { .. }));
307 }
308
309 #[test]
310 fn rejects_array_length_overflow() {
311 let mut record = LnmpRecord::new();
312 record.add_field(LnmpField {
313 fid: 1,
314 value: LnmpValue::StringArray(vec!["a".to_string(), "b".to_string()]),
315 });
316 let limits = StructuralLimits {
317 max_array_items: 1,
318 ..StructuralLimits::default()
319 };
320 let err = limits.validate_record(&record).unwrap_err();
321 assert!(matches!(
322 err,
323 StructuralError::MaxArrayLengthExceeded { .. }
324 ));
325 }
326}