c2rust-refactor 0.15.0

C2Rust refactoring tool implementation
'''
Labeled points.

This is similar to the `annot` module, but works with individual points instead
of spans.
'''
from typing import List, Tuple, Optional, Callable, Generic, TypeVar

from literate.annot import Span, Annot


T = TypeVar('T')
U = TypeVar('U')

class Point(Generic[T]):
    '''An index in some sequence, with a label applied.'''
    __slots__ = ('pos', 'label')

    def __init__(self, pos: int, label: T=None):
        self.pos = pos
        self.label = label

    def __add__(self, x: int) -> 'Point[T]':
        return Point(self.pos + x, self.label)

    def __sub__(self, x: int) -> 'Point[T]':
        return Point(self.pos - x, self.label)

    def __str__(self) -> str:
        return 'Point(%d, %r)' % (self.pos, self.label)

    def __repr__(self) -> str:
        return self.__str__()

    def copy(self) -> 'Point[T]':
        return Point(self.pos, self.label)

def annot_starts(annot: Annot[T]) -> List[Point[T]]:
    '''Get the start point of each span in `annot`, labeled with the span's
    original label.'''
    return [Point(s.start, s.label) for s in annot]

def annot_ends(annot: Annot[T]) -> List[Point[T]]:
    '''Get the end point of each span in `annot`, labeled with the span's
    original label.'''
    return [Point(s.end, s.label) for s in annot]

def annot_to_deltas(annot: Annot[T]) -> List[Point[Tuple[Optional[T], Optional[T]]]]:
    '''Turn an annotation into a list of points, where each point is on a span
    boundary and is labeled with the labels of the previous and next spans.'''
    if len(annot) == 0:
        return []

    result = []

    # The first span's start and the last span's end are special cases, since
    # they have no previous/next span to compare against.
    first = annot[0]
    result.append(Point(first.start, (None, first.label)))

    for (s1, s2) in zip(annot, annot[1:]):
        if s1.end == s2.start:
            # These spans are adjacent.  Record a transition directly from one
            # to the next at their shared boundary.
            result.append(Point(s1.end, (s1.label, s2.label)))
        else:
            # There is a gap between spans.  Record transitions to `None` and
            # back.
            result.append(Point(s1.end, (s1.label, None)))
            result.append(Point(s2.start, (None, s2.label)))

    # Note this works even when `len(annot) == 1` and thus `last is first`.
    last = annot[-1]
    result.append(Point(last.end, (last.label, None)))

    return result


def merge_points(p1: List[Point[T]], p2: List[Point[T]],
        *args: List[Point[T]]) -> List[Point[T]]:
    '''Merge two (sorted) lists of points, returning a sorted result.  If `p1`
    and `p2` have points at the same position, the result will contain all the
    `p1` points at that position (in their original order), followed by all the
    ones from `p2`.'''
    if len(args) > 0:
        acc = merge_points(p1, p2)
        for ps in args:
            acc = merge_points(acc, ps)
        return acc

    i1 = 0
    i2 = 0

    result = []

    while i1 < len(p1) and i2 < len(p2):
        if p1[i1].pos <= p2[i2].pos:
            result.append(p1[i1])
            i1 += 1
        else:
            result.append(p2[i2])
            i2 += 1

    result.extend(p1[i1:])
    result.extend(p2[i2:])

    return result

def map_points(ps: List[Point[T]], f: Callable[[T], U]) -> List[Point[U]]:
    '''Map `f` over the labels of all points in `ps`.'''
    return [Point(p.pos, f(p.label)) for p in ps]

def cut_points(orig: List[Point[T]], cut: Annot[U],
        include_start: bool=True, include_end: bool=False) -> List[Tuple[Span[U], List[Point[T]]]]:
    '''Cut a list of points `orig` into pieces, one for each span in `Cut`.
    Returns `len(cut)` pairs of (cut_span, points), where `points` is a subset
    of `orig` that falls within `cut_span`.  Position 0 in `points` corresponds
    to `cut_span.start` in the overall text.
    
    By default, a point lying on `cut_span`'s start is included in the
    `points`, while a point lying on its end is not.  This can be configured
    with `include_start` and `include_end`.  Note that setting both
    `include_start` and `include_end` to `True` can cause a point to appear
    twice in the output, if it falls on the boundary between two adjacent cut
    spans.'''

    i = 0
    pieces = []

    for cut_span in cut:
        acc = []
        def emit(p):
            acc.append(p - cut_span.start)

        # In order, handle:
        #  1. Points strictly before `cut_span`
        #  2. Points at `cut_span.start`
        #  3. Points strictly within `cut_span`
        #  4. Points at `cut_span.end` (only if `include_end` is `True`)

        while i < len(orig) and orig[i].pos < cut_span.start:
            i += 1

        while i < len(orig) and orig[i].pos == cut_span.start:
            if include_start:
                emit(orig[i])
            i += 1

        while i < len(orig) and orig[i].pos < cut_span.end:
            emit(orig[i])
            i += 1

        if include_end:
            saved_i = i
            while i < len(orig) and orig[i].pos == cut_span.end:
                emit(orig[i])
                i += 1
            if include_start:
                # Rewind, so the same points can be processed by the next
                # iteration.  When `include_start` and `include_end` are both
                # set, and two spans are adjacent (`span1.end == span2.start`),
                # we want the points on the boundary to appear in both pieces.
                i = saved_i

        pieces.append((cut_span, acc))

    return pieces

def cut_annot_at_points(orig: Annot[T], cut: List[Point[U]]) -> Annot[T]:
    '''Cut the spans of annotation `orig` at each point in `cut`.  The
    resulting annotation applies all the same labels to the same regions as in
    `orig`, but any span that previously crossed a `cut` point is broken into
    two or more consecutive subspans.'''
    result = []
    def emit(s):
        # Filter out any zero-length spans.  This should only happen when two
        # points in `cut` occupy the same position.
        if len(s) > 0:
            result.append(s)

    i = 0
    for span in orig:
        # Skip points that lie strictly before `span`.
        while i < len(cut) and cut[i].pos <= span.start:
            i += 1

        # For each point that lies inside `span`, emit the subspan before the
        # point, then check for additional cut points in the subspan after the
        # point.
        while i < len(cut) and cut[i].pos < span.end:
            emit(Span(span.start, cut[i].pos, span.label))
            span = Span(cut[i].pos, span.end, span.label)
            i += 1

        emit(span)

    return result