1use std::ops::Range;
2
3use pulldown_cmark::{Event, LinkType, Options, Parser, Tag, TagEnd};
4
5use crate::page::{BlockId, Heading, PageId, WikilinkFragment, WikilinkOccurrence};
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9pub enum RangeKind {
10 Prose,
12 Heading,
14 Frontmatter,
16 CodeBlock,
18 InlineCode,
20 Wikilink,
22 Embed,
24 Url,
26 HtmlBlock,
28 HtmlInline,
30}
31
32#[derive(Debug, Clone)]
34pub struct ClassifiedRange {
35 pub kind: RangeKind,
36 pub byte_range: Range<usize>,
37}
38
39fn parser_options() -> Options {
40 Options::ENABLE_WIKILINKS
41 | Options::ENABLE_YAML_STYLE_METADATA_BLOCKS
42 | Options::ENABLE_TABLES
43 | Options::ENABLE_STRIKETHROUGH
44 | Options::ENABLE_HEADING_ATTRIBUTES
45}
46
47pub fn classify_ranges(source: &str) -> Vec<ClassifiedRange> {
52 let parser = Parser::new_ext(source, parser_options());
53 let offset_iter = parser.into_offset_iter();
54
55 let mut ranges = Vec::new();
56 let mut context_stack: Vec<RangeKind> = Vec::new();
59
60 for (event, range) in offset_iter {
61 match event {
62 Event::Start(Tag::MetadataBlock(_)) => {
63 context_stack.push(RangeKind::Frontmatter);
64 }
65 Event::End(TagEnd::MetadataBlock(_)) => {
66 context_stack.pop();
67 ranges.push(ClassifiedRange {
68 kind: RangeKind::Frontmatter,
69 byte_range: range,
70 });
71 }
72
73 Event::Start(Tag::Heading { .. }) => {
74 context_stack.push(RangeKind::Heading);
75 }
76 Event::End(TagEnd::Heading(_)) => {
77 context_stack.pop();
78 }
79
80 Event::Start(Tag::CodeBlock(_)) => {
81 context_stack.push(RangeKind::CodeBlock);
82 }
83 Event::End(TagEnd::CodeBlock) => {
84 context_stack.pop();
85 }
86
87 Event::Start(Tag::HtmlBlock) => {
88 context_stack.push(RangeKind::HtmlBlock);
89 }
90 Event::End(TagEnd::HtmlBlock) => {
91 context_stack.pop();
92 }
93
94 Event::Start(Tag::Link {
96 link_type: LinkType::WikiLink { .. },
97 ..
98 }) => {
99 ranges.push(ClassifiedRange {
100 kind: RangeKind::Wikilink,
101 byte_range: range,
102 });
103 context_stack.push(RangeKind::Wikilink);
104 }
105 Event::End(TagEnd::Link) if context_stack.last() == Some(&RangeKind::Wikilink) => {
106 context_stack.pop();
107 }
108
109 Event::Start(Tag::Image {
111 link_type: LinkType::WikiLink { .. },
112 ..
113 }) => {
114 ranges.push(ClassifiedRange {
115 kind: RangeKind::Embed,
116 byte_range: range,
117 });
118 context_stack.push(RangeKind::Embed);
119 }
120 Event::End(TagEnd::Image) if context_stack.last() == Some(&RangeKind::Embed) => {
121 context_stack.pop();
122 }
123
124 Event::Start(Tag::Link {
126 link_type: LinkType::Autolink | LinkType::Email,
127 ..
128 }) => {
129 ranges.push(ClassifiedRange {
130 kind: RangeKind::Url,
131 byte_range: range,
132 });
133 context_stack.push(RangeKind::Url);
134 }
135 Event::End(TagEnd::Link) if context_stack.last() == Some(&RangeKind::Url) => {
136 context_stack.pop();
137 }
138
139 Event::Code(_) => {
141 ranges.push(ClassifiedRange {
142 kind: RangeKind::InlineCode,
143 byte_range: range,
144 });
145 }
146
147 Event::InlineHtml(_) => {
149 ranges.push(ClassifiedRange {
150 kind: RangeKind::HtmlInline,
151 byte_range: range,
152 });
153 }
154
155 Event::Text(_) => {
159 let kind = context_stack.last().copied().unwrap_or(RangeKind::Prose);
160 match kind {
161 RangeKind::Wikilink | RangeKind::Embed | RangeKind::Url => {}
162 _ => {
163 ranges.push(ClassifiedRange {
164 kind,
165 byte_range: range,
166 });
167 }
168 }
169 }
170
171 _ => {}
174 }
175 }
176
177 ranges.sort_by_key(|r| r.byte_range.start);
178 ranges
179}
180
181pub fn extract_wikilinks(source: &str) -> Vec<WikilinkOccurrence> {
183 let parser = Parser::new_ext(source, parser_options());
184 let offset_iter = parser.into_offset_iter();
185 let mut wikilinks = Vec::new();
186
187 for (event, range) in offset_iter {
188 let (dest_url, is_embed) = match &event {
189 Event::Start(Tag::Link {
190 link_type: LinkType::WikiLink { .. },
191 dest_url,
192 ..
193 }) => (dest_url.as_ref(), false),
194 Event::Start(Tag::Image {
195 link_type: LinkType::WikiLink { .. },
196 dest_url,
197 ..
198 }) => (dest_url.as_ref(), true),
199 _ => continue,
200 };
201
202 let (page_str, fragment) = match dest_url.split_once('#') {
203 Some((page, frag)) => {
204 let fragment = if let Some(block) = frag.strip_prefix('^') {
205 WikilinkFragment::Block(BlockId::from(block))
206 } else {
207 WikilinkFragment::Heading(frag.to_owned())
208 };
209 (page, Some(fragment))
210 }
211 None => (dest_url, None),
212 };
213
214 wikilinks.push(WikilinkOccurrence {
215 page: PageId::from(page_str),
216 fragment,
217 is_embed,
218 byte_range: range,
219 });
220 }
221
222 wikilinks
223}
224
225pub fn extract_headings(source: &str) -> Vec<Heading> {
227 let parser = Parser::new_ext(source, parser_options());
228 let offset_iter = parser.into_offset_iter();
229 let mut headings = Vec::new();
230 let mut in_heading: Option<(u8, Range<usize>)> = None;
231 let mut heading_text = String::new();
232
233 for (event, range) in offset_iter {
234 match event {
235 Event::Start(Tag::Heading { level, .. }) => {
236 in_heading = Some((level as u8, range));
237 heading_text.clear();
238 }
239 Event::Text(text) if in_heading.is_some() => {
240 heading_text.push_str(&text);
241 }
242 Event::End(TagEnd::Heading(_)) => {
243 if let Some((level, start_range)) = in_heading.take() {
244 headings.push(Heading {
245 level,
246 text: std::mem::take(&mut heading_text),
247 byte_range: start_range.start..range.end,
248 });
249 }
250 }
251 _ => {}
252 }
253 }
254
255 headings
256}
257
258pub fn extract_block_ids(source: &str) -> Vec<BlockId> {
260 let mut block_ids = Vec::new();
264 for line in source.lines() {
265 let trimmed = line.trim();
266 if let Some(id) = trimmed.strip_prefix('^')
267 && !id.is_empty()
268 && id.chars().all(|c| c.is_ascii_alphanumeric() || c == '-')
269 {
270 block_ids.push(BlockId::from(id));
271 }
272 if let Some(pos) = trimmed.rfind(" ^") {
274 let candidate = &trimmed[pos + 2..];
275 if !candidate.is_empty()
276 && candidate
277 .chars()
278 .all(|c| c.is_ascii_alphanumeric() || c == '-')
279 {
280 block_ids.push(BlockId::from(candidate));
281 }
282 }
283 }
284 block_ids
285}
286
287#[cfg(test)]
288mod tests {
289 use super::*;
290
291 #[test]
292 fn classifies_frontmatter_as_non_prose() {
293 let source = "---\ntitle: Test\ntags: [a]\n---\n\nSome prose here.";
294 let ranges = classify_ranges(source);
295 let fm_ranges: Vec<_> = ranges
296 .iter()
297 .filter(|r| r.kind == RangeKind::Frontmatter)
298 .collect();
299 assert!(!fm_ranges.is_empty(), "should have frontmatter ranges");
300
301 let prose_ranges: Vec<_> = ranges
302 .iter()
303 .filter(|r| r.kind == RangeKind::Prose)
304 .collect();
305 assert!(!prose_ranges.is_empty(), "should have prose ranges");
306 for pr in &prose_ranges {
308 let text = &source[pr.byte_range.clone()];
309 if text.contains("Some prose") {
310 return;
311 }
312 }
313 panic!("prose range should contain 'Some prose here.'");
314 }
315
316 #[test]
317 fn classifies_wikilinks() {
318 let source = "Text with [[GRPO]] and more.";
319 let ranges = classify_ranges(source);
320 let wl_ranges: Vec<_> = ranges
321 .iter()
322 .filter(|r| r.kind == RangeKind::Wikilink)
323 .collect();
324 assert_eq!(wl_ranges.len(), 1);
325 }
326
327 #[test]
328 fn classifies_headings_as_non_prose() {
329 let source = "# My Heading\n\nParagraph text.";
330 let ranges = classify_ranges(source);
331 let heading_ranges: Vec<_> = ranges
332 .iter()
333 .filter(|r| r.kind == RangeKind::Heading)
334 .collect();
335 assert!(!heading_ranges.is_empty());
336 for hr in &heading_ranges {
337 let text = &source[hr.byte_range.clone()];
338 assert!(
339 text.contains("My Heading"),
340 "heading range should contain heading text, got: {text:?}"
341 );
342 }
343 }
344
345 #[test]
346 fn classifies_code_blocks() {
347 let source = "Text before\n\n```rust\nlet x = 1;\n```\n\nText after";
348 let ranges = classify_ranges(source);
349 let code_ranges: Vec<_> = ranges
350 .iter()
351 .filter(|r| r.kind == RangeKind::CodeBlock)
352 .collect();
353 assert!(!code_ranges.is_empty());
354 }
355
356 #[test]
357 fn classifies_inline_code() {
358 let source = "Use `GRPO` here.";
359 let ranges = classify_ranges(source);
360 let code_ranges: Vec<_> = ranges
361 .iter()
362 .filter(|r| r.kind == RangeKind::InlineCode)
363 .collect();
364 assert_eq!(code_ranges.len(), 1);
365 assert_eq!(&source[code_ranges[0].byte_range.clone()], "`GRPO`");
366 }
367
368 #[test]
369 fn extracts_wikilinks_with_fragments() {
370 let source = "See [[post-training#^method-comparison]] for details.";
371 let wikilinks = extract_wikilinks(source);
372 assert_eq!(wikilinks.len(), 1);
373 assert_eq!(wikilinks[0].page.as_str(), "post-training");
374 assert_eq!(
375 wikilinks[0].fragment,
376 Some(WikilinkFragment::Block(BlockId::from("method-comparison")))
377 );
378 assert!(!wikilinks[0].is_embed);
379 }
380
381 #[test]
382 fn extracts_embed_wikilinks() {
383 let source = "![[post-training#^method-comparison]]";
384 let wikilinks = extract_wikilinks(source);
385 assert_eq!(wikilinks.len(), 1);
386 assert!(wikilinks[0].is_embed);
387 }
388
389 #[test]
390 fn extracts_heading_fragment() {
391 let source = "See [[page#Some Heading]] for details.";
392 let wikilinks = extract_wikilinks(source);
393 assert_eq!(wikilinks.len(), 1);
394 assert_eq!(
395 wikilinks[0].fragment,
396 Some(WikilinkFragment::Heading("Some Heading".to_owned()))
397 );
398 }
399
400 #[test]
401 fn extracts_headings() {
402 let source = "# Title\n\nParagraph\n\n## Section One\n\nMore text\n\n### Sub Section";
403 let headings = extract_headings(source);
404 assert_eq!(headings.len(), 3);
405 assert_eq!(headings[0].level, 1);
406 assert_eq!(headings[0].text, "Title");
407 assert_eq!(headings[1].level, 2);
408 assert_eq!(headings[1].text, "Section One");
409 assert_eq!(headings[2].level, 3);
410 assert_eq!(headings[2].text, "Sub Section");
411 }
412
413 #[test]
414 fn extracts_block_ids() {
415 let source = "Some content\n\n^method-comparison\n\nMore content ^inline-block";
416 let block_ids = extract_block_ids(source);
417 assert_eq!(block_ids.len(), 2);
418 assert_eq!(block_ids[0].as_str(), "method-comparison");
419 assert_eq!(block_ids[1].as_str(), "inline-block");
420 }
421}