1use super::generic;
7
8pub trait MapDataVisitor<T, U> {
10 fn map_data(&mut self, data: T) -> U;
12
13 fn visit_document(&mut self, doc: generic::Document<T>) -> generic::Document<U> {
15 generic::Document {
16 blocks: doc
17 .blocks
18 .into_iter()
19 .map(|b| self.visit_block(b))
20 .collect(),
21 user_data: self.map_data(doc.user_data),
22 }
23 }
24
25 fn visit_block(&mut self, block: generic::Block<T>) -> generic::Block<U> {
27 match block {
28 generic::Block::Paragraph { content, user_data } => generic::Block::Paragraph {
29 content: content.into_iter().map(|i| self.visit_inline(i)).collect(),
30 user_data: self.map_data(user_data),
31 },
32 generic::Block::Heading(heading) => {
33 generic::Block::Heading(self.visit_heading(heading))
34 }
35 generic::Block::ThematicBreak { user_data } => generic::Block::ThematicBreak {
36 user_data: self.map_data(user_data),
37 },
38 generic::Block::BlockQuote { blocks, user_data } => generic::Block::BlockQuote {
39 blocks: blocks.into_iter().map(|b| self.visit_block(b)).collect(),
40 user_data: self.map_data(user_data),
41 },
42 generic::Block::List(list) => generic::Block::List(self.visit_list(list)),
43 generic::Block::CodeBlock(code_block) => {
44 generic::Block::CodeBlock(self.visit_code_block(code_block))
45 }
46 generic::Block::HtmlBlock { content, user_data } => generic::Block::HtmlBlock {
47 content,
48 user_data: self.map_data(user_data),
49 },
50 generic::Block::Definition(def) => {
51 generic::Block::Definition(self.visit_link_definition(def))
52 }
53 generic::Block::Table(table) => generic::Block::Table(self.visit_table(table)),
54 generic::Block::FootnoteDefinition(footnote) => {
55 generic::Block::FootnoteDefinition(self.visit_footnote_definition(footnote))
56 }
57 generic::Block::GitHubAlert(alert) => {
58 generic::Block::GitHubAlert(self.visit_github_alert(alert))
59 }
60 generic::Block::Empty { user_data } => generic::Block::Empty {
61 user_data: self.map_data(user_data),
62 },
63 }
64 }
65
66 fn visit_inline(&mut self, inline: generic::Inline<T>) -> generic::Inline<U> {
68 match inline {
69 generic::Inline::Text { content, user_data } => generic::Inline::Text {
70 content,
71 user_data: self.map_data(user_data),
72 },
73 generic::Inline::LineBreak { user_data } => generic::Inline::LineBreak {
74 user_data: self.map_data(user_data),
75 },
76 generic::Inline::Code { content, user_data } => generic::Inline::Code {
77 content,
78 user_data: self.map_data(user_data),
79 },
80 generic::Inline::Html { content, user_data } => generic::Inline::Html {
81 content,
82 user_data: self.map_data(user_data),
83 },
84 generic::Inline::Link(link) => generic::Inline::Link(self.visit_link(link)),
85 generic::Inline::LinkReference(link_ref) => {
86 generic::Inline::LinkReference(self.visit_link_reference(link_ref))
87 }
88 generic::Inline::Image(image) => generic::Inline::Image(self.visit_image(image)),
89 generic::Inline::Emphasis { content, user_data } => generic::Inline::Emphasis {
90 content: content.into_iter().map(|i| self.visit_inline(i)).collect(),
91 user_data: self.map_data(user_data),
92 },
93 generic::Inline::Strong { content, user_data } => generic::Inline::Strong {
94 content: content.into_iter().map(|i| self.visit_inline(i)).collect(),
95 user_data: self.map_data(user_data),
96 },
97 generic::Inline::Strikethrough { content, user_data } => {
98 generic::Inline::Strikethrough {
99 content: content.into_iter().map(|i| self.visit_inline(i)).collect(),
100 user_data: self.map_data(user_data),
101 }
102 }
103 generic::Inline::Autolink { url, user_data } => generic::Inline::Autolink {
104 url,
105 user_data: self.map_data(user_data),
106 },
107 generic::Inline::FootnoteReference { label, user_data } => {
108 generic::Inline::FootnoteReference {
109 label,
110 user_data: self.map_data(user_data),
111 }
112 }
113 generic::Inline::Empty { user_data } => generic::Inline::Empty {
114 user_data: self.map_data(user_data),
115 },
116 }
117 }
118
119 fn visit_heading(&mut self, heading: generic::Heading<T>) -> generic::Heading<U> {
121 generic::Heading {
122 kind: heading.kind,
123 content: heading
124 .content
125 .into_iter()
126 .map(|i| self.visit_inline(i))
127 .collect(),
128 user_data: self.map_data(heading.user_data),
129 }
130 }
131
132 fn visit_list(&mut self, list: generic::List<T>) -> generic::List<U> {
134 generic::List {
135 kind: list.kind,
136 items: list
137 .items
138 .into_iter()
139 .map(|i| self.visit_list_item(i))
140 .collect(),
141 user_data: self.map_data(list.user_data),
142 }
143 }
144
145 fn visit_list_item(&mut self, item: generic::ListItem<T>) -> generic::ListItem<U> {
147 generic::ListItem {
148 task: item.task,
149 blocks: item
150 .blocks
151 .into_iter()
152 .map(|b| self.visit_block(b))
153 .collect(),
154 user_data: self.map_data(item.user_data),
155 }
156 }
157
158 fn visit_code_block(&mut self, code_block: generic::CodeBlock<T>) -> generic::CodeBlock<U> {
160 generic::CodeBlock {
161 kind: code_block.kind,
162 literal: code_block.literal,
163 user_data: self.map_data(code_block.user_data),
164 }
165 }
166
167 fn visit_link_definition(
169 &mut self,
170 def: generic::LinkDefinition<T>,
171 ) -> generic::LinkDefinition<U> {
172 generic::LinkDefinition {
173 label: def
174 .label
175 .into_iter()
176 .map(|i| self.visit_inline(i))
177 .collect(),
178 destination: def.destination,
179 title: def.title,
180 user_data: self.map_data(def.user_data),
181 }
182 }
183
184 fn visit_table(&mut self, table: generic::Table<T>) -> generic::Table<U> {
186 generic::Table {
187 rows: table
188 .rows
189 .into_iter()
190 .map(|row| {
191 row.into_iter()
192 .map(|cell| cell.into_iter().map(|i| self.visit_inline(i)).collect())
193 .collect()
194 })
195 .collect(),
196 alignments: table.alignments,
197 user_data: self.map_data(table.user_data),
198 }
199 }
200
201 fn visit_footnote_definition(
203 &mut self,
204 footnote: generic::FootnoteDefinition<T>,
205 ) -> generic::FootnoteDefinition<U> {
206 generic::FootnoteDefinition {
207 label: footnote.label,
208 blocks: footnote
209 .blocks
210 .into_iter()
211 .map(|b| self.visit_block(b))
212 .collect(),
213 user_data: self.map_data(footnote.user_data),
214 }
215 }
216
217 fn visit_github_alert(
219 &mut self,
220 alert: generic::GitHubAlertNode<T>,
221 ) -> generic::GitHubAlertNode<U> {
222 generic::GitHubAlertNode {
223 alert_type: alert.alert_type,
224 blocks: alert
225 .blocks
226 .into_iter()
227 .map(|b| self.visit_block(b))
228 .collect(),
229 user_data: self.map_data(alert.user_data),
230 }
231 }
232
233 fn visit_link(&mut self, link: generic::Link<T>) -> generic::Link<U> {
235 generic::Link {
236 destination: link.destination,
237 title: link.title,
238 children: link
239 .children
240 .into_iter()
241 .map(|i| self.visit_inline(i))
242 .collect(),
243 user_data: self.map_data(link.user_data),
244 }
245 }
246
247 fn visit_image(&mut self, image: generic::Image<T>) -> generic::Image<U> {
249 generic::Image {
250 destination: image.destination,
251 title: image.title,
252 alt: image.alt,
253 user_data: self.map_data(image.user_data),
254 }
255 }
256
257 fn visit_link_reference(
259 &mut self,
260 link_ref: generic::LinkReference<T>,
261 ) -> generic::LinkReference<U> {
262 generic::LinkReference {
263 label: link_ref
264 .label
265 .into_iter()
266 .map(|i| self.visit_inline(i))
267 .collect(),
268 text: link_ref
269 .text
270 .into_iter()
271 .map(|i| self.visit_inline(i))
272 .collect(),
273 user_data: self.map_data(link_ref.user_data),
274 }
275 }
276}
277
278pub struct ClosureMapDataVisitor<T, U, F>
280where
281 F: FnMut(T) -> U,
282{
283 f: F,
284 _phantom: std::marker::PhantomData<(T, U)>,
285}
286
287impl<T, U, F> ClosureMapDataVisitor<T, U, F>
288where
289 F: FnMut(T) -> U,
290{
291 pub fn new(f: F) -> Self {
292 Self {
293 f,
294 _phantom: std::marker::PhantomData,
295 }
296 }
297}
298
299impl<T, U, F> MapDataVisitor<T, U> for ClosureMapDataVisitor<T, U, F>
300where
301 F: FnMut(T) -> U,
302{
303 fn map_data(&mut self, data: T) -> U {
304 (self.f)(data)
305 }
306}
307
308pub fn map_user_data<T, U, F>(doc: generic::Document<T>, f: F) -> generic::Document<U>
310where
311 F: FnMut(T) -> U,
312{
313 let mut visitor = ClosureMapDataVisitor::new(f);
314 visitor.visit_document(doc)
315}
316
317#[cfg(test)]
318mod tests {
319 use super::*;
320
321 #[test]
322 fn test_map_data_visitor_with_u32() {
323 let doc = generic::Document {
324 blocks: vec![generic::Block::Paragraph {
325 content: vec![generic::Inline::Text {
326 content: "Hello".to_string(),
327 user_data: 1u32,
328 }],
329 user_data: 2u32,
330 }],
331 user_data: 0u32,
332 };
333
334 let transformed = map_user_data(doc, |id| format!("element_{id}"));
336
337 assert_eq!(transformed.user_data, "element_0");
338 match &transformed.blocks[0] {
339 generic::Block::Paragraph { user_data, content } => {
340 assert_eq!(user_data, "element_2");
341 match &content[0] {
342 generic::Inline::Text { user_data, .. } => {
343 assert_eq!(user_data, "element_1");
344 }
345 _ => panic!("Expected text"),
346 }
347 }
348 _ => panic!("Expected paragraph"),
349 }
350 }
351
352 #[test]
353 fn test_complex_ast_transformation() {
354 let doc = generic::Document {
355 blocks: vec![
356 generic::Block::Heading(generic::Heading {
357 kind: crate::ast::HeadingKind::Atx(1),
358 content: vec![
359 generic::Inline::Text {
360 content: "Title".to_string(),
361 user_data: 1u32,
362 },
363 generic::Inline::Strong {
364 content: vec![generic::Inline::Text {
365 content: "Bold".to_string(),
366 user_data: 2u32,
367 }],
368 user_data: 3u32,
369 },
370 ],
371 user_data: 4u32,
372 }),
373 generic::Block::List(generic::List {
374 kind: generic::ListKind::Bullet(crate::ast::ListBulletKind::Dash),
375 items: vec![generic::ListItem {
376 task: None,
377 blocks: vec![generic::Block::Paragraph {
378 content: vec![generic::Inline::Text {
379 content: "Item".to_string(),
380 user_data: 5u32,
381 }],
382 user_data: 6u32,
383 }],
384 user_data: 7u32,
385 }],
386 user_data: 8u32,
387 }),
388 ],
389 user_data: 9u32,
390 };
391
392 let transformed = map_user_data(doc, |n| n * 10);
394
395 assert_eq!(transformed.user_data, 90);
396
397 match &transformed.blocks[0] {
399 generic::Block::Heading(heading) => {
400 assert_eq!(heading.user_data, 40);
401 match &heading.content[1] {
402 generic::Inline::Strong { user_data, content } => {
403 assert_eq!(*user_data, 30);
404 match &content[0] {
405 generic::Inline::Text { user_data, .. } => {
406 assert_eq!(*user_data, 20);
407 }
408 _ => panic!("Expected text"),
409 }
410 }
411 _ => panic!("Expected strong"),
412 }
413 }
414 _ => panic!("Expected heading"),
415 }
416
417 match &transformed.blocks[1] {
419 generic::Block::List(list) => {
420 assert_eq!(list.user_data, 80);
421 assert_eq!(list.items[0].user_data, 70);
422 }
423 _ => panic!("Expected list"),
424 }
425 }
426
427 #[test]
428 fn test_custom_visitor() {
429 struct CountingVisitor {
430 count: u32,
431 }
432
433 impl CountingVisitor {
434 fn new() -> Self {
435 Self { count: 0 }
436 }
437 }
438
439 impl MapDataVisitor<String, u32> for CountingVisitor {
440 fn map_data(&mut self, _data: String) -> u32 {
441 let result = self.count;
442 self.count += 1;
443 result
444 }
445 }
446
447 let doc = generic::Document {
448 blocks: vec![generic::Block::Paragraph {
449 content: vec![
450 generic::Inline::Text {
451 content: "First".to_string(),
452 user_data: "text1".to_string(),
453 },
454 generic::Inline::Text {
455 content: "Second".to_string(),
456 user_data: "text2".to_string(),
457 },
458 ],
459 user_data: "paragraph".to_string(),
460 }],
461 user_data: "document".to_string(),
462 };
463
464 let mut visitor = CountingVisitor::new();
465 let transformed = visitor.visit_document(doc);
466
467 let doc_id = transformed.user_data;
470 match &transformed.blocks[0] {
471 generic::Block::Paragraph { user_data, content } => {
472 let para_id = *user_data;
473 let text1_id = match &content[0] {
474 generic::Inline::Text { user_data, .. } => *user_data,
475 _ => panic!("Expected text"),
476 };
477 let text2_id = match &content[1] {
478 generic::Inline::Text { user_data, .. } => *user_data,
479 _ => panic!("Expected text"),
480 };
481
482 let mut ids = vec![doc_id, para_id, text1_id, text2_id];
484 ids.sort();
485 ids.dedup();
486 assert_eq!(ids.len(), 4, "All IDs should be unique");
487
488 assert!(ids.iter().all(|&id| id <= 3), "IDs should be 0-3");
490 }
491 _ => panic!("Expected paragraph"),
492 }
493 }
494}