use crate::Segment;
use smallvec::SmallVec;
use std::ops::Range;
#[derive(Debug, Clone, Copy)]
pub struct GroupDef {
pub name: &'static str,
pub trigger: &'static str,
pub children: &'static [GroupDef],
}
#[derive(Debug)]
pub struct SegmentGroup<'a> {
pub definition: &'static str,
pub segments: Vec<Segment<'a>>,
pub children: Vec<SegmentGroup<'a>>,
}
impl<'a> SegmentGroup<'a> {
fn new(definition: &'static str) -> Self {
Self {
definition,
segments: Vec::new(),
children: Vec::new(),
}
}
pub fn all_segments(&self) -> impl Iterator<Item = &Segment<'a>> + '_ {
AllSegmentsIter::new(self)
}
pub fn find_segment(&self, tag: &str) -> Option<&Segment<'a>> {
self.segments.iter().find(|s| s.tag == tag)
}
}
struct AllSegmentsIter<'g, 'a> {
stack: SmallVec<[(&'g SegmentGroup<'a>, usize, usize); 8]>,
}
impl<'g, 'a> AllSegmentsIter<'g, 'a> {
fn new(root: &'g SegmentGroup<'a>) -> Self {
Self {
stack: smallvec::smallvec![(root, 0, 0)],
}
}
}
impl<'g, 'a> Iterator for AllSegmentsIter<'g, 'a> {
type Item = &'g Segment<'a>;
fn next(&mut self) -> Option<Self::Item> {
loop {
let (group, seg_idx, child_idx) = self.stack.last_mut()?;
if *seg_idx < group.segments.len() {
let seg = &group.segments[*seg_idx];
*seg_idx += 1;
return Some(seg);
}
if *child_idx < group.children.len() {
let child = &group.children[*child_idx];
*child_idx += 1;
self.stack.push((child, 0, 0));
continue;
}
self.stack.pop();
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
let lower: usize = self
.stack
.iter()
.map(|(g, seg_idx, _)| g.segments.len().saturating_sub(*seg_idx))
.sum();
(lower, None)
}
}
pub fn group_segments<'a>(
segments: &[Segment<'a>],
schema: &'static [GroupDef],
root_name: &'static str,
) -> SegmentGroup<'a> {
let mut root = SegmentGroup::new(root_name);
group_recursive(segments, &mut root, schema);
root
}
fn group_recursive<'a>(
segments: &[Segment<'a>],
parent: &mut SegmentGroup<'a>,
schema: &'static [GroupDef],
) -> usize {
group_recursive_inner(segments, parent, schema, &[])
}
fn group_recursive_inner<'a>(
segments: &[Segment<'a>],
parent: &mut SegmentGroup<'a>,
schema: &'static [GroupDef],
stop_triggers: &[&'static str],
) -> usize {
let combined_stop: SmallVec<[&'static str; 16]> = {
let mut v: SmallVec<[&'static str; 16]> = SmallVec::from_slice(stop_triggers);
for d in schema {
if !v.contains(&d.trigger) {
v.push(d.trigger);
}
}
v
};
let mut i = 0;
while i < segments.len() {
let tag = segments[i].tag;
if stop_triggers.iter().copied().any(|t| t == tag) {
break;
}
if let Some(def) = schema.iter().find(|d| d.trigger == tag) {
let mut child = SegmentGroup::new(def.name);
child.segments.push(segments[i].clone());
i += 1;
let consumed =
group_recursive_inner(&segments[i..], &mut child, def.children, &combined_stop);
i += consumed;
parent.children.push(child);
} else {
parent.segments.push(segments[i].clone());
i += 1;
}
}
i
}
#[derive(Debug)]
pub struct SegmentGroupIndexed {
pub definition: &'static str,
pub total_span: Range<usize>,
pub children: Vec<SegmentGroupIndexed>,
}
impl SegmentGroupIndexed {
pub fn direct_segment_indices(&self) -> impl Iterator<Item = usize> + '_ {
self.total_span.clone().filter(|i| {
!self
.children
.iter()
.any(|child| child.total_span.contains(i))
})
}
}
pub fn group_segments_indexed<'a>(
segments: &[Segment<'a>],
schema: &'static [GroupDef],
root_name: &'static str,
) -> SegmentGroupIndexed {
let mut root = SegmentGroupIndexed {
definition: root_name,
total_span: 0..0,
children: Vec::new(),
};
group_recursive_indexed(segments, &mut root, schema, &[], 0);
root
}
fn group_recursive_indexed<'a>(
segments: &[Segment<'a>],
parent: &mut SegmentGroupIndexed,
schema: &'static [GroupDef],
stop_triggers: &[&'static str],
offset: usize,
) -> usize {
let combined_stop: SmallVec<[&'static str; 16]> = {
let mut v: SmallVec<[&'static str; 16]> = SmallVec::from_slice(stop_triggers);
for d in schema {
if !v.contains(&d.trigger) {
v.push(d.trigger);
}
}
v
};
let span_start = if !parent.total_span.is_empty() {
parent.total_span.start } else {
offset
};
let mut i = 0;
while i < segments.len() {
let tag = segments[i].tag;
if stop_triggers.iter().copied().any(|t| t == tag) {
break;
}
if let Some(def) = schema.iter().find(|d| d.trigger == tag) {
let child_offset = offset + i;
let mut child = SegmentGroupIndexed {
definition: def.name,
total_span: child_offset..child_offset + 1,
children: Vec::new(),
};
i += 1;
let consumed = group_recursive_indexed(
&segments[i..],
&mut child,
def.children,
&combined_stop,
offset + i,
);
i += consumed;
parent.children.push(child);
} else {
i += 1;
}
}
parent.total_span = span_start..(offset + i);
i
}
#[cfg(test)]
mod tests {
use super::*;
use crate::Span;
use crate::model::Element;
fn seg(tag: &'static str) -> Segment<'static> {
Segment {
tag,
span: Span::new(0, 0),
tag_span: Span::new(0, 0),
elements: vec![Element::of(&["x"])],
}
}
static SCHEMA: &[GroupDef] = &[
GroupDef {
name: "SG1",
trigger: "NAD",
children: &[GroupDef {
name: "SG2",
trigger: "CTA",
children: &[],
}],
},
GroupDef {
name: "SG3",
trigger: "LIN",
children: &[],
},
];
#[test]
fn root_segments_before_first_trigger() {
let segs = vec![seg("UNH"), seg("BGM"), seg("NAD")];
let tree = group_segments(&segs, SCHEMA, "ROOT");
assert_eq!(tree.segments.len(), 2, "UNH + BGM should be in root");
assert_eq!(tree.children.len(), 1);
assert_eq!(tree.children[0].definition, "SG1");
}
#[test]
fn repeated_trigger_creates_multiple_children() {
let segs = vec![seg("UNH"), seg("NAD"), seg("NAD"), seg("UNT")];
let tree = group_segments(&segs, SCHEMA, "ROOT");
assert_eq!(
tree.children
.iter()
.filter(|c| c.definition == "SG1")
.count(),
2
);
}
#[test]
fn nested_child_groups() {
let segs = vec![seg("NAD"), seg("CTA"), seg("CTA")];
let tree = group_segments(&segs, SCHEMA, "ROOT");
let sg1 = &tree.children[0];
assert_eq!(sg1.definition, "SG1");
assert_eq!(sg1.children.len(), 2);
assert!(sg1.children.iter().all(|c| c.definition == "SG2"));
}
#[test]
fn all_segments_iterator_depth_first() {
let segs = vec![seg("UNH"), seg("NAD"), seg("CTA")];
let tree = group_segments(&segs, SCHEMA, "ROOT");
let tags: Vec<_> = tree.all_segments().map(|s| s.tag).collect();
assert!(tags.contains(&"UNH"));
assert!(tags.contains(&"NAD"));
assert!(tags.contains(&"CTA"));
}
}