from typing import List, Tuple, Optional, Callable, Generic, TypeVar
from literate.annot import Span, Annot
T = TypeVar('T')
U = TypeVar('U')
class Point(Generic[T]):
__slots__ = ('pos', 'label')
def __init__(self, pos: int, label: T=None):
self.pos = pos
self.label = label
def __add__(self, x: int) -> 'Point[T]':
return Point(self.pos + x, self.label)
def __sub__(self, x: int) -> 'Point[T]':
return Point(self.pos - x, self.label)
def __str__(self) -> str:
return 'Point(%d, %r)' % (self.pos, self.label)
def __repr__(self) -> str:
return self.__str__()
def copy(self) -> 'Point[T]':
return Point(self.pos, self.label)
def annot_starts(annot: Annot[T]) -> List[Point[T]]:
return [Point(s.start, s.label) for s in annot]
def annot_ends(annot: Annot[T]) -> List[Point[T]]:
return [Point(s.end, s.label) for s in annot]
def annot_to_deltas(annot: Annot[T]) -> List[Point[Tuple[Optional[T], Optional[T]]]]:
if len(annot) == 0:
return []
result = []
first = annot[0]
result.append(Point(first.start, (None, first.label)))
for (s1, s2) in zip(annot, annot[1:]):
if s1.end == s2.start:
result.append(Point(s1.end, (s1.label, s2.label)))
else:
result.append(Point(s1.end, (s1.label, None)))
result.append(Point(s2.start, (None, s2.label)))
last = annot[-1]
result.append(Point(last.end, (last.label, None)))
return result
def merge_points(p1: List[Point[T]], p2: List[Point[T]],
*args: List[Point[T]]) -> List[Point[T]]:
if len(args) > 0:
acc = merge_points(p1, p2)
for ps in args:
acc = merge_points(acc, ps)
return acc
i1 = 0
i2 = 0
result = []
while i1 < len(p1) and i2 < len(p2):
if p1[i1].pos <= p2[i2].pos:
result.append(p1[i1])
i1 += 1
else:
result.append(p2[i2])
i2 += 1
result.extend(p1[i1:])
result.extend(p2[i2:])
return result
def map_points(ps: List[Point[T]], f: Callable[[T], U]) -> List[Point[U]]:
return [Point(p.pos, f(p.label)) for p in ps]
def cut_points(orig: List[Point[T]], cut: Annot[U],
include_start: bool=True, include_end: bool=False) -> List[Tuple[Span[U], List[Point[T]]]]:
i = 0
pieces = []
for cut_span in cut:
acc = []
def emit(p):
acc.append(p - cut_span.start)
while i < len(orig) and orig[i].pos < cut_span.start:
i += 1
while i < len(orig) and orig[i].pos == cut_span.start:
if include_start:
emit(orig[i])
i += 1
while i < len(orig) and orig[i].pos < cut_span.end:
emit(orig[i])
i += 1
if include_end:
saved_i = i
while i < len(orig) and orig[i].pos == cut_span.end:
emit(orig[i])
i += 1
if include_start:
i = saved_i
pieces.append((cut_span, acc))
return pieces
def cut_annot_at_points(orig: Annot[T], cut: List[Point[U]]) -> Annot[T]:
result = []
def emit(s):
if len(s) > 0:
result.append(s)
i = 0
for span in orig:
while i < len(cut) and cut[i].pos <= span.start:
i += 1
while i < len(cut) and cut[i].pos < span.end:
emit(Span(span.start, cut[i].pos, span.label))
span = Span(cut[i].pos, span.end, span.label)
i += 1
emit(span)
return result