vyre-conform 0.1.0

Conformance suite for vyre backends — proves byte-identical output to CPU reference
Documentation
//! Lexically aware source rewriting helpers.

use std::ops::Range;

/// Return source byte ranges that are Rust code, excluding strings and comments.
#[inline]
pub(crate) fn code_spans(source: &str) -> Vec<Range<usize>> {
    let bytes = source.as_bytes();
    let mut spans = Vec::new();
    let mut start = 0;
    let mut i = 0;

    while i < bytes.len() {
        if bytes[i] == b'/' && bytes.get(i + 1) == Some(&b'/') {
            push_span(&mut spans, start, i);
            i += 2;
            while i < bytes.len() && bytes[i] != b'\n' {
                i += 1;
            }
            start = i;
            continue;
        }

        if bytes[i] == b'/' && bytes.get(i + 1) == Some(&b'*') {
            push_span(&mut spans, start, i);
            i = skip_block_comment(bytes, i + 2);
            start = i;
            continue;
        }

        if let Some(end) = raw_string_end(bytes, i) {
            push_span(&mut spans, start, i);
            i = end;
            start = i;
            continue;
        }

        if starts_normal_string(bytes, i) {
            push_span(&mut spans, start, i);
            i = skip_quoted(bytes, i + string_quote_offset(bytes, i), b'"');
            start = i;
            continue;
        }

        if starts_char_literal(bytes, i) {
            push_span(&mut spans, start, i);
            i = skip_quoted(bytes, i, b'\'');
            start = i;
            continue;
        }

        i += 1;
    }

    push_span(&mut spans, start, source.len());
    spans
}

/// Replace code-only occurrences of `pattern`.
#[inline]
pub(crate) fn replace_code(source: &str, pattern: &str, replacement: &str, count: usize) -> String {
    replace_code_where(source, pattern, replacement, count, |_, _| true)
}

/// Replace code-only occurrences whose surrounding bytes are not identifier chars.
#[inline]
pub(crate) fn replace_code_word(
    source: &str,
    pattern: &str,
    replacement: &str,
    count: usize,
) -> String {
    replace_code_where(source, pattern, replacement, count, |src, idx| {
        let left = previous_char(src, idx);
        let right = next_char(src, idx + pattern.len());
        !left.is_some_and(is_ident_char) && !right.is_some_and(is_ident_char)
    })
}

/// Find the first code-only occurrence of `pattern`.
#[inline]
pub(crate) fn find_code(source: &str, pattern: &str) -> Option<usize> {
    code_spans(source).into_iter().find_map(|span| {
        source[span.clone()]
            .find(pattern)
            .map(|relative| span.start + relative)
    })
}

/// Test whether byte index `idx` is in a code range.
#[inline]
pub(crate) fn is_code_index(source: &str, idx: usize) -> bool {
    code_spans(source)
        .into_iter()
        .any(|span| span.start <= idx && idx < span.end)
}

#[inline]
pub(crate) fn previous_non_ws(source: &str, idx: usize) -> Option<char> {
    source[..idx].chars().rev().find(|c| !c.is_whitespace())
}

#[inline]
pub(crate) fn next_non_ws(source: &str, idx: usize) -> Option<char> {
    source.get(idx..)?.chars().find(|c| !c.is_whitespace())
}

fn replace_code_where(
    source: &str,
    pattern: &str,
    replacement: &str,
    count: usize,
    accept: impl Fn(&str, usize) -> bool,
) -> String {
    if pattern.is_empty() || count == 0 {
        return source.to_string();
    }

    let spans = code_spans(source);
    let mut out = String::with_capacity(source.len());
    let mut cursor = 0;
    let mut replaced = 0;

    for span in spans {
        let mut scan = span.start;
        while replaced < count && scan + pattern.len() <= span.end {
            let haystack = &source[scan..span.end];
            let Some(relative) = haystack.find(pattern) else {
                break;
            };
            let idx = scan + relative;
            if accept(source, idx) {
                out.push_str(&source[cursor..idx]);
                out.push_str(replacement);
                cursor = idx + pattern.len();
                scan = cursor;
                replaced += 1;
            } else {
                scan = idx + next_char_len(source, idx);
            }
        }
    }

    out.push_str(&source[cursor..]);
    out
}

fn push_span(spans: &mut Vec<Range<usize>>, start: usize, end: usize) {
    if start < end {
        spans.push(start..end);
    }
}

fn skip_block_comment(bytes: &[u8], mut i: usize) -> usize {
    let mut depth = 1usize;
    while i < bytes.len() {
        if bytes[i] == b'/' && bytes.get(i + 1) == Some(&b'*') {
            depth += 1;
            i += 2;
        } else if bytes[i] == b'*' && bytes.get(i + 1) == Some(&b'/') {
            depth -= 1;
            i += 2;
            if depth == 0 {
                return i;
            }
        } else {
            i += 1;
        }
    }
    bytes.len()
}

fn raw_string_end(bytes: &[u8], i: usize) -> Option<usize> {
    let mut cursor = i;
    if bytes.get(cursor) == Some(&b'b') {
        cursor += 1;
    }
    if bytes.get(cursor) != Some(&b'r') {
        return None;
    }
    cursor += 1;
    let hashes_start = cursor;
    while bytes.get(cursor) == Some(&b'#') {
        cursor += 1;
    }
    if bytes.get(cursor) != Some(&b'"') {
        return None;
    }
    let hashes = cursor - hashes_start;
    cursor += 1;
    while cursor < bytes.len() {
        if bytes[cursor] == b'"' {
            let end_hashes = cursor + 1 + hashes;
            if end_hashes <= bytes.len() && bytes[cursor + 1..end_hashes].iter().all(|b| *b == b'#')
            {
                return Some(end_hashes);
            }
        }
        cursor += 1;
    }
    Some(bytes.len())
}

fn starts_normal_string(bytes: &[u8], i: usize) -> bool {
    bytes.get(i) == Some(&b'"')
        || (bytes.get(i) == Some(&b'b') && bytes.get(i + 1) == Some(&b'"'))
        || (bytes.get(i) == Some(&b'c') && bytes.get(i + 1) == Some(&b'"'))
}

fn string_quote_offset(bytes: &[u8], i: usize) -> usize {
    if bytes.get(i) == Some(&b'"') {
        i
    } else {
        i + 1
    }
}

fn starts_char_literal(bytes: &[u8], i: usize) -> bool {
    let quote = if bytes.get(i) == Some(&b'\'') {
        i
    } else if bytes.get(i) == Some(&b'b') && bytes.get(i + 1) == Some(&b'\'') {
        i + 1
    } else {
        return false;
    };
    let Some(end) = bounded_char_end(bytes, quote) else {
        return false;
    };
    end > quote + 2
}

fn bounded_char_end(bytes: &[u8], quote: usize) -> Option<usize> {
    let mut i = quote + 1;
    let mut escaped = false;
    let limit = bytes.len().min(quote + 12);
    while i < limit {
        if escaped {
            escaped = false;
        } else if bytes[i] == b'\\' {
            escaped = true;
        } else if bytes[i] == b'\'' {
            return Some(i + 1);
        } else if bytes[i] == b'\n' {
            return None;
        }
        i += 1;
    }
    None
}

fn skip_quoted(bytes: &[u8], quote: usize, delimiter: u8) -> usize {
    let mut i = quote + 1;
    let mut escaped = false;
    while i < bytes.len() {
        if escaped {
            escaped = false;
        } else if bytes[i] == b'\\' {
            escaped = true;
        } else if bytes[i] == delimiter {
            return i + 1;
        }
        i += 1;
    }
    bytes.len()
}

fn previous_char(source: &str, idx: usize) -> Option<char> {
    source[..idx].chars().next_back()
}

fn next_char(source: &str, idx: usize) -> Option<char> {
    source.get(idx..)?.chars().next()
}

fn next_char_len(source: &str, idx: usize) -> usize {
    source[idx..].chars().next().map_or(1, char::len_utf8)
}

fn is_ident_char(ch: char) -> bool {
    ch.is_ascii_alphanumeric() || ch == '_'
}

/// Convert a 1-based line and 0-based UTF-8 character column to a byte offset.
#[inline]
pub(crate) fn line_column_to_byte_offset(source: &str, line: usize, column: usize) -> usize {
    let mut current_line = 1;
    let mut current_col = 0;
    for (byte_offset, ch) in source.char_indices() {
        if current_line == line && current_col == column {
            return byte_offset;
        }
        if ch == '\n' {
            current_line += 1;
            current_col = 0;
        } else {
            current_col += 1;
        }
    }
    source.len()
}