Skip to main content

html2md/rewriter/
writer.rs

1use super::handle::handle_tag;
2use super::quotes::rewrite_blockquote_text;
3use crate::clean_markdown_bytes;
4use crate::rewriter::{handle::handle_tag_send, quotes::rewrite_blockquote_text_send};
5use lol_html::{doc_comments, doctype, element, html_content::EndTag, text, RewriteStrSettings};
6use std::cell::Cell;
7use std::rc::Rc;
8use std::sync::{
9    atomic::{AtomicU8, AtomicUsize, Ordering},
10    Arc,
11};
12use url::Url;
13
14lazy_static::lazy_static! {
15    #[cfg(feature = "ignore_cookies")]
16    /// Cookie banner patterns.
17    static ref COOKIE_BANNER_SELECTOR: &'static str =
18        "body > #onetrust-banner-sdk,#didomi-host,#qc-cmp2-container,#cookie-banner,#__rptl-cookiebanner";
19}
20
21/// End tag handler type sync send.
22type EndHandler = Box<
23    dyn for<'b> FnOnce(
24            &mut EndTag<'b>,
25        ) -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>>
26        + Send
27        + 'static,
28>;
29
30/// End tag local handler type sync send.
31type LocalEndHandler = Box<
32    dyn for<'b> FnOnce(
33            &mut EndTag<'b>,
34        ) -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>>
35        + 'static,
36>;
37
38// ===== perf helpers =====
39
40#[inline]
41fn is_ascii_ws_only(s: &str) -> bool {
42    // Equivalent to "trim().is_empty()" for the whitespace you actually see in HTML formatting.
43    // Avoids Unicode trim work.
44    s.as_bytes()
45        .iter()
46        .all(|&b| matches!(b, b' ' | b'\n' | b'\r' | b'\t' | 0x0C))
47}
48
49/// Estimate the size of the markdown.
50fn estimate_markdown(html: &str) -> usize {
51    if html.is_empty() {
52        0
53    } else {
54        (html.len() / 2).max(50)
55    }
56}
57
58// ===== send flags packed into one atomic =====
59const F_IN_TABLE: u8 = 1 << 0;
60const F_LI_START: u8 = 1 << 1;
61
62#[inline]
63fn flag_set(flags: &AtomicU8, mask: u8) {
64    let _ = flags.fetch_or(mask, Ordering::Relaxed);
65}
66
67#[inline]
68fn flag_clear(flags: &AtomicU8, mask: u8) {
69    let _ = flags.fetch_and(!mask, Ordering::Relaxed);
70}
71
72/// Get the HTML rewriter settings to convert to markdown.
73pub fn get_rewriter_settings(
74    commonmark: bool,
75    custom: &Option<std::collections::HashSet<String>>,
76    url: Option<Url>,
77) -> RewriteStrSettings<'static, 'static> {
78    let mut list_type: Option<&'static str> = None;
79    let mut order_counter = 0usize;
80
81    let quote_depth = Rc::new(AtomicUsize::new(0));
82    let quote_depth1 = quote_depth.clone();
83
84    let repaired_head = Rc::new(std::sync::OnceLock::new());
85
86    // flags (non-send) are already fast
87    let list_item_start_flag = Rc::new(Cell::new(false));
88    let in_table_flag = Rc::new(Cell::new(false));
89
90    // state passed into handle_tag
91    let mut in_table = false;
92    let mut table_row_start = false;
93    let mut list_item_start = false;
94
95    let mut element_content_handlers = Vec::with_capacity(
96        4 + custom
97            .as_ref()
98            .map_or(0, |c| if c.is_empty() { 0 } else { 1 })
99            + {
100                #[cfg(feature = "ignore_cookies")]
101                {
102                    1
103                }
104                #[cfg(not(feature = "ignore_cookies"))]
105                {
106                    0
107                }
108            },
109    );
110
111    #[cfg(feature = "ignore_cookies")]
112    {
113        element_content_handlers.push(lol_html::element!(COOKIE_BANNER_SELECTOR, |el| {
114            el.remove();
115            Ok(())
116        }));
117    }
118
119    element_content_handlers.push(text!("blockquote, q, cite", move |el| {
120        let _ = rewrite_blockquote_text(el, &quote_depth1);
121        Ok(())
122    }));
123
124    // TEXT HANDLER: drop whitespace-only nodes inside tables + at list item start
125    let list_item_start_flag_text = list_item_start_flag.clone();
126    let in_table_flag_text = in_table_flag.clone();
127    element_content_handlers.push(text!(
128        "*:not(script):not(head):not(style):not(svg)",
129        move |el| {
130            let s = el.as_str();
131
132            // inside table: ignore formatting whitespace between cells
133            if in_table_flag_text.get() && is_ascii_ws_only(s) {
134                *el.as_mut_str() = String::new();
135                return Ok(());
136            }
137
138            // list marker fix: swallow whitespace-only nodes until first real text
139            if list_item_start_flag_text.get() {
140                if is_ascii_ws_only(s) {
141                    *el.as_mut_str() = String::new();
142                    return Ok(());
143                }
144                list_item_start_flag_text.set(false);
145            }
146
147            // Only allocate if escaping is actually needed
148            if let Some(escaped) = crate::replace_markdown_chars_opt(s) {
149                *el.as_mut_str() = escaped;
150            }
151            Ok(())
152        }
153    ));
154
155    element_content_handlers.push(element!(
156        "head, nav, footer, script, noscript, style",
157        move |el| {
158            let repaired_head_element: bool = repaired_head.get().is_some();
159            let head_element = el.tag_name() == "head";
160            if head_element && !repaired_head_element {
161                if let Some(hvec) = el.end_tag_handlers() {
162                    let repaired_head = repaired_head.clone();
163                    let h1: LocalEndHandler =
164                        Box::new(move |end: &mut lol_html::html_content::EndTag<'_>| {
165                            let repaired_element = repaired_head.get().is_some();
166                            if end.name() == "html" && !repaired_element {
167                                let _ = repaired_head.set(true);
168                                end.after("</head>", lol_html::html_content::ContentType::Html);
169                            } else {
170                                end.remove();
171                            }
172                            Ok(())
173                        });
174                    hvec.push(h1);
175                }
176            } else {
177                el.remove();
178            }
179            Ok(())
180        }
181    ));
182
183    // ELEMENT HANDLER: manage flags + call handle_tag
184    let list_item_start_flag_el = list_item_start_flag.clone();
185    let in_table_flag_el = in_table_flag.clone();
186
187    element_content_handlers.push(element!("*", move |el| {
188        // Table start: enable flag and add end-tag handler to disable.
189        if el.tag_name().as_str() == "table" {
190            in_table_flag_el.set(true);
191            if let Some(hvec) = el.end_tag_handlers() {
192                let in_table_flag_end = in_table_flag_el.clone();
193                let h: LocalEndHandler =
194                    Box::new(move |_end: &mut lol_html::html_content::EndTag<'_>| {
195                        in_table_flag_end.set(false);
196                        Ok(())
197                    });
198                hvec.push(h);
199            }
200        }
201
202        // sync state from flags
203        in_table = in_table_flag_el.get();
204        list_item_start = list_item_start_flag_el.get();
205
206        let _ = handle_tag(
207            el,
208            commonmark,
209            &url,
210            &mut list_type,
211            &mut order_counter,
212            quote_depth.clone(),
213            &mut in_table,
214            &mut table_row_start,
215            &mut list_item_start,
216        );
217
218        // mirror list flag for text handler
219        list_item_start_flag_el.set(list_item_start);
220
221        Ok(())
222    }));
223
224    if let Some(ignore) = custom {
225        if !ignore.is_empty() {
226            let ignore_handler = element!(
227                ignore.iter().cloned().collect::<Vec<String>>().join(","),
228                |el| {
229                    el.remove();
230                    Ok(())
231                }
232            );
233            element_content_handlers.push(ignore_handler);
234        }
235    }
236
237    RewriteStrSettings {
238        document_content_handlers: vec![
239            doc_comments!(|c| {
240                c.remove();
241                Ok(())
242            }),
243            doctype!(|c| {
244                c.remove();
245                Ok(())
246            }),
247        ],
248        element_content_handlers,
249        ..RewriteStrSettings::default()
250    }
251}
252
253/// Get the HTML rewriter settings to convert to markdown sync send.
254pub fn get_rewriter_settings_send(
255    commonmark: bool,
256    custom: &Option<std::collections::HashSet<String>>,
257    url: Option<Url>,
258) -> lol_html::send::Settings<'static, 'static> {
259    let mut list_type: Option<&'static str> = None;
260    let mut order_counter = 0usize;
261
262    let quote_depth = Arc::new(AtomicUsize::new(0));
263    let quote_depth1 = quote_depth.clone();
264
265    let repaired_head = Arc::new(std::sync::OnceLock::new());
266
267    // packed flags (single atomic load per handler call)
268    let flags = Arc::new(AtomicU8::new(0));
269
270    // state passed into handle_tag_send
271    let mut in_table = false;
272    let mut table_row_start = false;
273    let mut list_item_start = false;
274
275    let mut element_content_handlers = Vec::with_capacity(
276        4 + custom
277            .as_ref()
278            .map_or(0, |c| if c.is_empty() { 0 } else { 1 })
279            + {
280                #[cfg(feature = "ignore_cookies")]
281                {
282                    1
283                }
284                #[cfg(not(feature = "ignore_cookies"))]
285                {
286                    0
287                }
288            },
289    );
290
291    #[cfg(feature = "ignore_cookies")]
292    {
293        element_content_handlers.push(lol_html::element!(COOKIE_BANNER_SELECTOR, |el| {
294            el.remove();
295            Ok(())
296        }));
297    }
298
299    element_content_handlers.push(text!("blockquote, q, cite", move |el| {
300        let _ = rewrite_blockquote_text_send(el, &quote_depth);
301        Ok(())
302    }));
303
304    // TEXT HANDLER (send): single atomic load + ASCII whitespace scan
305    let flags_text = flags.clone();
306    element_content_handlers.push(text!(
307        "*:not(script):not(head):not(style):not(svg)",
308        move |el| {
309            let f = flags_text.load(Ordering::Relaxed);
310            let in_table_now = (f & F_IN_TABLE) != 0;
311            let li_start_now = (f & F_LI_START) != 0;
312
313            let s = el.as_str();
314
315            if in_table_now && is_ascii_ws_only(s) {
316                *el.as_mut_str() = String::new();
317                return Ok(());
318            }
319
320            if li_start_now {
321                if is_ascii_ws_only(s) {
322                    *el.as_mut_str() = String::new();
323                    return Ok(());
324                }
325                // clear li-start
326                flag_clear(&*flags_text, F_LI_START);
327            }
328
329            // Only allocate if escaping is actually needed
330            if let Some(escaped) = crate::replace_markdown_chars_opt(s) {
331                *el.as_mut_str() = escaped;
332            }
333            Ok(())
334        }
335    ));
336
337    element_content_handlers.push(element!(
338        "head, nav, footer, script, noscript, style",
339        move |el| {
340            let repaired_head_element: bool = repaired_head.get().is_some();
341            let head_element = el.tag_name() == "head";
342            if head_element && !repaired_head_element {
343                if let Some(hvec) = el.end_tag_handlers() {
344                    let repaired_head = repaired_head.clone();
345                    let h1: EndHandler =
346                        Box::new(move |end: &mut lol_html::html_content::EndTag<'_>| {
347                            let repaired_element = repaired_head.get().is_some();
348                            if end.name() == "html" && !repaired_element {
349                                let _ = repaired_head.set(true);
350                                end.after("</head>", lol_html::html_content::ContentType::Html);
351                            } else {
352                                end.remove();
353                            }
354                            Ok(())
355                        });
356                    hvec.push(h1);
357                }
358            } else {
359                el.remove();
360            }
361            Ok(())
362        }
363    ));
364
365    // ELEMENT HANDLER (send): set/clear packed flags + call handle_tag_send
366    let flags_el = flags.clone();
367    element_content_handlers.push(element!("*", move |el| {
368        // table start
369        if el.tag_name().as_str() == "table" {
370            flag_set(&*flags_el, F_IN_TABLE);
371
372            if let Some(hvec) = el.end_tag_handlers() {
373                let flags_end = flags_el.clone();
374                let h: EndHandler =
375                    Box::new(move |_end: &mut lol_html::html_content::EndTag<'_>| {
376                        flag_clear(&*flags_end, F_IN_TABLE);
377                        Ok(())
378                    });
379                hvec.push(h);
380            }
381        }
382
383        // local bools for handle_tag_send
384        let f = flags_el.load(Ordering::Relaxed);
385        in_table = (f & F_IN_TABLE) != 0;
386        list_item_start = (f & F_LI_START) != 0;
387
388        let _ = handle_tag_send(
389            el,
390            commonmark,
391            &url,
392            &mut list_type,
393            &mut order_counter,
394            quote_depth1.clone(),
395            &mut in_table,
396            &mut table_row_start,
397            &mut list_item_start,
398        );
399
400        // mirror li-start back into packed flags
401        if list_item_start {
402            flag_set(&*flags_el, F_LI_START);
403        } else {
404            flag_clear(&*flags_el, F_LI_START);
405        }
406
407        Ok(())
408    }));
409
410    if let Some(ignore) = custom {
411        if !ignore.is_empty() {
412            let ignore_handler = element!(
413                ignore.iter().cloned().collect::<Vec<String>>().join(","),
414                |el| {
415                    el.remove();
416                    Ok(())
417                }
418            );
419            element_content_handlers.push(ignore_handler);
420        }
421    }
422
423    lol_html::send::Settings {
424        document_content_handlers: vec![
425            doc_comments!(|c| {
426                c.remove();
427                Ok(())
428            }),
429            doctype!(|c| {
430                c.remove();
431                Ok(())
432            }),
433        ],
434        element_content_handlers,
435        ..lol_html::send::Settings::new_send()
436    }
437}
438
439/// Shortcut to rewrite string and encode correctly
440pub(crate) fn rewrite_str<'h, 's, H: lol_html::HandlerTypes>(
441    html: &str,
442    settings: impl Into<lol_html::Settings<'h, 's, H>>,
443) -> Result<Vec<u8>, lol_html::errors::RewritingError> {
444    let mut output = Vec::with_capacity(estimate_markdown(html));
445
446    let mut rewriter = lol_html::HtmlRewriter::new(settings.into(), |c: &[u8]| {
447        output.extend_from_slice(c);
448    });
449
450    rewriter.write(html.as_bytes())?;
451    rewriter.end()?;
452
453    Ok(output)
454}
455
456/// Convert to markdown streaming re-writer
457pub(crate) fn convert_html_to_markdown(
458    html: &str,
459    custom: &Option<std::collections::HashSet<String>>,
460    commonmark: bool,
461    url: &Option<Url>,
462) -> Result<String, Box<dyn std::error::Error>> {
463    let settings = get_rewriter_settings(commonmark, custom, url.clone());
464
465    match rewrite_str(html, settings) {
466        Ok(markdown) => Ok(clean_markdown_bytes(&markdown)),
467        Err(e) => Err(e.into()),
468    }
469}
470
471/// Convert to markdown streaming re-writer with chunk size.
472#[cfg(feature = "stream")]
473pub async fn convert_html_to_markdown_send_with_size(
474    html: &str,
475    custom: &Option<std::collections::HashSet<String>>,
476    commonmark: bool,
477    url: &Option<Url>,
478    chunk_size: usize,
479) -> Result<String, Box<dyn std::error::Error>> {
480    let settings = get_rewriter_settings_send(commonmark, custom, url.clone());
481    let mut rewrited_bytes: Vec<u8> = Vec::with_capacity(estimate_markdown(html));
482
483    let mut rewriter = lol_html::send::HtmlRewriter::new(settings.into(), |c: &[u8]| {
484        rewrited_bytes.extend_from_slice(c);
485    });
486
487    let bytes = html.as_bytes();
488
489    // Process in chunks without async overhead for in-memory data
490    let mut wrote_error = false;
491    for chunk in bytes.chunks(chunk_size) {
492        if rewriter.write(chunk).is_err() {
493            wrote_error = true;
494            break;
495        }
496    }
497
498    if !wrote_error {
499        let _ = rewriter.end();
500    }
501
502    Ok(clean_markdown_bytes(&rewrited_bytes))
503}
504
505/// Convert to markdown streaming re-writer
506#[cfg(feature = "stream")]
507pub async fn convert_html_to_markdown_send(
508    html: &str,
509    custom: &Option<std::collections::HashSet<String>>,
510    commonmark: bool,
511    url: &Option<Url>,
512) -> Result<String, Box<dyn std::error::Error>> {
513    convert_html_to_markdown_send_with_size(html, custom, commonmark, url, 8192).await
514}
515
516/// Error type for stream-based conversion.
517#[cfg(feature = "stream")]
518#[derive(Debug)]
519pub enum StreamConvertError<E> {
520    /// The input stream yielded an error.
521    Stream(E),
522    /// lol_html rewriting failed.
523    Rewrite(lol_html::errors::RewritingError),
524}
525
526#[cfg(feature = "stream")]
527impl<E: std::fmt::Display> std::fmt::Display for StreamConvertError<E> {
528    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
529        match self {
530            Self::Stream(e) => write!(f, "stream error: {e}"),
531            Self::Rewrite(e) => write!(f, "rewrite error: {e}"),
532        }
533    }
534}
535
536#[cfg(feature = "stream")]
537impl<E: std::error::Error + 'static> std::error::Error for StreamConvertError<E> {
538    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
539        match self {
540            Self::Stream(e) => Some(e),
541            Self::Rewrite(e) => Some(e),
542        }
543    }
544}
545
546/// Convert an async byte stream of HTML into markdown.
547///
548/// Genuinely async — yields to the executor between input chunks via `stream.next().await`.
549/// Uses `lol_html::send::HtmlRewriter` which handles chunk-boundary splitting internally.
550#[cfg(feature = "stream")]
551pub async fn convert_html_stream_to_markdown<S, B, E>(
552    stream: S,
553    custom: &Option<std::collections::HashSet<String>>,
554    commonmark: bool,
555    url: &Option<Url>,
556) -> Result<String, StreamConvertError<E>>
557where
558    S: futures_util::Stream<Item = Result<B, E>> + Unpin,
559    B: AsRef<[u8]>,
560{
561    use futures_util::StreamExt;
562
563    let settings = get_rewriter_settings_send(commonmark, custom, url.clone());
564    let mut output: Vec<u8> = Vec::with_capacity(4096);
565
566    let mut rewriter = lol_html::send::HtmlRewriter::new(settings.into(), |c: &[u8]| {
567        output.extend_from_slice(c);
568    });
569
570    futures_util::pin_mut!(stream);
571
572    while let Some(chunk_result) = stream.next().await {
573        let chunk = chunk_result.map_err(StreamConvertError::Stream)?;
574        rewriter
575            .write(chunk.as_ref())
576            .map_err(StreamConvertError::Rewrite)?;
577    }
578
579    rewriter.end().map_err(StreamConvertError::Rewrite)?;
580
581    Ok(clean_markdown_bytes(&output))
582}