mdbook_ai_pocket_reference/
ai_pocket_reference.rs

1use handlebars::{to_json, Handlebars};
2use mdbook::book::{Book, BookItem};
3use mdbook::preprocess::{Preprocessor, PreprocessorContext};
4use once_cell::sync::Lazy;
5use regex::{CaptureMatches, Captures, Regex};
6use serde::Serialize;
7use serde_json::value::Map;
8use std::collections::HashMap;
9
10const AIPR_HEADER_TEMPLATE: &str = include_str!("./templates/header.hbs");
11const AIPR_FOOTER_HTML: &str = include_str!("./templates/footer.html");
12const MDLINK_TEMPLATE: &str = include_str!("./templates/md_link.hbs");
13const WORDS_PER_MINUTE: usize = 200;
14
15#[derive(Default)]
16pub struct AIPRPreprocessor;
17
18/// A preprocessor for expanding AI-Pocket-Reference helpers.
19///
20/// Supported helpers are:
21///
22/// - `{{#aipr_header <param-str>}}` - Adds the ai-pocket-reference header (optional param-str)
23impl AIPRPreprocessor {
24    pub(crate) const NAME: &'static str = "ai-pocket-reference";
25
26    /// Create a new `AIPRPreprocessor`.
27    pub fn new() -> Self {
28        AIPRPreprocessor
29    }
30}
31
32impl Preprocessor for AIPRPreprocessor {
33    fn name(&self) -> &str {
34        Self::NAME
35    }
36
37    fn run(&self, _ctx: &PreprocessorContext, mut book: Book) -> anyhow::Result<Book> {
38        // This run method's implementation follows the implementation of
39        // mdbook::preprocess::links::LinkPreprocessor.run().
40        book.for_each_mut(|section: &mut BookItem| {
41            if let BookItem::Chapter(ref mut ch) = *section {
42                let word_count = words_count::count(&ch.content);
43                let mut content = replace_all(&ch.content, word_count.words);
44
45                // add footer with logo
46                content.push_str(AIPR_FOOTER_HTML);
47
48                // mutate chapter content
49                ch.content = content;
50            }
51        });
52        Ok(book)
53    }
54}
55
56fn replace_all(s: &str, num_words: usize) -> String {
57    // First replace all AIPR links
58    let aipr_replaced = replace_all_aipr_links(s, num_words);
59
60    // Then replace all Markdown links
61    replace_all_md_links(&aipr_replaced)
62}
63
64fn replace_all_aipr_links(s: &str, num_words: usize) -> String {
65    // This implementation follows closely to the implementation of
66    // mdbook::preprocess::links::replace_all.
67    let mut previous_end_index = 0;
68    let mut replaced = String::new();
69
70    for link in find_aipr_links(s) {
71        replaced.push_str(&s[previous_end_index..link.start_index]);
72        let new_content = link.render(num_words).unwrap(); // todo: better error handling
73        replaced.push_str(&new_content);
74        previous_end_index = link.end_index;
75    }
76
77    replaced.push_str(&s[previous_end_index..]);
78    replaced
79}
80
81fn replace_all_md_links(s: &str) -> String {
82    let mut previous_end_index = 0;
83    let mut replaced = String::new();
84
85    for link in find_md_links(s) {
86        // Add text up to the current link
87        let prefix = &s[previous_end_index..link.start_index];
88        replaced.push_str(prefix);
89
90        // Check if the prefix ends with a backslash or exclamation mark
91        let last_char = prefix.chars().last();
92        let is_escaped = last_char == Some('\\') || last_char == Some('!');
93
94        if is_escaped {
95            // For escaped links, just add the original link text
96            replaced.push_str(&s[link.start_index..link.end_index]);
97        } else {
98            // For normal links, render as HTML
99            let new_content = link.render().unwrap();
100            replaced.push_str(&new_content);
101        }
102
103        previous_end_index = link.end_index;
104    }
105
106    replaced.push_str(&s[previous_end_index..]);
107    replaced
108}
109
110#[derive(PartialEq, Debug, Clone)]
111enum AIPRLinkType {
112    Header(AIPRHeaderSettings),
113}
114
115#[derive(Debug, Clone, PartialEq)]
116struct AIPRHeaderSettings {
117    reading_time: bool,
118    submit_issue: bool,
119    colab: Option<String>,
120}
121
122impl Default for AIPRHeaderSettings {
123    fn default() -> Self {
124        Self {
125            reading_time: true,
126            submit_issue: true,
127            colab: None,
128        }
129    }
130}
131
132fn _parse_param_str(param_str: &str) -> HashMap<String, String> {
133    param_str
134        .split(',')
135        .filter_map(|pair| {
136            pair.split_once('=')
137                .map(|(key, value)| (key.trim().to_string(), value.trim().to_string()))
138        })
139        .collect()
140}
141
142impl AIPRHeaderSettings {
143    fn from_param_str(param_str: &str) -> Self {
144        let param_map = _parse_param_str(param_str);
145        let colab = param_map.get("colab").map(|s| s.to_owned());
146        let reading_time =
147            !matches!(param_map.get("reading_time"), Some(bool_str) if (bool_str == "false"));
148        let submit_issue =
149            !matches!(param_map.get("submit_issue"), Some(bool_str) if (bool_str == "false"));
150
151        Self {
152            reading_time,
153            submit_issue,
154            colab,
155        }
156    }
157}
158
159#[derive(PartialEq, Debug, Clone)]
160struct AIPRLink<'a> {
161    start_index: usize,
162    end_index: usize,
163    link_type: AIPRLinkType,
164    link_text: &'a str,
165}
166
167impl<'a> AIPRLink<'a> {
168    #[allow(dead_code)]
169    fn from_capture(cap: Captures<'a>) -> Option<AIPRLink<'a>> {
170        let link_type = match (cap.get(0), cap.get(1), cap.get(2)) {
171            (_, Some(typ), None) if typ.as_str() == "aipr_header" => {
172                Some(AIPRLinkType::Header(AIPRHeaderSettings::default()))
173            }
174            (_, Some(typ), Some(param_str)) if typ.as_str() == "aipr_header" => {
175                Some(AIPRLinkType::Header(AIPRHeaderSettings::from_param_str(
176                    param_str.as_str().trim(),
177                )))
178            }
179            _ => None,
180        };
181
182        link_type.and_then(|lnk_type| {
183            cap.get(0).map(|mat| AIPRLink {
184                start_index: mat.start(),
185                end_index: mat.end(),
186                link_type: lnk_type,
187                link_text: mat.as_str(),
188            })
189        })
190    }
191
192    fn render(&self, num_words: usize) -> anyhow::Result<String> {
193        match &self.link_type {
194            AIPRLinkType::Header(settings) => {
195                let mut handlebars = Handlebars::new();
196                // register template from const str and assign a name to it
197                handlebars
198                    .register_template_string("aipr_header", AIPR_HEADER_TEMPLATE)
199                    .unwrap();
200
201                // create data for rendering handlebar
202                let mut data = Map::new();
203                if let Some(colab_path) = &settings.colab {
204                    let colab_nb = ColabNB {
205                        path: colab_path.to_owned(),
206                    };
207                    data.insert("colab_nb".to_string(), to_json(colab_nb));
208                }
209                data.insert("submit_issue".to_string(), to_json(settings.submit_issue));
210                if settings.reading_time {
211                    let rt_in_mins = (num_words as f32 / WORDS_PER_MINUTE as f32).round();
212                    let rt = ReadingTime {
213                        value: format!("{:.0} min", rt_in_mins),
214                    };
215                    data.insert("reading_time".to_string(), to_json(rt));
216                }
217
218                // render
219                let html_string = handlebars.render("aipr_header", &data)?;
220
221                Ok(html_string)
222            }
223        }
224    }
225}
226
227#[derive(PartialEq, Debug, Clone, Serialize)]
228pub struct ColabNB {
229    path: String,
230}
231
232#[derive(PartialEq, Debug, Clone, Serialize)]
233pub struct ReadingTime {
234    value: String,
235}
236
237struct AIPRLinkIter<'a>(CaptureMatches<'a, 'a>);
238
239impl<'a> Iterator for AIPRLinkIter<'a> {
240    type Item = AIPRLink<'a>;
241    fn next(&mut self) -> Option<AIPRLink<'a>> {
242        for cap in &mut self.0 {
243            if let Some(inc) = AIPRLink::from_capture(cap) {
244                return Some(inc);
245            }
246        }
247        None
248    }
249}
250
251fn find_aipr_links(contents: &str) -> AIPRLinkIter<'_> {
252    // lazily compute following regex
253    // r"\\\{\{#.*\}\}|\{\{#([a-zA-Z0-9]+)\s*([^}]+)\}\}")?;
254    static RE: Lazy<Regex> = Lazy::new(|| {
255        Regex::new(
256            r"(?x)              # insignificant whitespace mode
257        \\\{\{\#.*\}\}      # match escaped link
258        |                   # or
259        \{\{\s*             # link opening parens and whitespace
260        \#([a-zA-Z0-9_]+)   # link type
261        \s+                 # separating whitespace
262        ([^}]+)?            # link target path and space separated properties (optional)
263        \}\}                # link closing parens",
264        )
265        .unwrap()
266    });
267
268    AIPRLinkIter(RE.captures_iter(contents))
269}
270
271#[derive(PartialEq, Debug, Clone)]
272struct MDLink<'a> {
273    start_index: usize,
274    end_index: usize,
275    text: &'a str,
276    url: &'a str,
277}
278
279impl<'a> MDLink<'a> {
280    #[allow(dead_code)]
281    fn from_capture(cap: Captures<'a>) -> Option<MDLink<'a>> {
282        let md_tuple = match (cap.get(0), cap.get(1), cap.get(2)) {
283            (_, Some(text_str), Some(url_str))
284                if (url_str.as_str().starts_with("https://")
285                    || url_str.as_str().starts_with("http://")) =>
286            {
287                Some((text_str.as_str(), url_str.as_str()))
288            }
289            _ => None,
290        };
291
292        md_tuple.and_then(|(text, url)| {
293            cap.get(0).map(|mat| MDLink {
294                start_index: mat.start(),
295                end_index: mat.end(),
296                text,
297                url,
298            })
299        })
300    }
301
302    #[allow(dead_code)]
303    fn render(&self) -> anyhow::Result<String> {
304        let mut handlebars = Handlebars::new();
305
306        // register template
307        handlebars
308            .register_template_string("md_link_expansion", MDLINK_TEMPLATE.trim())
309            .unwrap();
310
311        // create data for rendering handlebar
312        let mut data = Map::new();
313        data.insert("text".to_string(), to_json(self.text));
314        data.insert("url".to_string(), to_json(self.url));
315
316        // render
317        let html_string = handlebars.render("md_link_expansion", &data)?;
318
319        Ok(html_string)
320    }
321}
322
323struct MDLinkIter<'a>(CaptureMatches<'a, 'a>);
324
325impl<'a> Iterator for MDLinkIter<'a> {
326    type Item = MDLink<'a>;
327    fn next(&mut self) -> Option<MDLink<'a>> {
328        for cap in &mut self.0 {
329            if let Some(inc) = MDLink::from_capture(cap) {
330                return Some(inc);
331            }
332        }
333        None
334    }
335}
336
337fn find_md_links(contents: &str) -> MDLinkIter<'_> {
338    static RE: Lazy<Regex> = Lazy::new(|| {
339        Regex::new(
340            r"(?x)
341            \[([^\]]*(?:\\.[^\]]*)*)\]    # link text in square brackets
342            \(([^)]*(?:\\.[^)]*)*)\)      # link URL in parentheses
343            ",
344        )
345        .unwrap()
346    });
347
348    MDLinkIter(RE.captures_iter(contents))
349}
350
351#[cfg(test)]
352mod tests {
353    use super::*;
354    use anyhow::Result;
355    use rstest::*;
356
357    #[fixture]
358    fn simple_book_content() -> String {
359        "{{ #aipr_header }} {{ #aipr_header colab=nlp/lora.ipynb }} Some random [text with](https://fake.io) and more text ..."
360            .to_string()
361    }
362
363    #[rstest]
364    fn test_find_links_no_author_links() -> Result<()> {
365        let s = "Some random text without link...";
366        assert!(find_aipr_links(s).collect::<Vec<_>>() == vec![]);
367        assert!(find_md_links(s).collect::<Vec<_>>() == vec![]);
368        Ok(())
369    }
370
371    #[rstest]
372    fn test_find_links_empty_link() -> Result<()> {
373        let s = "Some random text with {{#colab  }} and {{}} {{#}}...";
374        println!("{:?}", find_aipr_links(s).collect::<Vec<_>>());
375        assert!(find_aipr_links(s).collect::<Vec<_>>() == vec![]);
376        Ok(())
377    }
378
379    #[rstest]
380    fn test_find_links_unknown_link_type() -> Result<()> {
381        let s = "Some random \\[text with\\](test) {{#my_author ar.rs}} and {{#auth}} {{baz}} {{#bar}}...";
382        assert!(find_aipr_links(s).collect::<Vec<_>>() == vec![]);
383        assert!(find_md_links(s).collect::<Vec<_>>() == vec![]);
384        Ok(())
385    }
386
387    #[rstest]
388    fn test_find_links_simple_author_links(simple_book_content: String) -> Result<()> {
389        let res = find_aipr_links(&simple_book_content[..]).collect::<Vec<_>>();
390        println!("\nOUTPUT: {res:?}\n");
391
392        assert_eq!(
393            res,
394            vec![
395                AIPRLink {
396                    start_index: 0,
397                    end_index: 18,
398                    link_type: AIPRLinkType::Header(AIPRHeaderSettings::default()),
399                    link_text: "{{ #aipr_header }}",
400                },
401                AIPRLink {
402                    start_index: 19,
403                    end_index: 58,
404                    link_type: AIPRLinkType::Header(AIPRHeaderSettings::from_param_str(
405                        "colab=nlp/lora.ipynb"
406                    )),
407                    link_text: "{{ #aipr_header colab=nlp/lora.ipynb }}",
408                },
409            ]
410        );
411        Ok(())
412    }
413
414    #[rstest]
415    #[case(
416        "submit_issue=false,colab=nlp/lora.ipynb,reading_time=false",
417        AIPRHeaderSettings {
418            colab: Some("nlp/lora.ipynb".to_string()),
419            submit_issue: false,
420            reading_time: false
421        }
422    )]
423    #[case(
424        "colab=nlp/lora.ipynb",
425        AIPRHeaderSettings {
426            colab: Some("nlp/lora.ipynb".to_string()),
427            ..Default::default()
428        }
429    )]
430    #[case(
431        "reading_time=falsee",
432        AIPRHeaderSettings {
433            ..Default::default()
434        }
435    )]
436    fn test_aipr_header_settings(
437        #[case] param_str: &str,
438        #[case] expected_setting: AIPRHeaderSettings,
439    ) -> Result<()> {
440        let setting = AIPRHeaderSettings::from_param_str(param_str);
441        assert_eq!(setting, expected_setting);
442
443        Ok(())
444    }
445
446    #[rstest]
447    fn test_link_render() -> Result<()> {
448        let link = AIPRLink {
449            start_index: 19,
450            end_index: 58,
451            link_type: AIPRLinkType::Header(AIPRHeaderSettings::from_param_str(
452                "colab=nlp/lora.ipynb",
453            )),
454            link_text: "{{ #aipr_header colab=nlp/lora.ipynb }}",
455        };
456        let num_words = 201;
457
458        let html_string = link.render(num_words)?;
459        let expected = "<div style=\"display: flex; justify-content: \
460        space-between; align-items: center; margin-bottom: 2em;\">\n  <div>\n    \
461        <a target=\"_blank\" href=\"https://github.com/VectorInstitute/\
462        ai-pocket-reference/issues/new?template=edit-request.yml\">\n      \
463        <img src=\"https://img.shields.io/badge/Suggest_an_Edit-black?logo=\
464        github&style=flat\" alt=\"Suggest an Edit\"/>\n    </a>\n    \
465        <a target=\"_blank\" href=\"https://colab.research.google.com/github/\
466        VectorInstitute/ai-pocket-reference-code/blob/main/notebooks/nlp/lora.ipynb\
467        \">\n      <img src=\"https://colab.research.google.com/assets/colab-badge.svg\
468        \" alt=\"Open In Colab\"/>\n    </a>\n    <p style=\"margin: 0;\">\
469        <small>Reading time: 1 min</small></p>\n  </div>\n</div>\n";
470
471        println!("{:#?}", html_string);
472
473        assert_eq!(html_string, expected);
474
475        Ok(())
476    }
477
478    #[rstest]
479    fn test_link_render_no_colab() -> Result<()> {
480        let link = AIPRLink {
481            start_index: 19,
482            end_index: 58,
483            link_type: AIPRLinkType::Header(AIPRHeaderSettings::default()),
484            link_text: "{{ #aipr_header }}",
485        };
486        let num_words = 301;
487
488        let html_string = link.render(num_words)?;
489        let expected = "<div style=\"display: flex; justify-content: \
490        space-between; align-items: center; margin-bottom: 2em;\">\n  <div>\n    \
491        <a target=\"_blank\" href=\"https://github.com/VectorInstitute/\
492        ai-pocket-reference/issues/new?template=edit-request.yml\">\n      \
493        <img src=\"https://img.shields.io/badge/Suggest_an_Edit-black?logo=\
494        github&style=flat\" alt=\"Suggest an Edit\"/>\n    </a>\n    \
495        <p style=\"margin: 0;\"><small>Reading time: 2 min</small></p>\n  \
496        </div>\n</div>\n";
497
498        assert_eq!(html_string, expected);
499
500        Ok(())
501    }
502
503    #[rstest]
504    fn test_link_render_no_colab_no_reading_time() -> Result<()> {
505        let link = AIPRLink {
506            start_index: 19,
507            end_index: 58,
508            link_type: AIPRLinkType::Header(AIPRHeaderSettings::from_param_str(
509                "reading_time=false",
510            )),
511            link_text: "{{ #aipr_header reading_time=false }}",
512        };
513        let num_words = 200;
514
515        let html_string = link.render(num_words)?;
516        let expected = "<div style=\"display: flex; justify-content: \
517        space-between; align-items: center; margin-bottom: 2em;\">\n  <div>\n    \
518        <a target=\"_blank\" href=\"https://github.com/VectorInstitute/\
519        ai-pocket-reference/issues/new?template=edit-request.yml\">\n      \
520        <img src=\"https://img.shields.io/badge/Suggest_an_Edit-black?logo=\
521        github&style=flat\" alt=\"Suggest an Edit\"/>\n    </a>\n  \
522        </div>\n</div>\n";
523
524        assert_eq!(html_string, expected);
525
526        Ok(())
527    }
528
529    #[rstest]
530    fn test_finds_md_link(simple_book_content: String) -> Result<()> {
531        let res = find_md_links(&simple_book_content[..]).collect::<Vec<_>>();
532        println!("\nOUTPUT: {res:?}\n");
533
534        assert_eq!(
535            res,
536            vec![MDLink {
537                start_index: 71,
538                end_index: 99,
539                text: "text with",
540                url: "https://fake.io"
541            }]
542        );
543
544        Ok(())
545    }
546
547    #[rstest]
548    fn test_md_link_render() -> Result<()> {
549        let link = MDLink {
550            start_index: 19,
551            end_index: 58,
552            text: "some text",
553            url: "https://fake.io",
554        };
555
556        let html_string = link.render()?;
557        let expected = "<a href=\"https://fake.io\" target=\"_blank\" \
558        rel=\"noopener noreferrer\">some text</a>";
559
560        assert_eq!(html_string, expected);
561
562        Ok(())
563    }
564
565    #[rstest]
566    fn test_replace_all_md_links() -> Result<()> {
567        let content = "This is [good link](https://good.io), \
568            whereas ![this](https://not-covered.io), and \
569            neither is \\[this\\](http://not-covered.io).";
570
571        let new_content = replace_all_md_links(content);
572        let expected = "This is <a href=\"https://good.io\" target=\"_blank\" \
573         rel=\"noopener noreferrer\">good link</a>, whereas ![this](https://not-covered.io), \
574         and neither is \\[this\\](http://not-covered.io).";
575
576        assert_eq!(new_content, expected);
577
578        Ok(())
579    }
580}