use std::ops::Range;
use mdwright_document::{Document, ParseError, ParseOptions};
#[derive(Copy, Clone, Debug)]
pub(crate) struct BlockCheckpoint {
pub(crate) byte: u32,
#[expect(dead_code, reason = "reserved for LSP incremental-rebuild")]
pub(crate) parser_state: u64,
}
#[derive(Debug)]
pub struct CheckpointTable {
source_len: u32,
points: Vec<BlockCheckpoint>,
}
impl CheckpointTable {
pub fn build(source: &str) -> Result<Self, ParseError> {
Self::build_with_options(source, ParseOptions::default())
}
pub fn build_with_options(source: &str, parse_options: ParseOptions) -> Result<Self, ParseError> {
let doc = Document::parse_with_options(source, parse_options)?;
Ok(Self::from_document(&doc))
}
#[must_use]
pub fn from_document(doc: &Document) -> Self {
let facts = doc
.block_checkpoints()
.iter()
.map(|point| {
let byte = usize::try_from(point.byte).unwrap_or(usize::MAX);
let original = doc.canonical_to_original_range(byte..byte).start;
mdwright_document::BlockCheckpointFact {
byte: u32::try_from(original).unwrap_or(u32::MAX),
parser_state: point.parser_state,
}
})
.collect();
Self::from_facts(doc.original_source().len(), facts)
}
fn from_facts(source_len: usize, facts: Vec<mdwright_document::BlockCheckpointFact>) -> Self {
let source_len = u32::try_from(source_len).unwrap_or(u32::MAX);
let mut points: Vec<BlockCheckpoint> = facts
.into_iter()
.map(|point| BlockCheckpoint {
byte: point.byte,
parser_state: point.parser_state,
})
.collect();
if points.last().is_none_or(|last| last.byte < source_len) {
points.push(BlockCheckpoint {
byte: source_len,
parser_state: 0,
});
}
Self { source_len, points }
}
pub fn snap_to_block_boundaries(&self, range: Range<u32>) -> Range<u32> {
let req_start = range.start.min(self.source_len);
let req_end = range.end.min(self.source_len).max(req_start);
let lo_idx = match self.points.binary_search_by_key(&req_start, |p| p.byte) {
Ok(i) => i,
Err(i) => i.saturating_sub(1),
};
let lo = self.points.get(lo_idx).map_or(0, |p| p.byte);
let hi_idx = match self.points.binary_search_by_key(&req_end, |p| p.byte) {
Ok(i) => i,
Err(i) => i,
};
let hi = self.points.get(hi_idx).map_or(self.source_len, |p| p.byte);
lo..hi
}
#[must_use]
pub fn len(&self) -> usize {
self.points.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.points.len() <= 2
}
}
#[cfg(test)]
#[allow(clippy::expect_used)]
mod tests {
use super::CheckpointTable;
#[test]
fn empty_source() {
let t = CheckpointTable::build("").expect("checkpoint source parses");
assert_eq!(t.len(), 1);
assert!(t.is_empty());
assert_eq!(t.snap_to_block_boundaries(0..0), 0..0);
}
#[test]
fn three_paragraphs() {
let src = "a\n\nb\n\nc\n";
let t = CheckpointTable::build(src).expect("checkpoint source parses");
assert!(t.len() >= 4);
let snapped = t.snap_to_block_boundaries(3..4);
assert_eq!(&src[snapped.start as usize..snapped.end as usize], "b\n\n");
}
#[test]
fn range_inside_list_snaps_to_list_boundaries() {
let src = "para\n\n1. one\n2. two\n3. three\n\ntail\n";
let t = CheckpointTable::build(src).expect("checkpoint source parses");
let two_at = src.find("two").unwrap_or(0);
let snapped = t.snap_to_block_boundaries(two_at as u32..two_at as u32 + 1);
let slice = &src[snapped.start as usize..snapped.end as usize];
assert!(slice.contains("1. one"), "slice should include list start: {slice:?}");
assert!(slice.contains("3. three"), "slice should include list end: {slice:?}");
}
#[test]
fn range_past_end_snaps_to_empty() {
let src = "a\n";
let t = CheckpointTable::build(src).expect("checkpoint source parses");
let snapped = t.snap_to_block_boundaries(99..100);
assert_eq!(snapped, src.len() as u32..src.len() as u32);
}
#[test]
fn frontmatter_is_one_prelude_region() {
let src = "---\ntitle: x\n---\n# heading\n\npara\n";
let t = CheckpointTable::build(src).expect("checkpoint source parses");
let heading_at = src.find("# heading").unwrap_or(0);
let snapped = t.snap_to_block_boundaries(2..3);
assert!(snapped.start as usize <= heading_at);
assert!(snapped.end as usize >= heading_at);
}
#[test]
fn crlf_source_preserves_original_offsets() {
let src = "a\r\n\r\nb\r\n";
let t = CheckpointTable::build(src).expect("checkpoint source parses");
let b_at = src.find('b').unwrap_or(0) as u32;
let snapped = t.snap_to_block_boundaries(b_at..b_at + 1);
let slice = &src[snapped.start as usize..snapped.end as usize];
assert!(slice.contains('b'));
}
}