c2rust-refactor 0.15.0

C2Rust refactoring tool implementation
from typing import List, Tuple, Dict, Set, Optional, NamedTuple

from literate.annot import Span, Annot, fill_annot, zip_annot, cut_annot, \
        lookup_span, SpanMerger
from literate.file import File, Diff
from literate.points import Point, cut_points, annot_to_deltas


# This is the same info contained in a `marks.json` entry, just slightly more
# convenient to access.

class Mark(NamedTuple):
    '''This is the same info contained in a `marks.json` entry, just slightly
    more convenient to access.'''

    id: int
    '''The node ID of the marked node'''

    orig_id: int
    '''The "original node ID" reported by `c2rust-refactor`.  This is used to
    identify nodes across refactoring steps, so we can detect when a mark is
    left unchanged even if nodes got renumbered.'''

    labels: Set[str]
    '''A list of strings, giving the labels applied to the marked node.'''

    kind: str
    '''A string describing the kind of node that was marked.  `"item"`,
    `"expr"`, `"stmt"`, etc.'''

    name: Optional[str]
    '''The name of the marked node.  May be `None` or the empty string.  This
    is mainly useful on item nodes.'''

LabelChanges = Tuple[List[str], List[str], List[str]]

def convert_marks(marks: List[Dict]) -> Dict[int, Mark]:
    '''Convert a list of `marks.json` entries to `Mark` objects, and build a
    dict mapping each `Mark`'s `id` to the `Mark` itself.'''
    result = {}
    for m in marks:
        if len(m['labels']) == 0:
            continue
        assert m['id'] not in result
        result[m['id']] = Mark(
                m['id'],
                m['orig_id'],
                m['labels'],
                m['kind'],
                m['name'],
                )
    return result

CRATE_NODE_ID = 0

def build_mark_annot(f: File) -> Annot[Set[int]]:
    '''Build an annotation on the entire file, labeled with sets of NodeIds
    indicating the marked nodes overlapping each source location.'''
    # We start with one big annotation that labels the entire file with the
    # empty set (or the singleton set containing CRATE_NODE_ID, if the crate is
    # marked), and zip it with an annotation for each marked node in turn.
    if CRATE_NODE_ID not in f.marks:
        default = frozenset()
    else:
        default = frozenset((CRATE_NODE_ID,))
    annot = [Span(0, len(f.text), default)]

    for u_start, u_end, node_id in f.unformatted_nodes:
        if node_id not in f.marks:
            continue

        # `unformatted_nodes` uses source locations in the unformatted text,
        # which we need to translate to locations in the formatted text.
        start = f.fmt_map_translate(u_start)
        end = f.fmt_map_translate(u_end)

        node_annot = fill_annot([Span(start, end, frozenset((node_id,)))], 
                len(f.text), label=default)
        annot = zip_annot(annot, node_annot, f=lambda a, b: a | b)

    return annot


def init_line_mark_bounds(f: File):
    '''Initialize the `mark_starts` and `mark_ends` fields of the `Line`s in
    `f`.'''
    # We just build `mark_starts` and `mark_ends` point lists for the entire
    # file, and cut them up for the individual lines at the end.
    file_starts = []
    file_ends = []

    # Get a `point` for each change in `mark_annot`, indicating the sets of
    # overlapping marks before and after that point.
    for p in annot_to_deltas(f.mark_annot):
        old, new = p.label
        if old is None and new is None:
            continue
        elif old is None:
            started = new
            ended = set()
        elif new is None:
            started = set()
            ended = old
        else:
            started = new - old
            ended = old - new

        if len(started) > 0:
            file_starts.append(Point(p.pos, started))
        if len(ended) > 0:
            file_ends.append(Point(p.pos, ended))

    for line_span, line_starts in cut_points(file_starts, f.line_annot,
            include_start=True, include_end=False):
        f.lines[line_span.label].set_mark_starts(line_starts)

    for line_span, line_ends in cut_points(file_ends, f.line_annot,
            include_start=False, include_end=True):
        f.lines[line_span.label].set_mark_ends(line_ends)

def mark_file(f: File):
    '''Process marks for a file, initializing `File.marks`, `File.mark_annot`,
    `Line.mark_starts`, and `Line.mark_ends` fields.'''
    f.set_marks(convert_marks(f.raw_marks))

    # Annotate file with sets of marks
    annot = build_mark_annot(f)
    # Filter out spans labeled with the empty set.
    annot = [s for s in annot if len(s.label) > 0]
    f.set_mark_annot(annot)

    init_line_mark_bounds(f)


def init_hunk_start_marks(f: File, lines: Span):
    '''Set `Line.hunk_start_marks` on the first line in `lines` to the set of
    marked nodes that overlap its start.'''
    if lines.start >= len(f.lines):
        return
    line = lines.start
    char = f.line_annot[line].start
    # Figure out which marks cover the start position of the first line.  Only
    # marks that cross the line start count - ones that begin or end exactly at
    # line start are not included.
    mark_span = lookup_span(f.mark_annot, char,
            include_start=False, include_end=False)
    if mark_span is not None and len(mark_span.label) > 0:
        f.lines[line].set_hunk_start_marks(mark_span.label)

def init_hunk_end_marks(f: File, lines: Span):
    '''Set `Line.hunk_end_marks` on the last line in `lines` to the set of
    marked nodes that overlap its end.'''
    if lines.end <= 0:
        return
    line = lines.end - 1
    char = f.line_annot[line].end
    # Figure out which marks cover the end position of the last line.
    mark_span = lookup_span(f.mark_annot, char,
            include_start=False, include_end=False)
    if mark_span is not None and len(mark_span.label) > 0:
        f.lines[line].set_hunk_end_marks(mark_span.label)

def init_hunk_boundary_marks(d: Diff):
    '''Initialize `Line.hunk_start_marks` and/or `Line.hunk_end_marks` for
    lines at the start/end of each hunk in `d`.'''
    for h in d.hunks:
        changed, old_lines, new_lines = h.blocks[0]
        init_hunk_start_marks(d.old_file, old_lines)
        init_hunk_start_marks(d.new_file, new_lines)

        changed, old_lines, new_lines = h.blocks[-1]
        init_hunk_end_marks(d.old_file, old_lines)
        init_hunk_end_marks(d.new_file, new_lines)


def diff_labels(l1: Set[str], l2: Set[str]) -> LabelChanges:
    '''Diff two collections of labels, producing lists of labels added, labels
    removed, and labels kept.'''
    l1 = set(l1)
    l2 = set(l2)
    added = l2 - l1
    removed = l1 - l2
    kept = l1 & l2
    return (sorted(added), sorted(removed), sorted(kept))

def init_mark_labels(d: Diff):
    '''Diff the marks present in `d.old_files` and `d.new_files`, and use that
    to initialize `File.mark_labels` for both files.'''
    old_marks = dict((m.orig_id, m) for m in d.old_file.marks.values())
    new_marks = dict((m.orig_id, m) for m in d.new_file.marks.values())

    old_labels = {}
    for m in d.old_file.marks.values():
        if m.orig_id in new_marks:
            # There is a corresponding node marked in the new file.  Diff `m`'s
            # labels with the ones for that new node.
            old_labels[m.id] = diff_labels(m.labels, new_marks[m.orig_id].labels)
        else:
            # There is no corresponding node - the node (and its mark) must
            # have been deleted.
            old_labels[m.id] = ([], sorted(m.labels), [])
    d.old_file.set_mark_labels(old_labels)

    new_labels = {}
    for m in d.new_file.marks.values():
        if m.orig_id in old_marks:
            new_labels[m.id] = diff_labels(old_marks[m.orig_id].labels, m.labels)
        else:
            new_labels[m.id] = (sorted(m.labels), [], [])
    d.new_file.set_mark_labels(new_labels)


def init_file_keep_mark_lines(f: File):
    '''Initialize `f.keep_mark_lines` with an annotation covering the start of
    each node where a mark was added or removed.'''
    # Figure out which marks were changed - text for these will be kept in the
    # output even if it's not part of any hunk's context.
    keep_marks = set()
    for node_id, (added, removed, kept) in f.mark_labels.items():
        if len(added) > 0 or len(removed) > 0:
            keep_marks.add(node_id)

    # Get the start line for each kept mark.
    keep_start_lines = set()
    for u_start, u_end, node_id in f.unformatted_nodes:
        if node_id not in keep_marks:
            continue
        start = f.fmt_map_translate(u_start)
        line_span = lookup_span(f.line_annot, start)
        keep_start_lines.add(line_span.label)

    # Label a region around each mark's start line.
    keep_lines = SpanMerger()
    for start in sorted(keep_start_lines):
        keep_lines.add(Span(start - 3, start + 6))

    f.set_keep_mark_lines(keep_lines.finish())

def init_keep_mark_lines(d: Diff):
    '''Initialize `f.keep_mark_lines` for the old and new files of `d`.

    The two files are processed independently, but it relies on information
    derived from previous diff processing (namely, `File.mark_labels`).'''
    init_file_keep_mark_lines(d.old_file)
    init_file_keep_mark_lines(d.new_file)