from typing import List, Tuple, Iterator, Iterable, Callable, Optional, Any, Generic, TypeVar
T = TypeVar('T')
U = TypeVar('U')
V = TypeVar('V')
class Span(Generic[T]):
__slots__ = ('start', 'end', 'label')
def __init__(self, start: int, end: int, label: T=None):
assert start <= end
self.start = start
self.end = end
self.label = label
def is_empty(self) -> bool:
return self.end == self.start
def __len__(self) -> int:
return self.end - self.start
def __iter__(self) -> Iterator[int]:
return iter(range(self.start, self.end))
def __contains__(self, i: int) -> bool:
return self.start <= i < self.end
def overlaps(self, other: 'Span[Any]') -> bool:
return other.start < self.end and self.start < other.end
def overlaps_ends(self, other: 'Span[Any]') -> bool:
return other.start <= self.end and self.start <= other.end
def intersect(self, other: 'Span[Any]') -> 'Span[T]':
return Span(
max(self.start, other.start),
min(self.end, other.end),
self.label)
def contains(self, other: 'Span[Any]'):
return self.start <= other.start and other.end <= self.end
def __add__(self, x: int) -> 'Span[T]':
return Span(self.start + x, self.end + x, self.label)
def __sub__(self, x: int) -> 'Span[T]':
return Span(self.start - x, self.end - x, self.label)
def __str__(self) -> str:
return 'Span(%d, %d, %r)' % (self.start, self.end, self.label)
def __repr__(self) -> str:
return self.__str__()
def copy(self) -> 'Span[T]':
return Span(self.start, self.end, self.label)
Annot = List[Span[T]]
def number_lines(lines: List[str]) -> Annot[int]:
result = []
pos = 0
for i, l in enumerate(lines):
result.append(Span(pos, pos + len(l), i))
pos += len(l)
return result
def cut_annot(orig: Annot[T], cut: Annot[U]) -> List[Tuple[Span[U], Annot[T]]]:
i = 0
pieces = []
for cut_span in cut:
acc = []
while i < len(orig):
s = orig[i]
if s.overlaps(cut_span):
acc.append(s.intersect(cut_span) - cut_span.start)
if s.end > cut_span.end:
break
i += 1
pieces.append((cut_span, acc))
return pieces
def merge_annot(a1: Annot[T], a2: Annot[U]) -> Annot[None]:
result = SpanMerger()
i1 = 0
i2 = 0
while i1 < len(a1) and i2 < len(a2):
if a1[i1].start <= a2[i2].start:
result.add(a1[i1])
i1 += 1
else:
result.add(a2[i2])
i2 += 1
result.add_all(a1[i1:])
result.add_all(a2[i2:])
return result.finish()
def fill_annot(a: Annot[T], end: int, start: int=0, label: T=None) -> Annot[T]:
last_pos = start
result = []
for s in a:
if s.start > last_pos:
result.append(Span(last_pos, s.start, label))
result.append(s)
last_pos = s.end
if end > last_pos:
result.append(Span(last_pos, end, label))
return result
def invert_annot(a: Annot[T], end: int, start: int=0, label: U=None) -> Annot[U]:
last_pos = start
result = []
for s in a:
if s.start > last_pos:
result.append(Span(last_pos, s.start, label))
last_pos = s.end
if end > last_pos:
result.append(Span(last_pos, end, label))
return result
def sub_annot(a1: Annot[T], a2: Annot[U]) -> Annot[T]:
if a1 == []:
return []
end = a1[-1].end
result = []
for s2, ss1 in cut_annot(a1, invert_annot(a2, end)):
result.extend(s1 + s2.start for s1 in ss1)
return result
def zip_annot(a1: Annot[T], a2: Annot[U],
f: Callable[[T, U], V]=lambda l1, l2: (l1, l2)) -> Annot[V]:
result = []
for s2, ss1 in cut_annot(a1, a2):
for s1 in ss1:
start = s1.start + s2.start
end = s1.end + s2.start
result.append(Span(start, end, f(s1.label, s2.label)))
return result
def lookup_span(a: Annot[T], pos: int,
include_start: bool=True, include_end: bool=False) -> Optional[Span[T]]:
for s in a:
if s.end > pos or (include_end and s.end == pos):
if s.start < pos or (include_start and s.start == pos):
return s
else:
return None
return None
class SpanMerger(Generic[T]):
def __init__(self):
self.acc = []
def add(self, span: Span[T]):
if len(self.acc) > 0 and span.start <= self.acc[-1].end:
self.acc[-1].end = max(self.acc[-1].end, span.end)
else:
self.acc.append(span)
def add_all(self, spans: Iterable[Span[T]]):
for s in spans:
self.add(s)
def finish(self) -> Annot[T]:
result = self.acc
self.acc = None
return result