use std::borrow::Cow;
use super::utils::is_inside_code_block;
pub fn handle(text: &str) -> Cow<'_, str> {
if !text.contains('>') {
return Cow::Borrowed(text);
}
let bytes = text.as_bytes();
let len = bytes.len();
let mut result: Option<String> = None;
let mut last_copy = 0;
let mut line_start = 0;
while line_start < len {
let line_end = bytes[line_start..]
.iter()
.position(|&b| b == b'\n')
.map(|p| line_start + p)
.unwrap_or(len);
let line = &bytes[line_start..line_end];
if let Some(gt_offset) = find_list_comparison(line) {
let gt_pos = line_start + gt_offset;
if !is_inside_code_block(text, gt_pos) {
let buf = result.get_or_insert_with(|| String::with_capacity(len + 8));
buf.push_str(&text[last_copy..gt_pos]);
buf.push_str("\\>");
last_copy = gt_pos + 1;
}
}
line_start = line_end + 1;
}
match result {
Some(mut buf) => {
buf.push_str(&text[last_copy..]);
Cow::Owned(buf)
}
None => Cow::Borrowed(text),
}
}
fn find_list_comparison(line: &[u8]) -> Option<usize> {
let len = line.len();
let mut i = 0;
while i < len && matches!(line[i], b' ' | b'\t') {
i += 1;
}
if i >= len {
return None;
}
if matches!(line[i], b'-' | b'*' | b'+') {
i += 1;
} else if line[i].is_ascii_digit() {
while i < len && line[i].is_ascii_digit() {
i += 1;
}
if i >= len || !matches!(line[i], b'.' | b')') {
return None;
}
i += 1;
} else {
return None;
}
if i >= len || line[i] != b' ' {
return None;
}
while i < len && line[i] == b' ' {
i += 1;
}
if i >= len || line[i] != b'>' {
return None;
}
let gt_offset = i;
i += 1;
if i < len && line[i] == b'=' {
i += 1;
}
if i < len && line[i] == b' ' {
i += 1;
}
if i < len && line[i] == b'$' {
i += 1;
}
if i >= len || !line[i].is_ascii_digit() {
return None;
}
Some(gt_offset)
}
#[cfg(test)]
mod tests {
use super::handle;
use std::borrow::Cow;
#[test]
fn escapes_comparison_in_list() {
assert_eq!(handle("- > 25").as_ref(), "- \\> 25");
}
#[test]
fn escapes_gte_in_list() {
assert_eq!(handle("- >= 25").as_ref(), "- \\>= 25");
}
#[test]
fn leaves_blockquote() {
assert!(matches!(handle("- > text"), Cow::Borrowed(_)));
}
#[test]
fn leaves_non_list() {
assert!(matches!(handle("> 25"), Cow::Borrowed(_)));
}
#[test]
fn ordered_list() {
assert_eq!(handle("1. > 25").as_ref(), "1. \\> 25");
}
#[test]
fn no_angle_bracket() {
assert!(matches!(handle("- hello"), Cow::Borrowed(_)));
}
}