raxb_xmlschema/
writer.rs

1use crate::cnst::{Order, MAGIC_BYTE};
2use lz4_flex::block::compress_prepend_size;
3use std::{collections::BTreeMap, path::PathBuf, string::FromUtf8Error};
4use thiserror::Error;
5
6#[derive(Error, Debug)]
7pub enum WriterError {
8    #[error(transparent)]
9    Io(#[from] std::io::Error),
10    #[error(transparent)]
11    Utf8Error(#[from] FromUtf8Error),
12    #[error(transparent)]
13    UrlParse(#[from] url::ParseError),
14    #[cfg(feature = "writer")]
15    #[error(transparent)]
16    Reqwest(#[from] reqwest::Error),
17    #[error("no entrypoint")]
18    NoEntrypoint,
19    #[error("invalid file format")]
20    InvalidFormat,
21    #[error("invalid file header")]
22    InvalidHead,
23}
24
25pub type WriterResult<T> = Result<T, WriterError>;
26
27fn create_uuid(b: &[u8]) -> String {
28    uuid::Uuid::new_v5(&uuid::Uuid::NAMESPACE_URL, b)
29        .as_simple()
30        .to_string()
31}
32
33pub fn create_filepath<P: AsRef<std::path::Path>>(path: P, target_namespace: &str) -> PathBuf {
34    path.as_ref()
35        .join(format!("{}.xsdb", create_uuid(target_namespace.as_bytes())))
36}
37
38#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
39pub enum SchemaLocation {
40    Path(PathBuf),
41    Url(url::Url),
42}
43
44impl SchemaLocation {
45    pub fn get_content(&self, cache_dir: &std::path::Path) -> WriterResult<String> {
46        Ok(match self {
47            Self::Url(url) => {
48                let cache_name = format!("{url}");
49                let cached_file = cache_dir.join(create_uuid(cache_name.as_bytes()));
50                if cached_file.exists() {
51                    return Ok(std::fs::read_to_string(&cached_file)?);
52                }
53                let result = reqwest::blocking::get(url.as_ref())?.text()?;
54                std::fs::write(&cached_file, &result)?;
55                result
56            }
57            Self::Path(path) => std::fs::read_to_string(path)?,
58        })
59    }
60
61    pub fn try_join(&self, other: &str) -> WriterResult<Self> {
62        Ok(match self {
63            Self::Url(u) => Self::Url(u.join(other)?),
64            Self::Path(u) => Self::Path(u.parent().unwrap().join(other)),
65        })
66    }
67}
68
69impl std::fmt::Display for SchemaLocation {
70    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71        match self {
72            Self::Url(u) => u.fmt(f),
73            Self::Path(u) => u.display().fmt(f),
74        }
75    }
76}
77
78impl std::str::FromStr for SchemaLocation {
79    type Err = WriterError;
80    fn from_str(s: &str) -> Result<Self, Self::Err> {
81        Ok(if s.starts_with("http") {
82            SchemaLocation::Url(s.parse()?)
83        } else {
84            SchemaLocation::Path(PathBuf::from(s))
85        })
86    }
87}
88
89#[derive(Default)]
90pub struct SchemaWriter {
91    w: std::io::Cursor<Vec<u8>>,
92}
93
94#[derive(Debug)]
95pub struct SchemaEntry {
96    target_namespace: String,
97    entrypoint: bool,
98    content: String,
99}
100
101impl SchemaEntry {
102    pub fn new(target_namespace: String, entrypoint: bool, content: String) -> Self {
103        Self {
104            target_namespace,
105            entrypoint,
106            content,
107        }
108    }
109}
110
111impl SchemaWriter {
112    pub fn write(mut self, map: BTreeMap<SchemaLocation, SchemaEntry>) -> WriterResult<Vec<u8>> {
113        eprintln!("write {map:#?}");
114        use byteorder::WriteBytesExt;
115        use std::io::Write;
116        let m: Vec<(String, SchemaEntry)> = map
117            .into_iter()
118            .map(|(k, v)| {
119                let s = match k {
120                    SchemaLocation::Path(p) => p.file_name().unwrap().to_str().unwrap().to_string(),
121                    SchemaLocation::Url(u) => u.to_string(),
122                };
123                (s, v)
124            })
125            .collect::<Vec<_>>();
126        let (entrypoint_name, v) = m
127            .iter()
128            .find(|v| v.1.entrypoint)
129            .ok_or(WriterError::NoEntrypoint)?;
130        let initial_headsize = 4 + 8 + 4 + entrypoint_name.len() + 4 + v.target_namespace.len();
131        let head_size = m.iter().fold(initial_headsize, |state, (e, _)| {
132            state + 1 + 8 + 8 + 4 + e.len()
133        });
134        self.w.write_u32::<Order>(MAGIC_BYTE)?;
135        self.w.write_u64::<Order>(head_size as u64)?;
136        self.w.write_u32::<Order>(entrypoint_name.len() as u32)?;
137        self.w.write_all(entrypoint_name.as_bytes())?;
138        self.w.write_u32::<Order>(v.target_namespace.len() as u32)?;
139        self.w.write_all(v.target_namespace.as_bytes())?;
140        let mut pos = 0;
141        for (name, v) in m.iter() {
142            let end = pos + v.content.len();
143            self.w.write_u8(if v.entrypoint { 1 } else { 0 })?; // is entrypoint?
144            self.w.write_u64::<Order>(pos as u64)?; // start
145            self.w.write_u64::<Order>(end as u64)?; // end
146            self.w.write_u32::<Order>(name.len() as u32)?; // name length
147            self.w.write_all(name.as_bytes())?;
148            pos = end;
149        }
150        for (_, v) in m.iter() {
151            self.w.write_all(v.content.as_bytes())?;
152        }
153        self.w.flush()?;
154        Ok(compress_prepend_size(&self.w.into_inner()))
155    }
156}