epub/
xmlutils.rs

1use std::cell::RefCell;
2use std::rc::Rc;
3use std::rc::Weak;
4use xml::attribute::OwnedAttribute;
5use xml::reader::Error as ReaderError;
6use xml::reader::EventReader;
7use xml::reader::ParserConfig;
8
9use xml::reader::XmlEvent as ReaderEvent;
10use xml::writer::XmlEvent as WriterEvent;
11
12use std::fmt;
13use xml::writer::EmitterConfig;
14use xml::writer::Error as EmitterError;
15
16use std::borrow::Cow;
17
18// Using RefCell because we need to edit the children vec during the parsing.
19// Using rc because a Node will be referenced by its parent and by its childs.
20type ChildNodeRef = Rc<RefCell<XMLNode>>;
21type ParentNodeRef = Weak<RefCell<XMLNode>>;
22
23#[derive(Debug, thiserror::Error)]
24pub enum XMLError {
25    #[error("XML Reader Error: {0}")]
26    Reader(#[from] ReaderError),
27    #[error("XML Writer Error: {0}")]
28    Emitter(#[from] EmitterError),
29    #[error("Attribute Not Found: {0}")]
30    AttrNotFound(String),
31    #[error("Invalid State; this is a bug")]
32    InvalidState,
33    #[error("No XML Elements Found")]
34    NoElements,
35    #[error("XML content is empty")]
36    NoContent,
37}
38
39pub struct XMLReader<'a> {
40    reader: EventReader<&'a [u8]>,
41}
42
43impl<'a> XMLReader<'a> {
44    pub fn parse(content: &[u8]) -> Result<RefCell<XMLNode>, XMLError> {
45        // The operations below require at least 4 bytes to not panic
46        if content.is_empty() || content.len() < 4 {
47            return Err(XMLError::NoContent);
48        }
49
50        let content_str;
51        //If there is a UTF-8 BOM marker, ignore it
52        let content_slice = if content[0..3] == [0xefu8, 0xbbu8, 0xbfu8] {
53            &content[3..]
54        } else if content[0..2] == [0xfeu8, 0xffu8] || content[0..2] == [0xffu8, 0xfeu8] {
55            //handle utf-16
56            let (big_byte, small_byte) = if content[0] == 0xfeu8 {
57                (1, 0) //big endian utf-16
58            } else {
59                (0, 1) //little endian utf-16
60            };
61            let content_u16: Vec<u16> = content[2..]
62                .chunks_exact(2)
63                .into_iter()
64                .map(|a| u16::from_ne_bytes([a[big_byte], a[small_byte]]))
65                .collect();
66            content_str = String::from_utf16_lossy(content_u16.as_slice());
67            content_str.as_bytes()
68        } else {
69            content
70        };
71
72        let reader = XMLReader {
73            reader: ParserConfig::new()
74                .add_entity("nbsp", " ")
75                .add_entity("copy", "©")
76                .add_entity("reg", "®")
77                .create_reader(content_slice),
78        };
79
80        reader.parse_xml()
81    }
82
83    fn parse_xml(self) -> Result<RefCell<XMLNode>, XMLError> {
84        let mut root: Option<ChildNodeRef> = None;
85        let mut parents: Vec<ChildNodeRef> = vec![];
86
87        for e in self.reader {
88            match e {
89                Ok(ReaderEvent::StartElement {
90                    name,
91                    attributes,
92                    namespace,
93                }) => {
94                    let node = XMLNode {
95                        name,
96                        attrs: attributes,
97                        namespace,
98                        parent: None,
99                        text: None,
100                        cdata: None,
101                        children: vec![],
102                    };
103                    let arnode = Rc::new(RefCell::new(node));
104
105                    {
106                        let current = parents.last();
107                        if let Some(c) = current {
108                            c.borrow_mut().children.push(arnode.clone());
109                            arnode.borrow_mut().parent = Some(Rc::downgrade(c));
110                        }
111                    }
112                    parents.push(arnode.clone());
113
114                    if root.is_none() {
115                        root = Some(arnode.clone());
116                    }
117                }
118                Ok(ReaderEvent::EndElement { .. }) => {
119                    if !parents.is_empty() {
120                        parents.pop();
121                    }
122                }
123                Ok(ReaderEvent::Characters(text)) => {
124                    let current = parents.last();
125                    if let Some(c) = current {
126                        c.borrow_mut().text = Some(text);
127                    }
128                }
129                Ok(ReaderEvent::CData(text)) => {
130                    let current = parents.last();
131                    if let Some(c) = current {
132                        c.borrow_mut().cdata = Some(text);
133                    }
134                }
135                _ => continue,
136            }
137        }
138
139        if let Some(r) = root {
140            let a = Rc::try_unwrap(r);
141            match a {
142                Ok(n) => return Ok(n),
143                Err(_) => return Err(XMLError::InvalidState),
144            }
145        }
146        Err(XMLError::NoElements)
147    }
148}
149
150#[derive(Debug)]
151pub struct XMLNode {
152    pub name: xml::name::OwnedName,
153    pub attrs: Vec<xml::attribute::OwnedAttribute>,
154    pub namespace: xml::namespace::Namespace,
155    pub text: Option<String>,
156    pub cdata: Option<String>,
157    pub parent: Option<ParentNodeRef>,
158    pub children: Vec<ChildNodeRef>,
159}
160
161impl XMLNode {
162    pub fn get_attr(&self, name: &str) -> Option<String> {
163        self.attrs
164            .iter()
165            .find(|a| a.name.local_name == name)
166            .map(|a| a.value.clone())
167    }
168
169    pub fn find(&self, tag: &str) -> Option<ChildNodeRef> {
170        for r in &self.children {
171            let c = r.borrow();
172            if c.name.local_name == tag {
173                return Some(r.clone());
174            } else if let Some(n) = c.find(tag) {
175                return Some(n);
176            }
177        }
178
179        None
180    }
181}
182
183impl fmt::Display for XMLNode {
184    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
185        let childs: String = self.children.iter().fold(String::new(), |sum, x| {
186            format!("{}{}\n\t", sum, *x.borrow())
187        });
188        let attrs: String = self
189            .attrs
190            .iter()
191            .fold(String::new(), |sum, x| sum + &x.name.local_name + ", ");
192
193        let t = self.text.as_ref();
194        let mut text = String::new();
195        if let Some(t) = t {
196            text.clone_from(t);
197        }
198
199        write!(
200            f,
201            "<{} [{}]>\n\t{}{}",
202            self.name.local_name, attrs, childs, text
203        )
204    }
205}
206
207pub fn replace_attrs<F>(
208    xmldoc: &[u8],
209    closure: F,
210    extra_css: &[String],
211) -> Result<Vec<u8>, XMLError>
212where
213    F: Fn(&str, &str, &str) -> String,
214{
215    let mut b = Vec::new();
216
217    {
218        let reader = ParserConfig::new()
219            .add_entity("nbsp", " ")
220            .add_entity("copy", "©")
221            .add_entity("reg", "®")
222            .create_reader(xmldoc);
223        let mut writer = EmitterConfig::default()
224            .perform_indent(true)
225            .create_writer(&mut b);
226
227        for e in reader {
228            match e? {
229                ev @ ReaderEvent::StartElement { .. } => {
230                    let mut attrs: Vec<xml::attribute::OwnedAttribute> = vec![];
231
232                    if let Some(WriterEvent::StartElement {
233                        name,
234                        attributes,
235                        namespace,
236                    }) = ev.as_writer_event()
237                    {
238                        for i in 0..attributes.len() {
239                            let mut attr = attributes[i].to_owned();
240                            let repl = closure(name.local_name, &attr.name.local_name, &attr.value);
241                            attr.value = repl;
242                            attrs.push(attr);
243                        }
244
245                        let w = WriterEvent::StartElement {
246                            name,
247                            attributes: Cow::Owned(
248                                attrs.iter().map(OwnedAttribute::borrow).collect(),
249                            ),
250                            //attributes: attributes,
251                            namespace,
252                        };
253                        writer.write(w)?;
254                    }
255                }
256                ReaderEvent::EndElement { name: n } => {
257                    if n.local_name.to_lowercase() == "head" && !extra_css.is_empty() {
258                        // injecting here the extra css
259                        let mut allcss = extra_css.concat();
260                        allcss = format!("*/ {} /*", allcss);
261
262                        writer.write(WriterEvent::start_element("style"))?;
263                        writer.write("/*")?;
264                        writer.write(WriterEvent::cdata(&allcss))?;
265                        writer.write("*/")?;
266                        writer.write(WriterEvent::end_element())?;
267                    }
268                    writer.write(WriterEvent::end_element())?;
269                }
270                ev => {
271                    if let Some(e) = ev.as_writer_event() {
272                        writer.write(e)?;
273                    }
274                }
275            }
276        }
277    }
278
279    Ok(b)
280}