1use crate::{LnmpField, LnmpRecord, LnmpValue};
8
9#[derive(Debug, Clone, PartialEq, Eq)]
11pub enum StructuralError {
12 MaxDepthExceeded {
14 max_depth: usize,
16 seen_depth: usize,
18 },
19 MaxFieldsExceeded {
21 max_fields: usize,
23 seen_fields: usize,
25 },
26 MaxStringLengthExceeded {
28 max_len: usize,
30 seen_len: usize,
32 },
33 MaxArrayLengthExceeded {
35 max_len: usize,
37 seen_len: usize,
39 },
40}
41
42impl std::fmt::Display for StructuralError {
43 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44 match self {
45 StructuralError::MaxDepthExceeded {
46 max_depth,
47 seen_depth,
48 } => {
49 write!(
50 f,
51 "maximum nesting depth exceeded (max={}, saw={})",
52 max_depth, seen_depth
53 )
54 }
55 StructuralError::MaxFieldsExceeded {
56 max_fields,
57 seen_fields,
58 } => {
59 write!(
60 f,
61 "maximum field count exceeded (max={}, saw={})",
62 max_fields, seen_fields
63 )
64 }
65 StructuralError::MaxStringLengthExceeded { max_len, seen_len } => {
66 write!(
67 f,
68 "maximum string length exceeded (max={}, saw={})",
69 max_len, seen_len
70 )
71 }
72 StructuralError::MaxArrayLengthExceeded { max_len, seen_len } => {
73 write!(
74 f,
75 "maximum array length exceeded (max={}, saw={})",
76 max_len, seen_len
77 )
78 }
79 }
80 }
81}
82
83impl std::error::Error for StructuralError {}
84
85#[derive(Debug, Clone)]
87pub struct StructuralLimits {
88 pub max_depth: usize,
90 pub max_fields: usize,
92 pub max_string_len: usize,
94 pub max_array_items: usize,
96}
97
98impl Default for StructuralLimits {
99 fn default() -> Self {
100 Self {
101 max_depth: 32,
103 max_fields: 4096,
105 max_string_len: 16 * 1024,
107 max_array_items: 1024,
109 }
110 }
111}
112
113impl StructuralLimits {
114 pub fn validate_record(&self, record: &LnmpRecord) -> Result<(), StructuralError> {
116 let mut field_count = 0;
117 self.validate_fields(record.fields(), 0, &mut field_count)
118 }
119
120 fn validate_fields(
121 &self,
122 fields: &[LnmpField],
123 depth: usize,
124 field_count: &mut usize,
125 ) -> Result<(), StructuralError> {
126 if depth > self.max_depth {
127 return Err(StructuralError::MaxDepthExceeded {
128 max_depth: self.max_depth,
129 seen_depth: depth,
130 });
131 }
132
133 for field in fields {
134 *field_count += 1;
135 if *field_count > self.max_fields {
136 return Err(StructuralError::MaxFieldsExceeded {
137 max_fields: self.max_fields,
138 seen_fields: *field_count,
139 });
140 }
141 self.validate_value(&field.value, depth + 1, field_count)?;
142 }
143
144 Ok(())
145 }
146
147 fn validate_value(
148 &self,
149 value: &LnmpValue,
150 depth: usize,
151 field_count: &mut usize,
152 ) -> Result<(), StructuralError> {
153 match value {
154 LnmpValue::String(s) => {
155 if s.len() > self.max_string_len {
156 return Err(StructuralError::MaxStringLengthExceeded {
157 max_len: self.max_string_len,
158 seen_len: s.len(),
159 });
160 }
161 Ok(())
162 }
163 LnmpValue::StringArray(arr) => {
164 if arr.len() > self.max_array_items {
165 return Err(StructuralError::MaxArrayLengthExceeded {
166 max_len: self.max_array_items,
167 seen_len: arr.len(),
168 });
169 }
170 for s in arr {
171 if s.len() > self.max_string_len {
172 return Err(StructuralError::MaxStringLengthExceeded {
173 max_len: self.max_string_len,
174 seen_len: s.len(),
175 });
176 }
177 }
178 Ok(())
179 }
180 LnmpValue::NestedRecord(record) => {
181 self.validate_fields(record.fields(), depth, field_count)
182 }
183 LnmpValue::NestedArray(records) => {
184 if records.len() > self.max_array_items {
185 return Err(StructuralError::MaxArrayLengthExceeded {
186 max_len: self.max_array_items,
187 seen_len: records.len(),
188 });
189 }
190 for record in records {
191 self.validate_fields(record.fields(), depth, field_count)?;
192 }
193 Ok(())
194 }
195 LnmpValue::Int(_)
197 | LnmpValue::Float(_)
198 | LnmpValue::Bool(_)
199 | LnmpValue::Embedding(_)
200 | LnmpValue::EmbeddingDelta(_)
201 | LnmpValue::QuantizedEmbedding(_) => Ok(()),
202 }
203 }
204}
205
206#[cfg(test)]
207mod tests {
208 use super::*;
209
210 fn basic_record(string_len: usize) -> LnmpRecord {
211 let mut record = LnmpRecord::new();
212 record.add_field(LnmpField {
213 fid: 1,
214 value: LnmpValue::String("a".repeat(string_len)),
215 });
216 record
217 }
218
219 #[test]
220 fn validates_within_limits() {
221 let limits = StructuralLimits::default();
222 let record = basic_record(4);
223 assert!(limits.validate_record(&record).is_ok());
224 }
225
226 #[test]
227 fn rejects_oversized_string() {
228 let limits = StructuralLimits {
229 max_string_len: 2,
230 ..StructuralLimits::default()
231 };
232 let record = basic_record(3);
233 let err = limits.validate_record(&record).unwrap_err();
234 assert!(matches!(
235 err,
236 StructuralError::MaxStringLengthExceeded { .. }
237 ));
238 }
239
240 #[test]
241 fn rejects_excessive_depth() {
242 let mut inner = LnmpRecord::new();
243 inner.add_field(LnmpField {
244 fid: 2,
245 value: LnmpValue::Int(1),
246 });
247 let mut outer = LnmpRecord::new();
248 outer.add_field(LnmpField {
249 fid: 1,
250 value: LnmpValue::NestedRecord(Box::new(inner)),
251 });
252
253 let limits = StructuralLimits {
254 max_depth: 0,
255 ..StructuralLimits::default()
256 };
257 let err = limits.validate_record(&outer).unwrap_err();
258 assert!(matches!(err, StructuralError::MaxDepthExceeded { .. }));
259 }
260
261 #[test]
262 fn rejects_field_count_overflow() {
263 let mut record = LnmpRecord::new();
264 record.add_field(LnmpField {
265 fid: 1,
266 value: LnmpValue::Int(1),
267 });
268 record.add_field(LnmpField {
269 fid: 2,
270 value: LnmpValue::Int(2),
271 });
272
273 let limits = StructuralLimits {
274 max_fields: 1,
275 ..StructuralLimits::default()
276 };
277 let err = limits.validate_record(&record).unwrap_err();
278 assert!(matches!(err, StructuralError::MaxFieldsExceeded { .. }));
279 }
280
281 #[test]
282 fn rejects_array_length_overflow() {
283 let mut record = LnmpRecord::new();
284 record.add_field(LnmpField {
285 fid: 1,
286 value: LnmpValue::StringArray(vec!["a".to_string(), "b".to_string()]),
287 });
288 let limits = StructuralLimits {
289 max_array_items: 1,
290 ..StructuralLimits::default()
291 };
292 let err = limits.validate_record(&record).unwrap_err();
293 assert!(matches!(
294 err,
295 StructuralError::MaxArrayLengthExceeded { .. }
296 ));
297 }
298}