Skip to main content

lol_html/rewriter/
mod.rs

1mod handlers_dispatcher;
2mod rewrite_controller;
3
4#[macro_use]
5pub(crate) mod settings;
6
7use self::rewrite_controller::{ElementDescriptor, HtmlRewriteController};
8pub use self::settings::*;
9use crate::base::SharedEncoding;
10use crate::memory::{MemoryLimitExceededError, SharedMemoryLimiter};
11use crate::parser::ParsingAmbiguityError;
12use crate::rewritable_units::Element;
13use crate::transform_stream::*;
14use encoding_rs::Encoding;
15use mime::Mime;
16use std::borrow::Cow;
17use std::error::Error as StdError;
18use std::fmt::{self, Debug};
19use thiserror::Error;
20
21/// This is an encoding known to be ASCII-compatible.
22///
23/// Non-ASCII-compatible encodings (`UTF-16LE`, `UTF-16BE`, `ISO-2022-JP` and
24/// `replacement`) are not supported by `lol_html`.
25#[derive(Copy, Clone, Debug, PartialEq, Eq)]
26pub struct AsciiCompatibleEncoding(&'static Encoding);
27
28impl AsciiCompatibleEncoding {
29    /// Returns `Some` if `Encoding` is ascii-compatible, or `None` otherwise.
30    #[must_use]
31    pub fn new(encoding: &'static Encoding) -> Option<Self> {
32        encoding.is_ascii_compatible().then_some(Self(encoding))
33    }
34
35    fn from_mimetype(mime: &Mime) -> Option<Self> {
36        let cs = mime.get_param("charset")?;
37        Self::new(Encoding::for_label_no_replacement(cs.as_str().as_bytes())?)
38    }
39
40    /// Returns the most commonly used UTF-8 encoding.
41    #[must_use]
42    pub fn utf_8() -> Self {
43        Self(encoding_rs::UTF_8)
44    }
45}
46
47impl From<AsciiCompatibleEncoding> for &'static Encoding {
48    fn from(ascii_enc: AsciiCompatibleEncoding) -> &'static Encoding {
49        ascii_enc.0
50    }
51}
52
53impl TryFrom<&'static Encoding> for AsciiCompatibleEncoding {
54    type Error = ();
55
56    fn try_from(enc: &'static Encoding) -> Result<Self, ()> {
57        Self::new(enc).ok_or(())
58    }
59}
60
61/// A compound error type that can be returned by [`write`] and [`end`] methods of the rewriter.
62///
63/// # Note
64/// This error is unrecoverable. The rewriter instance will panic on attempt to use it after such an
65/// error.
66///
67/// [`write`]: ../struct.HtmlRewriter.html#method.write
68/// [`end`]: ../struct.HtmlRewriter.html#method.end
69#[derive(Error, Debug)]
70pub enum RewritingError {
71    /// See [`MemoryLimitExceededError`].
72    ///
73    /// [`MemoryLimitExceededError`]: struct.MemoryLimitExceededError.html
74    #[error("{0}")]
75    MemoryLimitExceeded(MemoryLimitExceededError),
76
77    /// See [`ParsingAmbiguityError`].
78    ///
79    /// [`ParsingAmbiguityError`]: struct.ParsingAmbiguityError.html
80    #[error("{0}")]
81    ParsingAmbiguity(ParsingAmbiguityError),
82
83    /// An error that was propagated from one of the content handlers.
84    #[error("{0}")]
85    ContentHandlerError(Box<dyn StdError + Send + Sync + 'static>),
86}
87
88/// A streaming HTML rewriter.
89///
90/// # Example
91/// ```
92/// use lol_html::{element, HtmlRewriter, Settings};
93///
94/// let mut output = vec![];
95///
96/// {
97///     let mut rewriter = HtmlRewriter::new(
98///         Settings {
99///             element_content_handlers: vec![
100///                 // Rewrite insecure hyperlinks
101///                 element!("a[href]", |el| {
102///                     let href = el
103///                         .get_attribute("href")
104///                         .unwrap()
105///                         .replace("http:", "https:");
106///
107///                     el.set_attribute("href", &href).unwrap();
108///
109///                     Ok(())
110///                 })
111///             ],
112///             ..Settings::new()
113///         },
114///         |c: &[u8]| output.extend_from_slice(c)
115///     );
116///
117///     rewriter.write(b"<div><a href=").unwrap();
118///     rewriter.write(b"http://example.com>").unwrap();
119///     rewriter.write(b"</a></div>").unwrap();
120///     rewriter.end().unwrap();
121/// }
122///
123/// assert_eq!(
124///     String::from_utf8(output).unwrap(),
125///     r#"<div><a href="https://example.com"></a></div>"#
126/// );
127/// ```
128pub struct HtmlRewriter<'h, O: OutputSink, H: HandlerTypes = LocalHandlerTypes> {
129    stream: TransformStream<HtmlRewriteController<'h, H>, O>,
130    poisoned: bool,
131}
132
133macro_rules! guarded {
134    ($self:ident, $expr:expr) => {{
135        assert!(
136            !$self.poisoned,
137            "Attempt to use the HtmlRewriter after a fatal error."
138        );
139
140        let res = $expr;
141
142        if res.is_err() {
143            $self.poisoned = true;
144        }
145
146        res
147    }};
148}
149
150impl<'h, O: OutputSink, H: HandlerTypes> HtmlRewriter<'h, O, H> {
151    /// Constructs a new rewriter with the provided `settings` that writes
152    /// the output to the `output_sink`.
153    ///
154    /// # Note
155    ///
156    /// For the convenience the [`OutputSink`] trait is implemented for closures.
157    ///
158    /// [`OutputSink`]: trait.OutputSink.html
159    pub fn new<'s>(settings: Settings<'h, 's, H>, output_sink: O) -> Self {
160        let preallocated_parsing_buffer_size =
161            settings.memory_settings.preallocated_parsing_buffer_size;
162        let strict = settings.strict;
163
164        let encoding = SharedEncoding::new(settings.encoding);
165
166        let memory_limiter =
167            SharedMemoryLimiter::new(settings.memory_settings.max_allowed_memory_usage);
168
169        let stream = TransformStream::new(TransformStreamSettings {
170            transform_controller: HtmlRewriteController::from_settings(
171                settings,
172                &memory_limiter,
173                &encoding,
174            ),
175            output_sink,
176            preallocated_parsing_buffer_size,
177            memory_limiter,
178            encoding,
179            strict,
180        });
181
182        HtmlRewriter {
183            stream,
184            poisoned: false,
185        }
186    }
187
188    /// Writes a chunk of input data to the rewriter.
189    ///
190    /// # Panics
191    ///  * If previous invocation of the method returned a [`RewritingError`]
192    ///    (these errors are unrecovarable).
193    ///
194    /// [`RewritingError`]: errors/enum.RewritingError.html
195    /// [`end`]: struct.HtmlRewriter.html#method.end
196    #[inline]
197    pub fn write(&mut self, data: &[u8]) -> Result<(), RewritingError> {
198        guarded!(self, self.stream.write(data))
199    }
200
201    /// Finalizes the rewriting process.
202    ///
203    /// Should be called once the last chunk of the input is written.
204    ///
205    /// # Panics
206    ///  * If previous invocation of [`write`] returned a [`RewritingError`] (these errors
207    ///    are unrecovarable).
208    ///
209    /// [`RewritingError`]: errors/enum.RewritingError.html
210    /// [`write`]: struct.HtmlRewriter.html#method.write
211    #[inline]
212    pub fn end(mut self) -> Result<(), RewritingError> {
213        guarded!(self, self.stream.end())
214    }
215}
216
217// NOTE: this opaque Debug implementation is required to make
218// `.unwrap()` and `.expect()` methods available on Result
219// returned by the `HtmlRewriterBuilder.build()` method.
220impl<O: OutputSink, H: HandlerTypes> Debug for HtmlRewriter<'_, O, H> {
221    #[cold]
222    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
223        write!(f, "HtmlRewriter")
224    }
225}
226
227fn handler_adjust_charset_on_meta_tag<'h, H: HandlerTypes>(
228    encoding: SharedEncoding,
229) -> (Cow<'h, crate::Selector>, ElementContentHandlers<'h, H>) {
230    // HTML5 allows encoding to be set only once
231    let mut found = false;
232
233    let handler = move |el: &mut Element<'_, '_, H>| {
234        if found {
235            return Ok(());
236        }
237
238        let charset = el.get_attribute("charset").and_then(|cs| {
239            AsciiCompatibleEncoding::new(Encoding::for_label_no_replacement(cs.as_bytes())?)
240        });
241
242        let charset = charset.or_else(|| {
243            el.get_attribute("http-equiv")
244                .filter(|http_equiv| http_equiv.eq_ignore_ascii_case("Content-Type"))
245                .and_then(|_| {
246                    AsciiCompatibleEncoding::from_mimetype(
247                        &el.get_attribute("content")?.parse::<Mime>().ok()?,
248                    )
249                })
250        });
251
252        if let Some(charset) = charset {
253            found = true;
254            encoding.set(charset);
255        }
256
257        Ok(())
258    };
259
260    let content_handlers = ElementContentHandlers {
261        element: Some(H::new_element_handler(handler)),
262        comments: None,
263        text: None,
264    };
265
266    (Cow::Owned("meta".parse().unwrap()), content_handlers)
267}
268
269/// Rewrites given `html` string with the provided `settings`.
270///
271/// # Example
272///
273/// ```
274/// use lol_html::{rewrite_str, element, RewriteStrSettings};
275///
276/// let element_content_handlers = vec![
277///     // Rewrite insecure hyperlinks
278///     element!("a[href]", |el| {
279///         let href = el
280///             .get_attribute("href")
281///             .unwrap()
282///             .replace("http:", "https:");
283///
284///          el.set_attribute("href", &href).unwrap();
285///
286///          Ok(())
287///     })
288/// ];
289/// let output = rewrite_str(
290///     r#"<div><a href="http://example.com"></a></div>"#,
291///     RewriteStrSettings {
292///         element_content_handlers,
293///         ..RewriteStrSettings::new()
294///     }
295/// ).unwrap();
296///
297/// assert_eq!(output, r#"<div><a href="https://example.com"></a></div>"#);
298/// ```
299pub fn rewrite_str<'h, 's, H: HandlerTypes>(
300    html: &str,
301    settings: impl Into<Settings<'h, 's, H>>,
302) -> Result<String, RewritingError> {
303    let mut output = vec![];
304
305    let mut rewriter = HtmlRewriter::new(settings.into(), |c: &[u8]| {
306        output.extend_from_slice(c);
307    });
308
309    rewriter.write(html.as_bytes())?;
310    rewriter.end()?;
311
312    // NOTE: it's ok to unwrap here as we guarantee encoding validity of the output
313    Ok(String::from_utf8(output).unwrap())
314}
315
316#[cfg(test)]
317mod tests {
318    use super::*;
319    use crate::html::TextType;
320    use crate::html_content::ContentType;
321    use crate::test_utils::{ASCII_COMPATIBLE_ENCODINGS, NON_ASCII_COMPATIBLE_ENCODINGS, Output};
322    use encoding_rs::Encoding;
323    use itertools::Itertools;
324    use static_assertions::assert_impl_all;
325    use std::convert::TryInto;
326    use std::sync::atomic::{AtomicUsize, Ordering};
327    use std::sync::{Arc, Mutex};
328
329    // Assert that HtmlRewriter with `SendHandlerTypes` is `Send`.
330    assert_impl_all!(crate::send::HtmlRewriter<'_, Box<dyn FnMut(&[u8]) + Send + 'static>>: Send);
331
332    fn write_chunks<O: OutputSink>(
333        mut rewriter: HtmlRewriter<'_, O>,
334        encoding: &'static Encoding,
335        chunks: &[&str],
336    ) {
337        for chunk in chunks {
338            let (chunk, _, _) = encoding.encode(chunk);
339
340            rewriter.write(&chunk).unwrap();
341        }
342
343        rewriter.end().unwrap();
344    }
345
346    fn rewrite_html_bytes(html: &[u8], settings: Settings<'_, '_>) -> Vec<u8> {
347        let mut out: Vec<u8> = Vec::with_capacity(html.len());
348
349        let mut rewriter = HtmlRewriter::new(settings, |c: &[u8]| out.extend_from_slice(c));
350
351        rewriter.write(html).unwrap();
352        rewriter.end().unwrap();
353
354        out
355    }
356
357    #[allow(clippy::drop_non_drop)]
358    #[test]
359    fn handlers_lifetime_covariance() {
360        // This test checks that if you have a handler with a lifetime larger than `'a` then you can
361        // use it in a place where a handler of lifetime `'a` is expected. If the code below
362        // compiles, then this condition holds.
363
364        let x = AtomicUsize::new(0);
365
366        let el_handler_static = element!("foo", |_| Ok(()));
367        let el_handler_local = element!("foo", |_| {
368            x.fetch_add(1, Ordering::Relaxed);
369            Ok(())
370        });
371
372        let doc_handler_static = end!(|_| Ok(()));
373        let doc_handler_local = end!(|_| {
374            x.fetch_add(1, Ordering::Relaxed);
375            Ok(())
376        });
377
378        let settings = Settings {
379            document_content_handlers: vec![doc_handler_static, doc_handler_local],
380            element_content_handlers: vec![el_handler_static, el_handler_local],
381            encoding: AsciiCompatibleEncoding::utf_8(),
382            strict: false,
383            adjust_charset_on_meta_tag: false,
384            ..Settings::new()
385        };
386        let rewriter = HtmlRewriter::new(settings, |_: &[u8]| ());
387
388        drop(rewriter);
389
390        drop(x);
391    }
392
393    #[test]
394    fn rewrite_html_str() {
395        let res = rewrite_str::<LocalHandlerTypes>(
396            "<!-- 42 --><div><!--hi--></div>",
397            RewriteStrSettings {
398                element_content_handlers: vec![
399                    element!("div", |el| {
400                        el.set_tag_name("span").unwrap();
401                        Ok(())
402                    }),
403                    comments!("div", |c| {
404                        c.set_text("hello").unwrap();
405                        Ok(())
406                    }),
407                ],
408                ..RewriteStrSettings::new()
409            },
410        )
411        .unwrap();
412
413        assert_eq!(res, "<!-- 42 --><span><!--hello--></span>");
414    }
415
416    #[test]
417    fn rewrite_incorrect_self_closing() {
418        let res = rewrite_str::<LocalHandlerTypes>(
419            "<title /></title><div/></div><style /></style><script /></script>
420            <br/><br><embed/><embed> <svg><a/><path/><path></path></svg>",
421            RewriteStrSettings {
422                element_content_handlers: vec![element!("*:not(svg)", |el| {
423                    el.set_attribute("s", if el.is_self_closing() { "y" } else { "n" })?;
424                    el.set_attribute("c", if el.can_have_content() { "y" } else { "n" })?;
425                    el.append("…", ContentType::Text);
426                    Ok(())
427                })],
428                ..RewriteStrSettings::new()
429            },
430        )
431        .unwrap();
432
433        assert_eq!(
434            res,
435            r#"<title s="y" c="y">…</title><div s="y" c="y">…</div><style s="y" c="y">…</style><script s="y" c="y">…</script>
436            <br s="y" c="n" /><br s="n" c="n"><embed s="y" c="n" /><embed s="n" c="n"> <svg><a s="y" c="n" /><path s="y" c="n" /><path s="n" c="y">…</path></svg>"#
437        );
438    }
439
440    #[test]
441    fn rewrite_arbitrary_settings() {
442        let res = rewrite_str("<span>Some text</span>", Settings::new()).unwrap();
443        assert_eq!(res, "<span>Some text</span>");
444    }
445
446    #[test]
447    fn non_ascii_compatible_encoding() {
448        for encoding in &NON_ASCII_COMPATIBLE_ENCODINGS {
449            assert_eq!(AsciiCompatibleEncoding::new(encoding), None);
450        }
451    }
452
453    #[test]
454    fn doctype_info() {
455        for &enc in &ASCII_COMPATIBLE_ENCODINGS {
456            let mut doctypes = Vec::default();
457
458            {
459                let rewriter = HtmlRewriter::new(
460                    Settings {
461                        document_content_handlers: vec![doctype!(|d| {
462                            doctypes.push((d.name(), d.public_id(), d.system_id()));
463                            Ok(())
464                        })],
465                        // NOTE: unwrap() here is intentional; it also tests `Ascii::new`.
466                        encoding: enc.try_into().unwrap(),
467                        ..Settings::new()
468                    },
469                    |_: &[u8]| {},
470                );
471
472                write_chunks(
473                    rewriter,
474                    enc,
475                    &[
476                        "<!doctype html1>",
477                        "<!-- test --><div>",
478                        r#"<!DOCTYPE HTML PUBLIC "-//W3C//DTD HTML 4.01//EN" "#,
479                        r#""http://www.w3.org/TR/html4/strict.dtd">"#,
480                        "</div><!DoCtYPe ",
481                    ],
482                );
483            }
484
485            assert_eq!(
486                doctypes,
487                &[
488                    (Some("html1".into()), None, None),
489                    (
490                        Some("html".into()),
491                        Some("-//W3C//DTD HTML 4.01//EN".into()),
492                        Some("http://www.w3.org/TR/html4/strict.dtd".into())
493                    ),
494                    (None, None, None),
495                ]
496            );
497        }
498    }
499
500    #[test]
501    fn rewrite_start_tags() {
502        for &enc in &ASCII_COMPATIBLE_ENCODINGS {
503            let actual: String = {
504                let mut output = Output::new(enc);
505
506                let rewriter = HtmlRewriter::new(
507                    Settings {
508                        element_content_handlers: vec![element!("*", |el| {
509                            el.set_attribute("foo", "bar").unwrap();
510                            el.prepend("<test></test>", ContentType::Html);
511                            Ok(())
512                        })],
513                        encoding: enc.try_into().unwrap(),
514                        ..Settings::new()
515                    },
516                    |c: &[u8]| output.push(c),
517                );
518
519                write_chunks(
520                    rewriter,
521                    enc,
522                    &[
523                        "<!doctype html>\n",
524                        "<html>\n",
525                        "   <head></head>\n",
526                        "   <body>\n",
527                        "       <div>Test</div>\n",
528                        "   </body>\n",
529                        "</html>",
530                    ],
531                );
532
533                output.into()
534            };
535
536            assert_eq!(
537                actual,
538                concat!(
539                    "<!doctype html>\n",
540                    "<html foo=\"bar\"><test></test>\n",
541                    "   <head foo=\"bar\"><test></test></head>\n",
542                    "   <body foo=\"bar\"><test></test>\n",
543                    "       <div foo=\"bar\"><test></test>Test</div>\n",
544                    "   </body>\n",
545                    "</html>",
546                )
547            );
548        }
549    }
550
551    #[test]
552    fn rewrite_document_content() {
553        for &enc in &ASCII_COMPATIBLE_ENCODINGS {
554            let actual: String = {
555                let mut output = Output::new(enc);
556
557                let rewriter = HtmlRewriter::new(
558                    Settings {
559                        element_content_handlers: vec![],
560                        document_content_handlers: vec![
561                            doc_comments!(|c| {
562                                c.set_text(&(c.text() + "1337")).unwrap();
563                                Ok(())
564                            }),
565                            doc_text!(|c| {
566                                if c.last_in_text_node() {
567                                    c.after("BAZ", ContentType::Text);
568                                }
569
570                                Ok(())
571                            }),
572                        ],
573                        encoding: enc.try_into().unwrap(),
574                        ..Settings::new()
575                    },
576                    |c: &[u8]| output.push(c),
577                );
578
579                write_chunks(
580                    rewriter,
581                    enc,
582                    &[
583                        "<!doctype html>\n",
584                        "<!-- hey -->\n",
585                        "<html>\n",
586                        "   <head><!-- aloha --></head>\n",
587                        "   <body>\n",
588                        "       <div>Test</div>\n",
589                        "   </body>\n",
590                        "   <!-- bonjour -->\n",
591                        "</html>Pshhh",
592                    ],
593                );
594
595                output.into()
596            };
597
598            assert_eq!(
599                actual,
600                concat!(
601                    "<!doctype html>\nBAZ",
602                    "<!-- hey 1337-->\nBAZ",
603                    "<html>\n",
604                    "   BAZ<head><!-- aloha 1337--></head>\n",
605                    "   BAZ<body>\n",
606                    "       BAZ<div>TestBAZ</div>\n",
607                    "   BAZ</body>\n",
608                    "   BAZ<!-- bonjour 1337-->\nBAZ",
609                    "</html>PshhhBAZ",
610                )
611            );
612        }
613    }
614
615    #[test]
616    fn rewrite_text_types() {
617        for &enc in &ASCII_COMPATIBLE_ENCODINGS {
618            let actual: String = {
619                let mut output = Output::new(enc);
620
621                let rewriter = HtmlRewriter::new(
622                    Settings {
623                        element_content_handlers: vec![],
624                        document_content_handlers: vec![doc_text!(|c| {
625                            let replace = match c.text_type() {
626                                TextType::PlainText => 'P',
627                                TextType::RCData => 'r',
628                                TextType::RawText => 'R',
629                                TextType::ScriptData => 'S',
630                                TextType::Data => '.',
631                                TextType::CDataSection => 'C',
632                            };
633                            let mut replaced: String = c
634                                .as_str()
635                                .chars()
636                                .map(|c| if c == '\n' { c } else { replace })
637                                .collect();
638                            if c.last_in_text_node() {
639                                replaced.push(';');
640                            }
641                            c.set_str(replaced);
642
643                            Ok(())
644                        })],
645                        encoding: enc.try_into().unwrap(),
646                        ..Settings::new()
647                    },
648                    |c: &[u8]| output.push(c),
649                );
650
651                write_chunks(
652                    rewriter,
653                    enc,
654                    &[
655                        "\n  <!doctype html> <title>rcdata</titlenot> <!--no comment rcdata</title>",
656                        "\n   <textarea>rc<x> --><!--no comment </TEXTAREA> ",
657                        "\n   body <!--> 1 </> 2 <noscript>nnnn</noscript>",
658                        "\n  <script>scr</script> <style>style</style>",
659                        "\n  <script><!-- scr --></script> <style>/*<![CDATA[*/ style /*]]>*/</style>",
660                        "\n  <svg> body <![CDATA[ cdata ]]> body",
661                        "\n  <script>scr</script> <style>style</style>",
662                        "\n  <script><!-- com -->s</script> <style>/*<![CDATA[*/ style /*]]>*/</style>",
663                        "\n  </svg>",
664                    ],
665                );
666
667                output.into()
668            };
669
670            assert_eq!(
671                actual,
672                "\
673                \n..;<!doctype html>.;<title>rrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrr;</title>\
674                \n...;<textarea>rrrrrrrrrrrrrrrrrrrrrrrr;</TEXTAREA>.\
675                \n........;<!-->...;</>...;<noscript>RRRR;</noscript>\
676                \n..;<script>SSS;</script>.;<style>RRRRR;</style>\
677                \n..;<script>SSSSSSSSSSSS;</script>.;<style>RRRRRRRRRRRRRRRRRRRRRRRRRRR;</style>\
678                \n..;<svg>......;<![CDATA[CCCCCCC;]]>.....\
679                \n..;<script>...;</script>.;<style>.....;</style>\
680                \n..;<script><!-- com -->.;</script>.;<style>..;<![CDATA[CCCCCCCCCCC;]]>..;</style>\
681                \n..;</svg>\
682                "
683            );
684        }
685    }
686
687    #[test]
688    fn handler_invocation_order() {
689        let handlers_executed = Arc::new(Mutex::new(Vec::default()));
690
691        macro_rules! create_handlers {
692            ($sel:expr, $idx:expr) => {
693                element!($sel, {
694                    let handlers_executed = ::std::sync::Arc::clone(&handlers_executed);
695
696                    move |_| {
697                        handlers_executed.lock().unwrap().push($idx);
698                        Ok(())
699                    }
700                })
701            };
702        }
703
704        let _res = rewrite_str(
705            "<div><span foo></span></div>",
706            RewriteStrSettings {
707                element_content_handlers: vec![
708                    create_handlers!("div span", 0),
709                    create_handlers!("div > span", 1),
710                    create_handlers!("span", 2),
711                    create_handlers!("[foo]", 3),
712                    create_handlers!("div span[foo]", 4),
713                ],
714                ..RewriteStrSettings::new()
715            },
716        )
717        .unwrap();
718
719        assert_eq!(*handlers_executed.lock().unwrap(), vec![0, 1, 2, 3, 4]);
720    }
721
722    #[test]
723    fn write_esi_tags() {
724        let res = rewrite_str(
725            "<span><esi:include src=a></span>",
726            RewriteStrSettings {
727                element_content_handlers: vec![element!("esi\\:include", |el| {
728                    el.replace("?", ContentType::Text);
729                    Ok(())
730                })],
731                enable_esi_tags: true,
732                ..RewriteStrSettings::new()
733            },
734        )
735        .unwrap();
736
737        assert_eq!(res, "<span>?</span>");
738    }
739
740    #[test]
741    fn test_rewrite_adjust_charset_on_meta_tag_attribute_charset() {
742        use crate::html_content::{ContentType, TextChunk};
743
744        let enthusiastic_text_handler = || {
745            doc_text!(move |text: &mut TextChunk<'_>| {
746                let new_text = text.as_str().replace('!', "!!!");
747                text.replace(&new_text, ContentType::Text);
748                Ok(())
749            })
750        };
751
752        let html: Vec<u8> = [
753            r#"<meta charset="windows-1251"><html><head></head><body>I love "#
754                .as_bytes()
755                .to_vec(),
756            vec![0xd5, 0xec, 0xb3, 0xcb, 0xdc],
757            br"!</body></html>".to_vec(),
758        ]
759        .into_iter()
760        .concat();
761
762        let expected: Vec<u8> = html
763            .iter()
764            .copied()
765            .flat_map(|c| match c {
766                b'!' => vec![b'!', b'!', b'!'],
767                c => vec![c],
768            })
769            .collect();
770
771        let transformed_no_charset_adjustment: Vec<u8> = rewrite_html_bytes(
772            &html,
773            Settings {
774                document_content_handlers: vec![enthusiastic_text_handler()],
775                ..Settings::new()
776            },
777        );
778
779        // Without charset adjustment the response has to be corrupted:
780        assert_ne!(transformed_no_charset_adjustment, expected);
781
782        let transformed_charset_adjustment: Vec<u8> = rewrite_html_bytes(
783            &html,
784            Settings {
785                document_content_handlers: vec![enthusiastic_text_handler()],
786                adjust_charset_on_meta_tag: true,
787                ..Settings::new()
788            },
789        );
790
791        // If it adapts the charset according to the meta tag everything will be correctly
792        // encoded in windows-1251:
793        assert_eq!(transformed_charset_adjustment, expected);
794    }
795
796    #[test]
797    fn test_rewrite_adjust_charset_on_meta_tag_attribute_content_type() {
798        use crate::html_content::{ContentType, TextChunk};
799
800        let enthusiastic_text_handler = || {
801            doc_text!(move |text: &mut TextChunk<'_>| {
802                let new_text = text.as_str().replace('!', "!!!");
803                text.replace(&new_text, ContentType::Text);
804                Ok(())
805            })
806        };
807
808        let html: Vec<u8> = [
809            r#"<meta http-equiv="conTent-type" content="text/html; charset=windows-1251"><html><head>"#.as_bytes(),
810            br#"<meta charset="utf-8"></head><body>I love "#, // second one should be ignored
811            &[0xd5, 0xec, 0xb3, 0xcb, 0xdc],
812            br"!</body></html>",
813        ].concat();
814
815        let expected: Vec<u8> = html
816            .iter()
817            .copied()
818            .flat_map(|c| match c {
819                b'!' => vec![b'!', b'!', b'!'],
820                c => vec![c],
821            })
822            .collect();
823
824        let transformed_no_charset_adjustment: Vec<u8> = rewrite_html_bytes(
825            &html,
826            Settings {
827                document_content_handlers: vec![enthusiastic_text_handler()],
828                ..Settings::new()
829            },
830        );
831
832        // Without charset adjustment the response has to be corrupted:
833        assert_ne!(transformed_no_charset_adjustment, expected);
834
835        let transformed_charset_adjustment: Vec<u8> = rewrite_html_bytes(
836            &html,
837            Settings {
838                document_content_handlers: vec![enthusiastic_text_handler()],
839                adjust_charset_on_meta_tag: true,
840                ..Settings::new()
841            },
842        );
843
844        // If it adapts the charset according to the meta tag everything will be correctly
845        // encoded in windows-1251:
846        assert_eq!(transformed_charset_adjustment, expected);
847    }
848
849    mod fatal_errors {
850        use super::*;
851        use crate::html_content::Comment;
852        use crate::memory::MemoryLimitExceededError;
853        use crate::rewritable_units::{Element, TextChunk};
854
855        fn create_rewriter<O: OutputSink>(
856            max_allowed_memory_usage: usize,
857            output_sink: O,
858        ) -> HtmlRewriter<'static, O> {
859            HtmlRewriter::new(
860                Settings {
861                    element_content_handlers: vec![element!("*", |_| Ok(()))],
862                    memory_settings: MemorySettings {
863                        max_allowed_memory_usage,
864                        preallocated_parsing_buffer_size: 0,
865                    },
866                    ..Settings::new()
867                },
868                output_sink,
869            )
870        }
871
872        #[test]
873        fn buffer_capacity_limit() {
874            const MAX: usize = 100;
875
876            let mut rewriter = create_rewriter(MAX, |_: &[u8]| {});
877
878            // Use two chunks for the stream to force the usage of the buffer and
879            // make sure to overflow it.
880            let chunk_1 = format!("<img alt=\"{}", "l".repeat(MAX / 2));
881            let chunk_2 = format!("{}\" />", "r".repeat(MAX / 2));
882
883            rewriter.write(chunk_1.as_bytes()).unwrap();
884
885            let write_err = rewriter.write(chunk_2.as_bytes()).unwrap_err();
886
887            match write_err {
888                RewritingError::MemoryLimitExceeded(e) => assert_eq!(e, MemoryLimitExceededError),
889                _ => panic!("{}", write_err),
890            }
891        }
892
893        #[test]
894        #[should_panic(expected = "Attempt to use the HtmlRewriter after a fatal error.")]
895        fn poisoning_after_fatal_error() {
896            const MAX: usize = 10;
897
898            let mut rewriter = create_rewriter(MAX, |_: &[u8]| {});
899            let chunk = format!("<img alt=\"{}", "l".repeat(MAX));
900
901            rewriter.write(chunk.as_bytes()).unwrap_err();
902            rewriter.end().unwrap_err();
903        }
904
905        #[test]
906        fn content_handler_error_propagation() {
907            fn assert_err<'h>(
908                element_handlers: ElementContentHandlers<'h>,
909                document_handlers: DocumentContentHandlers<'h>,
910                expected_err: &'static str,
911            ) {
912                use std::borrow::Cow;
913
914                let mut rewriter = HtmlRewriter::new(
915                    Settings {
916                        element_content_handlers: vec![(
917                            Cow::Owned("*".parse().unwrap()),
918                            element_handlers,
919                        )],
920                        document_content_handlers: vec![document_handlers],
921                        ..Settings::new()
922                    },
923                    |_: &[u8]| {},
924                );
925
926                let chunks = [
927                    "<!--doc comment--> Doc text",
928                    "<div><!--el comment-->El text</div>",
929                ];
930
931                let mut err = None;
932
933                for chunk in &chunks {
934                    match rewriter.write(chunk.as_bytes()) {
935                        Ok(()) => (),
936                        Err(e) => {
937                            err = Some(e);
938                            break;
939                        }
940                    }
941                }
942
943                if err.is_none() {
944                    match rewriter.end() {
945                        Ok(()) => (),
946                        Err(e) => err = Some(e),
947                    }
948                }
949
950                let err = format!("{}", err.expect("Error expected"));
951
952                assert_eq!(err, expected_err);
953            }
954
955            assert_err(
956                ElementContentHandlers::default(),
957                doc_comments!(|_| Err("Error in doc comment handler".into())),
958                "Error in doc comment handler",
959            );
960
961            assert_err(
962                ElementContentHandlers::default(),
963                doc_text!(|_| Err("Error in doc text handler".into())),
964                "Error in doc text handler",
965            );
966
967            assert_err(
968                ElementContentHandlers::default(),
969                doc_text!(|_| Err("Error in doctype handler".into())),
970                "Error in doctype handler",
971            );
972
973            assert_err(
974                ElementContentHandlers::default()
975                    .element(|_: &mut Element<'_, '_, _>| Err("Error in element handler".into())),
976                DocumentContentHandlers::default(),
977                "Error in element handler",
978            );
979
980            assert_err(
981                ElementContentHandlers::default()
982                    .comments(|_: &mut Comment<'_>| Err("Error in element comment handler".into())),
983                DocumentContentHandlers::default(),
984                "Error in element comment handler",
985            );
986
987            assert_err(
988                ElementContentHandlers::default()
989                    .text(|_: &mut TextChunk<'_>| Err("Error in element text handler".into())),
990                DocumentContentHandlers::default(),
991                "Error in element text handler",
992            );
993        }
994    }
995}