1use crate::document::{Item, MatrixList, Node};
46use crate::error::{HedlError, HedlResult};
47use crate::header::Header;
48use crate::inference::{infer_quoted_value, infer_value, InferenceContext};
49use crate::lex::row::parse_csv_row;
50use crate::lex::{calculate_indent, strip_comment};
51use crate::limits::Limits;
52use crate::reference::TypeRegistry;
53use crate::value::Value;
54use rayon::prelude::*;
55use std::collections::BTreeMap;
56use std::ops::Range;
57use std::sync::atomic::{AtomicUsize, Ordering};
58
59#[derive(Debug, Clone)]
76pub struct ParallelConfig {
77 pub enabled: bool,
79 pub min_root_entities: usize,
81 pub min_list_rows: usize,
83 pub thread_pool_size: Option<usize>,
85}
86
87impl Default for ParallelConfig {
88 fn default() -> Self {
89 Self {
90 enabled: true,
91 min_root_entities: 50,
92 min_list_rows: 100,
93 thread_pool_size: None,
94 }
95 }
96}
97
98impl ParallelConfig {
99 pub fn conservative() -> Self {
103 Self {
104 enabled: true,
105 min_root_entities: 100,
106 min_list_rows: 200,
107 thread_pool_size: None,
108 }
109 }
110
111 pub fn aggressive() -> Self {
115 Self {
116 enabled: true,
117 min_root_entities: 20,
118 min_list_rows: 50,
119 thread_pool_size: None,
120 }
121 }
122
123 pub fn should_parallelize_entities(&self, entity_count: usize) -> bool {
125 self.enabled && entity_count >= self.min_root_entities
126 }
127
128 pub fn should_parallelize_rows(&self, row_count: usize) -> bool {
130 self.enabled && row_count >= self.min_list_rows
131 }
132}
133
134pub struct AtomicSecurityCounters {
139 node_count: AtomicUsize,
140 total_keys: AtomicUsize,
141}
142
143impl AtomicSecurityCounters {
144 pub fn new() -> Self {
146 Self {
147 node_count: AtomicUsize::new(0),
148 total_keys: AtomicUsize::new(0),
149 }
150 }
151
152 pub fn increment_nodes(&self, limits: &Limits, line_num: usize) -> HedlResult<()> {
156 let count = self.node_count.fetch_add(1, Ordering::Relaxed);
157 if count >= limits.max_nodes {
158 return Err(HedlError::security(
159 format!("too many nodes: exceeds limit of {}", limits.max_nodes),
160 line_num,
161 ));
162 }
163 Ok(())
164 }
165
166 pub fn increment_keys(&self, limits: &Limits, line_num: usize) -> HedlResult<()> {
168 let count = self.total_keys.fetch_add(1, Ordering::Relaxed);
169 if count >= limits.max_total_keys {
170 return Err(HedlError::security(
171 format!(
172 "too many total keys: {} exceeds limit {}",
173 count + 1,
174 limits.max_total_keys
175 ),
176 line_num,
177 ));
178 }
179 Ok(())
180 }
181
182 pub fn node_count(&self) -> usize {
184 self.node_count.load(Ordering::Relaxed)
185 }
186
187 pub fn key_count(&self) -> usize {
189 self.total_keys.load(Ordering::Relaxed)
190 }
191}
192
193impl Default for AtomicSecurityCounters {
194 fn default() -> Self {
195 Self::new()
196 }
197}
198
199#[derive(Debug, Clone)]
204pub struct EntityBoundary {
205 pub key: String,
207 pub line_range: Range<usize>,
209 pub entity_type: EntityType,
211}
212
213#[derive(Debug, Clone, Copy, PartialEq, Eq)]
215pub enum EntityType {
216 Object,
218 List,
220 Scalar,
222}
223
224#[derive(Debug)]
228pub struct MatrixRowBatch<'a> {
229 pub type_name: String,
231 pub schema: Vec<String>,
233 pub rows: Vec<(usize, &'a str)>,
235 pub has_ditto: bool,
237}
238
239impl<'a> MatrixRowBatch<'a> {
240 pub fn can_parallelize(&self) -> bool {
244 !self.has_ditto && !self.rows.is_empty()
245 }
246}
247
248pub fn identify_entity_boundaries(lines: &[(usize, &str)]) -> Vec<EntityBoundary> {
253 let mut boundaries = Vec::new();
254 let mut current_key: Option<String> = None;
255 let mut current_start: usize = 0;
256 let mut current_type = EntityType::Object;
257
258 for (idx, &(line_num, line)) in lines.iter().enumerate() {
259 let trimmed = line.trim();
261 if trimmed.is_empty() || trimmed.starts_with('#') {
262 continue;
263 }
264
265 let indent_info = calculate_indent(line, line_num as u32).ok().flatten();
267 let indent = indent_info.map(|i| i.level).unwrap_or(0);
268
269 if indent == 0 && line.contains(':') {
270 if let Some(key) = current_key.take() {
272 boundaries.push(EntityBoundary {
274 key,
275 line_range: current_start..idx,
276 entity_type: current_type,
277 });
278 }
279
280 if let Some(colon_pos) = line.find(':') {
282 let key_part = &line[..colon_pos].trim();
283 let key = if let Some(paren_pos) = key_part.find('(') {
285 key_part[..paren_pos].trim()
286 } else {
287 key_part
288 };
289
290 current_key = Some(key.to_string());
291 current_start = idx;
292
293 let after_colon = line[colon_pos + 1..].trim();
295 current_type = if after_colon.starts_with('@') {
296 EntityType::List
297 } else if after_colon.is_empty() {
298 EntityType::Object
299 } else {
300 EntityType::Scalar
301 };
302 }
303 }
304 }
305
306 if let Some(key) = current_key {
308 boundaries.push(EntityBoundary {
309 key,
310 line_range: current_start..lines.len(),
311 entity_type: current_type,
312 });
313 }
314
315 boundaries
316}
317
318pub fn collect_matrix_rows<'a>(
323 lines: &'a [(usize, &str)],
324 list_start: usize,
325 expected_indent: usize,
326) -> MatrixRowBatch<'a> {
327 let mut rows = Vec::new();
328 let mut has_ditto = false;
329 let mut type_name = String::new();
330 let mut schema = Vec::new();
331
332 if list_start < lines.len() {
334 let (_, decl_line) = lines[list_start];
335 if let Some(colon_pos) = decl_line.find(':') {
336 let after_colon = decl_line[colon_pos + 1..].trim();
337 if let Some(rest) = after_colon.strip_prefix('@') {
338 if let Some(bracket_pos) = rest.find('[') {
340 type_name = rest[..bracket_pos].to_string();
341 let schema_str = &rest[bracket_pos..];
342 if schema_str.starts_with('[') && schema_str.ends_with(']') {
343 let inner = &schema_str[1..schema_str.len() - 1];
344 schema = inner.split(',').map(|s| s.trim().to_string()).collect();
345 }
346 } else {
347 type_name = rest.trim().to_string();
348 }
349 }
350 }
351 }
352
353 let mut i = list_start + 1;
355 while i < lines.len() {
356 let (line_num, line) = lines[i];
357
358 let indent_info = calculate_indent(line, line_num as u32).ok().flatten();
360 let indent = indent_info.map(|info| info.level).unwrap_or(0);
361
362 if indent < expected_indent {
364 break;
365 }
366
367 let content = line.trim();
368
369 if content.starts_with('|') {
371 if content.contains('"') && !content.contains("\"\"") {
373 let unquoted_part: String = content
375 .chars()
376 .scan(false, |in_quote, c| {
377 if c == '"' {
378 *in_quote = !*in_quote;
379 }
380 Some(if *in_quote { ' ' } else { c })
381 })
382 .collect();
383 if unquoted_part.contains('"') {
384 has_ditto = true;
385 }
386 }
387
388 rows.push((line_num, content));
389 } else if !content.is_empty() && !content.starts_with('#') {
390 break;
392 }
393
394 i += 1;
395 }
396
397 MatrixRowBatch {
398 type_name,
399 schema,
400 rows,
401 has_ditto,
402 }
403}
404
405pub fn parse_matrix_rows_parallel(
410 batch: &MatrixRowBatch<'_>,
411 header: &Header,
412 limits: &Limits,
413 counters: &AtomicSecurityCounters,
414) -> HedlResult<Vec<Node>> {
415 if !batch.can_parallelize() {
416 return parse_matrix_rows_sequential(batch, header, limits);
418 }
419
420 let nodes: Vec<HedlResult<Node>> = batch
422 .rows
423 .par_iter()
424 .map(|(line_num, row_content)| {
425 counters.increment_nodes(limits, *line_num)?;
427
428 parse_single_matrix_row(
430 row_content,
431 &batch.schema,
432 &batch.type_name,
433 header,
434 *line_num,
435 )
436 })
437 .collect();
438
439 nodes.into_iter().collect()
441}
442
443fn parse_matrix_rows_sequential(
445 batch: &MatrixRowBatch<'_>,
446 header: &Header,
447 limits: &Limits,
448) -> HedlResult<Vec<Node>> {
449 let mut nodes = Vec::with_capacity(batch.rows.len());
450 let mut prev_values: Option<Vec<Value>> = None;
451 let mut node_count = 0usize;
452
453 for (line_num, row_content) in &batch.rows {
454 node_count = node_count
456 .checked_add(1)
457 .ok_or_else(|| HedlError::security("node count overflow", *line_num))?;
458 if node_count > limits.max_nodes {
459 return Err(HedlError::security(
460 format!("too many nodes: exceeds limit of {}", limits.max_nodes),
461 *line_num,
462 ));
463 }
464
465 let node = parse_single_matrix_row_with_ditto(
467 row_content,
468 &batch.schema,
469 &batch.type_name,
470 header,
471 *line_num,
472 prev_values.as_deref(),
473 )?;
474
475 prev_values = Some(node.fields.to_vec());
476 nodes.push(node);
477 }
478
479 Ok(nodes)
480}
481
482fn parse_single_matrix_row(
486 row_content: &str,
487 schema: &[String],
488 type_name: &str,
489 header: &Header,
490 line_num: usize,
491) -> HedlResult<Node> {
492 parse_single_matrix_row_with_ditto(row_content, schema, type_name, header, line_num, None)
493}
494
495fn parse_single_matrix_row_with_ditto(
497 row_content: &str,
498 schema: &[String],
499 type_name: &str,
500 header: &Header,
501 line_num: usize,
502 prev_values: Option<&[Value]>,
503) -> HedlResult<Node> {
504 let content = row_content.strip_prefix('|').unwrap_or(row_content);
506
507 let csv_content = strip_comment(content).trim();
509
510 let fields =
512 parse_csv_row(csv_content).map_err(|e| HedlError::syntax(e.to_string(), line_num))?;
513
514 if fields.len() != schema.len() {
516 return Err(HedlError::shape(
517 format!("expected {} columns, got {}", schema.len(), fields.len()),
518 line_num,
519 ));
520 }
521
522 let mut values = Vec::with_capacity(fields.len());
524 for (col_idx, field) in fields.iter().enumerate() {
525 let ctx =
526 InferenceContext::for_matrix_cell(&header.aliases, col_idx, prev_values, type_name)
527 .with_version(header.version)
528 .with_null_char(header.null_char);
529
530 let value = if field.is_quoted {
531 infer_quoted_value(&field.value)
532 } else {
533 infer_value(&field.value, &ctx, line_num)?
534 };
535
536 values.push(value);
537 }
538
539 let id = match &values[0] {
541 Value::String(s) => s.clone(),
542 _ => {
543 return Err(HedlError::semantic("ID column must be a string", line_num));
544 }
545 };
546
547 let node = Node::new(type_name, &*id, values);
549
550 Ok(node)
551}
552
553pub fn collect_ids_parallel(
558 items: &BTreeMap<String, Item>,
559 limits: &Limits,
560) -> HedlResult<TypeRegistry> {
561 if items.len() < 10 {
563 let mut registry = TypeRegistry::new();
564 collect_ids_sequential(items, &mut registry, 0, limits.max_nest_depth, limits)?;
565 return Ok(registry);
566 }
567
568 let local_registries: Vec<HedlResult<TypeRegistry>> = items
570 .par_iter()
571 .map(|(_, item)| {
572 let mut local = TypeRegistry::new();
573 collect_item_ids(item, &mut local, 0, limits.max_nest_depth, limits)?;
574 Ok(local)
575 })
576 .collect();
577
578 let mut merged = TypeRegistry::new();
580 for result in local_registries {
581 let local = result?;
582 merge_registry_into(&mut merged, local, limits)?;
583 }
584
585 Ok(merged)
586}
587
588fn collect_ids_sequential(
590 items: &BTreeMap<String, Item>,
591 registry: &mut TypeRegistry,
592 depth: usize,
593 max_depth: usize,
594 limits: &Limits,
595) -> HedlResult<()> {
596 if depth > max_depth {
597 return Err(HedlError::security(
598 format!(
599 "NEST hierarchy depth {} exceeds maximum allowed depth {}",
600 depth, max_depth
601 ),
602 0,
603 ));
604 }
605
606 for item in items.values() {
607 collect_item_ids(item, registry, depth, max_depth, limits)?;
608 }
609
610 Ok(())
611}
612
613fn collect_item_ids(
615 item: &Item,
616 registry: &mut TypeRegistry,
617 depth: usize,
618 max_depth: usize,
619 limits: &Limits,
620) -> HedlResult<()> {
621 match item {
622 Item::List(list) => {
623 collect_list_ids(list, registry, depth, max_depth, limits)?;
624 }
625 Item::Object(obj) => {
626 collect_ids_sequential(obj, registry, depth + 1, max_depth, limits)?;
627 }
628 Item::Scalar(_) => {}
629 }
630 Ok(())
631}
632
633fn collect_list_ids(
635 list: &MatrixList,
636 registry: &mut TypeRegistry,
637 depth: usize,
638 max_depth: usize,
639 limits: &Limits,
640) -> HedlResult<()> {
641 for node in &list.rows {
642 registry.register(&list.type_name, &node.id, 0, limits)?;
643 }
644
645 for node in &list.rows {
647 if let Some(children) = node.children() {
648 for child_list in children.values() {
649 for child in child_list {
650 collect_node_ids(child, registry, depth + 1, max_depth, limits)?;
651 }
652 }
653 }
654 }
655
656 Ok(())
657}
658
659fn collect_node_ids(
661 node: &Node,
662 registry: &mut TypeRegistry,
663 depth: usize,
664 max_depth: usize,
665 limits: &Limits,
666) -> HedlResult<()> {
667 if depth > max_depth {
668 return Err(HedlError::security(
669 format!(
670 "NEST hierarchy depth {} exceeds maximum allowed depth {}",
671 depth, max_depth
672 ),
673 0,
674 ));
675 }
676
677 registry.register(&node.type_name, &node.id, 0, limits)?;
678
679 if let Some(children) = node.children() {
680 for child_list in children.values() {
681 for child in child_list {
682 collect_node_ids(child, registry, depth + 1, max_depth, limits)?;
683 }
684 }
685 }
686
687 Ok(())
688}
689
690fn merge_registry_into(
692 target: &mut TypeRegistry,
693 source: TypeRegistry,
694 limits: &Limits,
695) -> HedlResult<()> {
696 for (id, types) in source.by_id_iter() {
698 for type_name in types {
699 if target.contains_in_type(type_name, id) {
701 return Err(HedlError::collision(
702 format!(
703 "duplicate ID '{}' in type '{}' detected during parallel merge",
704 id, type_name
705 ),
706 0,
707 ));
708 }
709 target.register(type_name, id, 0, limits)?;
711 }
712 }
713 Ok(())
714}
715
716pub fn validate_references_parallel(
721 items: &BTreeMap<String, Item>,
722 registries: &TypeRegistry,
723 strict: bool,
724 max_depth: usize,
725) -> HedlResult<()> {
726 if items.len() < 10 {
728 return validate_references_sequential(items, registries, strict, None, 0, max_depth);
729 }
730
731 let results: Vec<HedlResult<()>> = items
733 .par_iter()
734 .map(|(_, item)| validate_item_refs(item, registries, strict, None, 0, max_depth))
735 .collect();
736
737 for result in results {
739 result?;
740 }
741
742 Ok(())
743}
744
745fn validate_references_sequential(
747 items: &BTreeMap<String, Item>,
748 registries: &TypeRegistry,
749 strict: bool,
750 current_type: Option<&str>,
751 depth: usize,
752 max_depth: usize,
753) -> HedlResult<()> {
754 if depth > max_depth {
755 return Err(HedlError::security(
756 format!(
757 "NEST hierarchy depth {} exceeds maximum allowed depth {}",
758 depth, max_depth
759 ),
760 0,
761 ));
762 }
763
764 for item in items.values() {
765 validate_item_refs(item, registries, strict, current_type, depth, max_depth)?;
766 }
767
768 Ok(())
769}
770
771fn validate_item_refs(
773 item: &Item,
774 registries: &TypeRegistry,
775 strict: bool,
776 current_type: Option<&str>,
777 depth: usize,
778 max_depth: usize,
779) -> HedlResult<()> {
780 match item {
781 Item::Scalar(value) => {
782 validate_value_ref(value, registries, strict, current_type)?;
783 }
784 Item::List(list) => {
785 for node in &list.rows {
786 validate_node_refs(node, registries, strict, depth, max_depth)?;
787 }
788 }
789 Item::Object(obj) => {
790 validate_references_sequential(
791 obj,
792 registries,
793 strict,
794 current_type,
795 depth + 1,
796 max_depth,
797 )?;
798 }
799 }
800 Ok(())
801}
802
803fn validate_node_refs(
805 node: &Node,
806 registries: &TypeRegistry,
807 strict: bool,
808 depth: usize,
809 max_depth: usize,
810) -> HedlResult<()> {
811 if depth > max_depth {
812 return Err(HedlError::security(
813 format!(
814 "NEST hierarchy depth {} exceeds maximum allowed depth {}",
815 depth, max_depth
816 ),
817 0,
818 ));
819 }
820
821 for value in &node.fields {
822 validate_value_ref(value, registries, strict, Some(&node.type_name))?;
823 }
824
825 if let Some(children) = node.children() {
826 for child_list in children.values() {
827 for child in child_list {
828 validate_node_refs(child, registries, strict, depth + 1, max_depth)?;
829 }
830 }
831 }
832
833 Ok(())
834}
835
836fn validate_value_ref(
838 value: &Value,
839 registries: &TypeRegistry,
840 strict: bool,
841 current_type: Option<&str>,
842) -> HedlResult<()> {
843 if let Value::Reference(ref_val) = value {
844 let resolved = match &ref_val.type_name {
845 Some(t) => registries.contains_in_type(t, &ref_val.id),
846 None => match current_type {
847 Some(type_name) => registries.contains_in_type(type_name, &ref_val.id),
848 None => {
849 let matching_types = registries.lookup_unqualified(&ref_val.id).unwrap_or(&[]);
850 match matching_types.len() {
851 0 => false,
852 1 => true,
853 _ => {
854 return Err(HedlError::reference(
855 format!(
856 "Ambiguous unqualified reference '@{}' matches multiple types: [{}]",
857 ref_val.id,
858 matching_types.join(", ")
859 ),
860 0,
861 ));
862 }
863 }
864 }
865 },
866 };
867
868 if !resolved && strict {
869 return Err(HedlError::reference(
870 format!("unresolved reference {}", ref_val.to_ref_string()),
871 0,
872 ));
873 }
874 }
875
876 Ok(())
877}
878
879#[cfg(test)]
880mod tests {
881 use super::*;
882
883 #[test]
884 fn test_parallel_config_default() {
885 let config = ParallelConfig::default();
886 assert!(config.enabled);
887 assert_eq!(config.min_root_entities, 50);
888 assert_eq!(config.min_list_rows, 100);
889 assert!(config.thread_pool_size.is_none());
890 }
891
892 #[test]
893 fn test_parallel_config_conservative() {
894 let config = ParallelConfig::conservative();
895 assert_eq!(config.min_root_entities, 100);
896 assert_eq!(config.min_list_rows, 200);
897 }
898
899 #[test]
900 fn test_parallel_config_aggressive() {
901 let config = ParallelConfig::aggressive();
902 assert_eq!(config.min_root_entities, 20);
903 assert_eq!(config.min_list_rows, 50);
904 }
905
906 #[test]
907 fn test_parallel_config_thresholds() {
908 let config = ParallelConfig::default();
909
910 assert!(!config.should_parallelize_entities(10));
911 assert!(!config.should_parallelize_entities(49));
912 assert!(config.should_parallelize_entities(50));
913 assert!(config.should_parallelize_entities(100));
914
915 assert!(!config.should_parallelize_rows(10));
916 assert!(!config.should_parallelize_rows(99));
917 assert!(config.should_parallelize_rows(100));
918 assert!(config.should_parallelize_rows(1000));
919 }
920
921 #[test]
922 fn test_atomic_counters() {
923 let counters = AtomicSecurityCounters::new();
924 assert_eq!(counters.node_count(), 0);
925 assert_eq!(counters.key_count(), 0);
926
927 let limits = Limits::default();
928
929 counters.increment_nodes(&limits, 0).unwrap();
931 assert_eq!(counters.node_count(), 1);
932
933 counters.increment_keys(&limits, 0).unwrap();
934 assert_eq!(counters.key_count(), 1);
935 }
936
937 #[test]
938 fn test_atomic_counters_limit_exceeded() {
939 let counters = AtomicSecurityCounters::new();
940 let limits = Limits {
941 max_nodes: 1,
942 ..Default::default()
943 };
944
945 counters.increment_nodes(&limits, 0).unwrap();
947
948 let result = counters.increment_nodes(&limits, 0);
950 assert!(result.is_err());
951 }
952
953 #[test]
954 fn test_entity_boundary_identification() {
955 let lines: Vec<(usize, &str)> = vec![
956 (1, "users:@User"),
957 (2, "| alice"),
958 (3, "| bob"),
959 (4, "settings:"),
960 (5, " debug: true"),
961 (6, "count: 42"),
962 ];
963
964 let boundaries = identify_entity_boundaries(&lines);
965 assert_eq!(boundaries.len(), 3);
966
967 assert_eq!(boundaries[0].key, "users");
968 assert_eq!(boundaries[0].entity_type, EntityType::List);
969 assert_eq!(boundaries[0].line_range, 0..3);
970
971 assert_eq!(boundaries[1].key, "settings");
972 assert_eq!(boundaries[1].entity_type, EntityType::Object);
973 assert_eq!(boundaries[1].line_range, 3..5);
974
975 assert_eq!(boundaries[2].key, "count");
976 assert_eq!(boundaries[2].entity_type, EntityType::Scalar);
977 assert_eq!(boundaries[2].line_range, 5..6);
978 }
979
980 #[test]
981 fn test_matrix_row_batch_no_ditto() {
982 let batch = MatrixRowBatch {
983 type_name: "User".to_string(),
984 schema: vec!["id".to_string(), "name".to_string()],
985 rows: vec![(1, "|alice, Alice"), (2, "|bob, Bob")],
986 has_ditto: false,
987 };
988
989 assert!(batch.can_parallelize());
990 }
991
992 #[test]
993 fn test_matrix_row_batch_with_ditto() {
994 let batch = MatrixRowBatch {
995 type_name: "User".to_string(),
996 schema: vec!["id".to_string(), "name".to_string()],
997 rows: vec![(1, "|alice, Alice"), (2, "|bob, \"")],
998 has_ditto: true,
999 };
1000
1001 assert!(!batch.can_parallelize());
1002 }
1003
1004 #[test]
1005 fn test_parse_single_matrix_row() {
1006 let header = Header::new((1, 0));
1007 let schema = vec!["id".to_string(), "name".to_string()];
1008
1009 let node = parse_single_matrix_row("|alice, Alice", &schema, "User", &header, 1).unwrap();
1010
1011 assert_eq!(node.id, "alice");
1012 assert_eq!(node.type_name, "User");
1013 assert_eq!(node.fields.len(), 2);
1014 }
1015
1016 #[test]
1017 fn test_parse_single_matrix_row_basic() {
1018 let header = Header::new((2, 0));
1019 let schema = vec!["id".to_string(), "name".to_string()];
1020
1021 let node = parse_single_matrix_row("|alice, Alice", &schema, "User", &header, 1).unwrap();
1022
1023 assert_eq!(node.id, "alice");
1024 assert_eq!(node.type_name, "User");
1025 assert_eq!(node.child_count, 0);
1027 }
1028}