1use crate::{LnmpField, LnmpRecord, LnmpValue};
8
9#[derive(Debug, Clone, PartialEq, Eq)]
11pub enum StructuralError {
12 MaxDepthExceeded {
14 max_depth: usize,
16 seen_depth: usize,
18 },
19 MaxFieldsExceeded {
21 max_fields: usize,
23 seen_fields: usize,
25 },
26 MaxStringLengthExceeded {
28 max_len: usize,
30 seen_len: usize,
32 },
33 MaxArrayLengthExceeded {
35 max_len: usize,
37 seen_len: usize,
39 },
40}
41
42impl std::fmt::Display for StructuralError {
43 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44 match self {
45 StructuralError::MaxDepthExceeded {
46 max_depth,
47 seen_depth,
48 } => {
49 write!(
50 f,
51 "maximum nesting depth exceeded (max={}, saw={})",
52 max_depth, seen_depth
53 )
54 }
55 StructuralError::MaxFieldsExceeded {
56 max_fields,
57 seen_fields,
58 } => {
59 write!(
60 f,
61 "maximum field count exceeded (max={}, saw={})",
62 max_fields, seen_fields
63 )
64 }
65 StructuralError::MaxStringLengthExceeded { max_len, seen_len } => {
66 write!(
67 f,
68 "maximum string length exceeded (max={}, saw={})",
69 max_len, seen_len
70 )
71 }
72 StructuralError::MaxArrayLengthExceeded { max_len, seen_len } => {
73 write!(
74 f,
75 "maximum array length exceeded (max={}, saw={})",
76 max_len, seen_len
77 )
78 }
79 }
80 }
81}
82
83impl std::error::Error for StructuralError {}
84
85#[derive(Debug, Clone)]
87pub struct StructuralLimits {
88 pub max_depth: usize,
90 pub max_fields: usize,
92 pub max_string_len: usize,
94 pub max_array_items: usize,
96}
97
98impl Default for StructuralLimits {
99 fn default() -> Self {
100 Self {
101 max_depth: 32,
103 max_fields: 4096,
105 max_string_len: 16 * 1024,
107 max_array_items: 1024,
109 }
110 }
111}
112
113impl StructuralLimits {
114 pub fn validate_record(&self, record: &LnmpRecord) -> Result<(), StructuralError> {
116 let mut field_count = 0;
117 self.validate_fields(record.fields(), 0, &mut field_count)
118 }
119
120 fn validate_fields(
121 &self,
122 fields: &[LnmpField],
123 depth: usize,
124 field_count: &mut usize,
125 ) -> Result<(), StructuralError> {
126 if depth > self.max_depth {
127 return Err(StructuralError::MaxDepthExceeded {
128 max_depth: self.max_depth,
129 seen_depth: depth,
130 });
131 }
132
133 for field in fields {
134 *field_count += 1;
135 if *field_count > self.max_fields {
136 return Err(StructuralError::MaxFieldsExceeded {
137 max_fields: self.max_fields,
138 seen_fields: *field_count,
139 });
140 }
141 self.validate_value(&field.value, depth + 1, field_count)?;
142 }
143
144 Ok(())
145 }
146
147 fn validate_value(
148 &self,
149 value: &LnmpValue,
150 depth: usize,
151 field_count: &mut usize,
152 ) -> Result<(), StructuralError> {
153 match value {
154 LnmpValue::String(s) => {
155 if s.len() > self.max_string_len {
156 return Err(StructuralError::MaxStringLengthExceeded {
157 max_len: self.max_string_len,
158 seen_len: s.len(),
159 });
160 }
161 Ok(())
162 }
163 LnmpValue::StringArray(arr) => {
164 if arr.len() > self.max_array_items {
165 return Err(StructuralError::MaxArrayLengthExceeded {
166 max_len: self.max_array_items,
167 seen_len: arr.len(),
168 });
169 }
170 for s in arr {
171 if s.len() > self.max_string_len {
172 return Err(StructuralError::MaxStringLengthExceeded {
173 max_len: self.max_string_len,
174 seen_len: s.len(),
175 });
176 }
177 }
178 Ok(())
179 }
180 LnmpValue::IntArray(ints) => {
181 if ints.len() > self.max_array_items {
182 return Err(StructuralError::MaxArrayLengthExceeded {
183 max_len: self.max_array_items,
184 seen_len: ints.len(),
185 });
186 }
187 Ok(())
188 }
189 LnmpValue::FloatArray(floats) => {
190 if floats.len() > self.max_array_items {
191 return Err(StructuralError::MaxArrayLengthExceeded {
192 max_len: self.max_array_items,
193 seen_len: floats.len(),
194 });
195 }
196 Ok(())
197 }
198 LnmpValue::BoolArray(bools) => {
199 if bools.len() > self.max_array_items {
200 return Err(StructuralError::MaxArrayLengthExceeded {
201 max_len: self.max_array_items,
202 seen_len: bools.len(),
203 });
204 }
205 Ok(())
206 }
207 LnmpValue::NestedRecord(record) => {
208 self.validate_fields(record.fields(), depth, field_count)
209 }
210 LnmpValue::NestedArray(records) => {
211 if records.len() > self.max_array_items {
212 return Err(StructuralError::MaxArrayLengthExceeded {
213 max_len: self.max_array_items,
214 seen_len: records.len(),
215 });
216 }
217 for record in records {
218 self.validate_fields(record.fields(), depth, field_count)?;
219 }
220 Ok(())
221 }
222 LnmpValue::Int(_)
224 | LnmpValue::Float(_)
225 | LnmpValue::Bool(_)
226 | LnmpValue::Embedding(_)
227 | LnmpValue::EmbeddingDelta(_)
228 | LnmpValue::QuantizedEmbedding(_) => Ok(()),
229 }
230 }
231}
232
233#[cfg(test)]
234mod tests {
235 use super::*;
236
237 fn basic_record(string_len: usize) -> LnmpRecord {
238 let mut record = LnmpRecord::new();
239 record.add_field(LnmpField {
240 fid: 1,
241 value: LnmpValue::String("a".repeat(string_len)),
242 });
243 record
244 }
245
246 #[test]
247 fn validates_within_limits() {
248 let limits = StructuralLimits::default();
249 let record = basic_record(4);
250 assert!(limits.validate_record(&record).is_ok());
251 }
252
253 #[test]
254 fn rejects_oversized_string() {
255 let limits = StructuralLimits {
256 max_string_len: 2,
257 ..StructuralLimits::default()
258 };
259 let record = basic_record(3);
260 let err = limits.validate_record(&record).unwrap_err();
261 assert!(matches!(
262 err,
263 StructuralError::MaxStringLengthExceeded { .. }
264 ));
265 }
266
267 #[test]
268 fn rejects_excessive_depth() {
269 let mut inner = LnmpRecord::new();
270 inner.add_field(LnmpField {
271 fid: 2,
272 value: LnmpValue::Int(1),
273 });
274 let mut outer = LnmpRecord::new();
275 outer.add_field(LnmpField {
276 fid: 1,
277 value: LnmpValue::NestedRecord(Box::new(inner)),
278 });
279
280 let limits = StructuralLimits {
281 max_depth: 0,
282 ..StructuralLimits::default()
283 };
284 let err = limits.validate_record(&outer).unwrap_err();
285 assert!(matches!(err, StructuralError::MaxDepthExceeded { .. }));
286 }
287
288 #[test]
289 fn rejects_field_count_overflow() {
290 let mut record = LnmpRecord::new();
291 record.add_field(LnmpField {
292 fid: 1,
293 value: LnmpValue::Int(1),
294 });
295 record.add_field(LnmpField {
296 fid: 2,
297 value: LnmpValue::Int(2),
298 });
299
300 let limits = StructuralLimits {
301 max_fields: 1,
302 ..StructuralLimits::default()
303 };
304 let err = limits.validate_record(&record).unwrap_err();
305 assert!(matches!(err, StructuralError::MaxFieldsExceeded { .. }));
306 }
307
308 #[test]
309 fn rejects_array_length_overflow() {
310 let mut record = LnmpRecord::new();
311 record.add_field(LnmpField {
312 fid: 1,
313 value: LnmpValue::StringArray(vec!["a".to_string(), "b".to_string()]),
314 });
315 let limits = StructuralLimits {
316 max_array_items: 1,
317 ..StructuralLimits::default()
318 };
319 let err = limits.validate_record(&record).unwrap_err();
320 assert!(matches!(
321 err,
322 StructuralError::MaxArrayLengthExceeded { .. }
323 ));
324 }
325}