use crate::Segment;
use smallvec::SmallVec;
#[derive(Debug, Clone, Copy)]
pub struct GroupDef {
pub name: &'static str,
pub trigger: &'static str,
pub children: &'static [GroupDef],
}
#[derive(Debug)]
pub struct SegmentGroup<'a> {
pub definition: &'static str,
pub segments: Vec<Segment<'a>>,
pub children: Vec<SegmentGroup<'a>>,
}
impl<'a> SegmentGroup<'a> {
fn new(definition: &'static str) -> Self {
Self {
definition,
segments: Vec::new(),
children: Vec::new(),
}
}
pub fn all_segments(&self) -> impl Iterator<Item = &Segment<'a>> + '_ {
AllSegmentsIter::new(self)
}
pub fn find_segment(&self, tag: &str) -> Option<&Segment<'a>> {
self.segments.iter().find(|s| s.tag == tag)
}
}
struct AllSegmentsIter<'g, 'a> {
stack: Vec<(&'g SegmentGroup<'a>, usize, usize)>,
}
impl<'g, 'a> AllSegmentsIter<'g, 'a> {
fn new(root: &'g SegmentGroup<'a>) -> Self {
Self {
stack: vec![(root, 0, 0)],
}
}
}
impl<'g, 'a> Iterator for AllSegmentsIter<'g, 'a> {
type Item = &'g Segment<'a>;
fn next(&mut self) -> Option<Self::Item> {
loop {
let (group, seg_idx, child_idx) = self.stack.last_mut()?;
if *seg_idx < group.segments.len() {
let seg = &group.segments[*seg_idx];
*seg_idx += 1;
return Some(seg);
}
if *child_idx < group.children.len() {
let child = &group.children[*child_idx];
*child_idx += 1;
self.stack.push((child, 0, 0));
continue;
}
self.stack.pop();
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
let lower: usize = self
.stack
.iter()
.map(|(g, seg_idx, _)| g.segments.len().saturating_sub(*seg_idx))
.sum();
(lower, None)
}
}
pub fn group_segments<'a>(
segments: &[Segment<'a>],
schema: &'static [GroupDef],
root_name: &'static str,
) -> SegmentGroup<'a> {
let mut root = SegmentGroup::new(root_name);
group_recursive(segments, &mut root, schema);
root
}
fn group_recursive<'a>(
segments: &[Segment<'a>],
parent: &mut SegmentGroup<'a>,
schema: &'static [GroupDef],
) -> usize {
group_recursive_inner(segments, parent, schema, &[])
}
fn group_recursive_inner<'a>(
segments: &[Segment<'a>],
parent: &mut SegmentGroup<'a>,
schema: &'static [GroupDef],
stop_triggers: &[&'static str],
) -> usize {
let mut i = 0;
while i < segments.len() {
let tag = segments[i].tag;
if stop_triggers.iter().any(|t| *t == tag) {
break;
}
if let Some(def) = schema.iter().find(|d| d.trigger == tag) {
let mut child = SegmentGroup::new(def.name);
child.segments.push(segments[i].clone());
i += 1;
let mut combined_stop: SmallVec<[&'static str; 16]> =
SmallVec::from_slice(stop_triggers);
for d in schema {
if !combined_stop.contains(&d.trigger) {
combined_stop.push(d.trigger);
}
}
let consumed =
group_recursive_inner(&segments[i..], &mut child, def.children, &combined_stop);
i += consumed;
parent.children.push(child);
} else {
parent.segments.push(segments[i].clone());
i += 1;
}
}
i
}
#[cfg(test)]
mod tests {
use super::*;
use crate::model::Element;
use crate::Span;
fn seg(tag: &'static str) -> Segment<'static> {
Segment {
tag,
span: Span::new(0, 0),
tag_span: Span::new(0, 0),
elements: vec![Element::of(&["x"])],
}
}
static SCHEMA: &[GroupDef] = &[
GroupDef {
name: "SG1",
trigger: "NAD",
children: &[GroupDef {
name: "SG2",
trigger: "CTA",
children: &[],
}],
},
GroupDef {
name: "SG3",
trigger: "LIN",
children: &[],
},
];
#[test]
fn root_segments_before_first_trigger() {
let segs = vec![seg("UNH"), seg("BGM"), seg("NAD")];
let tree = group_segments(&segs, SCHEMA, "ROOT");
assert_eq!(tree.segments.len(), 2, "UNH + BGM should be in root");
assert_eq!(tree.children.len(), 1);
assert_eq!(tree.children[0].definition, "SG1");
}
#[test]
fn repeated_trigger_creates_multiple_children() {
let segs = vec![
seg("UNH"),
seg("NAD"),
seg("NAD"),
seg("UNT"),
];
let tree = group_segments(&segs, SCHEMA, "ROOT");
assert_eq!(tree.children.iter().filter(|c| c.definition == "SG1").count(), 2);
}
#[test]
fn nested_child_groups() {
let segs = vec![seg("NAD"), seg("CTA"), seg("CTA")];
let tree = group_segments(&segs, SCHEMA, "ROOT");
let sg1 = &tree.children[0];
assert_eq!(sg1.definition, "SG1");
assert_eq!(sg1.children.len(), 2);
assert!(sg1.children.iter().all(|c| c.definition == "SG2"));
}
#[test]
fn all_segments_iterator_depth_first() {
let segs = vec![seg("UNH"), seg("NAD"), seg("CTA")];
let tree = group_segments(&segs, SCHEMA, "ROOT");
let tags: Vec<_> = tree.all_segments().map(|s| s.tag).collect();
assert!(tags.contains(&"UNH"));
assert!(tags.contains(&"NAD"));
assert!(tags.contains(&"CTA"));
}
}