use super::{ImportArgs, ImportReport, InferredSchema, XmlAttrsMode};
use crate::convert::infer;
use crate::error::{NxsError, Result};
use crate::writer::{NxsWriter, Schema, Slot};
use quick_xml::{events::Event, Reader};
use std::io::{BufRead, Read, Write};
fn check_for_entity_expansion(src: &[u8]) -> Result<()> {
if let Some(pos) = find_bytes_ci(src, b"<!DOCTYPE") {
let slice_after = src.get(pos..).unwrap_or(&[]);
if find_bytes_ci(slice_after, b"<!ENTITY").is_some()
|| find_bytes_ci(slice_after, b"ENTITY ").is_some()
{
return Err(NxsError::ConvertEntityExpansion);
}
}
Ok(())
}
fn find_bytes_ci(haystack: &[u8], needle: &[u8]) -> Option<usize> {
if needle.is_empty() || haystack.len() < needle.len() {
return None;
}
let upper_needle: Vec<u8> = needle.iter().map(|b| b.to_ascii_uppercase()).collect();
'outer: for i in 0..=(haystack.len() - needle.len()) {
for j in 0..needle.len() {
let h = haystack
.get(i + j)
.copied()
.unwrap_or(0)
.to_ascii_uppercase();
if let Some(&n) = upper_needle.get(j) {
if h != n {
continue 'outer;
}
}
}
return Some(i);
}
None
}
fn parse_records<B: BufRead>(
mut reader: Reader<B>,
args: &ImportArgs,
record_tag: &str,
mut on_record: impl FnMut(Vec<(String, String)>) -> Result<()>,
) -> Result<()> {
let depth_limit = args.xml_max_depth.min(args.max_depth);
let attrs_mode = args.xml_attrs;
let mut buf = Vec::new();
let mut depth: usize = 0;
let mut in_record = false;
let mut record_depth: usize = 0;
let mut current_fields: Vec<(String, String)> = Vec::new();
let mut current_path: Vec<String> = Vec::new(); let mut current_text_key: Option<String> = None;
loop {
match reader.read_event_into(&mut buf) {
Ok(Event::Eof) => break,
Ok(Event::DocType(_)) => {
return Err(NxsError::ConvertEntityExpansion);
}
Ok(Event::Start(e)) => {
depth += 1;
if depth > depth_limit {
return Err(NxsError::ConvertDepthExceeded);
}
let local_name = e.local_name();
let tag = reader
.decoder()
.decode(local_name.as_ref())
.map_err(|e| NxsError::ConvertParseError {
offset: 0,
msg: e.to_string(),
})?
.to_string();
if !in_record && tag == record_tag {
in_record = true;
record_depth = depth;
current_fields.clear();
current_path.clear();
collect_attrs(&e, &reader, &mut current_fields, attrs_mode)?;
} else if in_record {
current_path.push(tag.clone());
collect_attrs(&e, &reader, &mut current_fields, attrs_mode)?;
current_text_key = Some(path_key(¤t_path));
}
}
Ok(Event::Empty(e)) => {
let local_name = e.local_name();
let tag = reader
.decoder()
.decode(local_name.as_ref())
.map_err(|e| NxsError::ConvertParseError {
offset: 0,
msg: e.to_string(),
})?
.to_string();
if !in_record && tag == record_tag {
let mut fields: Vec<(String, String)> = Vec::new();
collect_attrs(&e, &reader, &mut fields, attrs_mode)?;
on_record(fields)?;
} else if in_record {
let key = if current_path.is_empty() {
tag.clone()
} else {
format!("{}.{tag}", path_key(¤t_path))
};
let mut fields: Vec<(String, String)> = Vec::new();
collect_attrs_with_prefix(&e, &reader, &mut fields, attrs_mode, &key)?;
current_fields.extend(fields);
}
}
Ok(Event::End(_)) => {
if in_record && depth == record_depth {
on_record(std::mem::take(&mut current_fields))?;
in_record = false;
} else if in_record && !current_path.is_empty() {
current_path.pop();
current_text_key = None;
}
depth = depth.saturating_sub(1);
}
Ok(Event::Text(e)) => {
if in_record {
let text = e
.unescape()
.map_err(|e| NxsError::ConvertParseError {
offset: 0,
msg: e.to_string(),
})?
.to_string();
let trimmed = text.trim().to_string();
if !trimmed.is_empty() {
if let Some(key) = ¤t_text_key {
current_fields.push((key.clone(), trimmed));
}
}
}
}
Ok(_) => {}
Err(e) => {
return Err(NxsError::ConvertParseError {
offset: 0,
msg: e.to_string(),
});
}
}
buf.clear();
}
Ok(())
}
fn path_key(parts: &[String]) -> String {
parts.join(".")
}
fn collect_attrs<B: BufRead>(
e: &quick_xml::events::BytesStart<'_>,
reader: &Reader<B>,
fields: &mut Vec<(String, String)>,
mode: XmlAttrsMode,
) -> Result<()> {
collect_attrs_with_prefix(e, reader, fields, mode, "")
}
fn collect_attrs_with_prefix<B: BufRead>(
e: &quick_xml::events::BytesStart<'_>,
reader: &Reader<B>,
fields: &mut Vec<(String, String)>,
mode: XmlAttrsMode,
_parent: &str,
) -> Result<()> {
for attr in e.attributes().flatten() {
let key_bytes = attr.key.local_name();
let key = reader
.decoder()
.decode(key_bytes.as_ref())
.map_err(|e| NxsError::ConvertParseError {
offset: 0,
msg: e.to_string(),
})?
.to_string();
let val = attr
.decode_and_unescape_value(reader.decoder())
.map_err(|e| NxsError::ConvertParseError {
offset: 0,
msg: e.to_string(),
})?
.to_string();
let field_key = match mode {
XmlAttrsMode::AsFields => key,
XmlAttrsMode::Prefix => format!("@{key}"),
};
fields.push((field_key, val));
}
Ok(())
}
fn make_reader<R: Read>(reader: R) -> Reader<std::io::BufReader<R>> {
let mut r = Reader::from_reader(std::io::BufReader::new(reader));
r.config_mut().trim_text(true);
r
}
pub fn infer_schema<R: Read>(mut reader: R, args: &ImportArgs) -> Result<InferredSchema> {
let record_tag = args
.xml_record_tag
.as_deref()
.ok_or_else(|| NxsError::ConvertParseError {
offset: 0,
msg: "XML import requires --xml-record-tag".into(),
})?;
let mut raw = Vec::new();
reader
.read_to_end(&mut raw)
.map_err(|e| NxsError::IoError(e.to_string()))?;
check_for_entity_expansion(&raw)?;
let xml_reader = make_reader(std::io::Cursor::new(raw));
let mut acc = InferredSchema::default();
parse_records(xml_reader, args, record_tag, |fields| {
infer::merge(&mut acc, &fields);
Ok(())
})?;
infer::finalize(acc, args.conflict)
}
pub fn emit<R: Read, W: Write>(
mut reader: R,
mut writer: W,
schema: &InferredSchema,
args: &ImportArgs,
) -> Result<ImportReport> {
let record_tag = args
.xml_record_tag
.as_deref()
.ok_or_else(|| NxsError::ConvertParseError {
offset: 0,
msg: "XML import requires --xml-record-tag".into(),
})?;
let mut raw = Vec::new();
reader
.read_to_end(&mut raw)
.map_err(|e| NxsError::IoError(e.to_string()))?;
check_for_entity_expansion(&raw)?;
let key_names: Vec<&str> = schema.keys.iter().map(|k| k.name.as_str()).collect();
let nxs_schema = Schema::new(&key_names);
let mut nxs_writer = NxsWriter::new(&nxs_schema);
let mut records_written = 0usize;
let xml_reader = make_reader(std::io::Cursor::new(raw));
parse_records(xml_reader, args, record_tag, |fields| {
nxs_writer.begin_object();
for (key, value) in &fields {
let slot_idx = schema.keys.iter().position(|k| &k.name == key);
if let Some(idx) = slot_idx {
if value.is_empty() {
continue;
}
let slot = Slot(idx as u16);
let sigil = schema.keys.get(idx).map(|k| k.sigil).unwrap_or(b'"');
match sigil {
b'=' => {
if let Ok(i) = value.parse::<i64>() {
nxs_writer.write_i64(slot, i);
}
}
b'~' => {
if let Ok(f) = value.parse::<f64>() {
nxs_writer.write_f64(slot, f);
}
}
b'?' => {
nxs_writer.write_bool(slot, value == "true");
}
b'@' => {
if let Ok(t) = value.parse::<i64>() {
nxs_writer.write_time(slot, t);
}
}
b'<' => {
if value.len() % 2 == 0 {
if let Ok(bytes) = (0..value.len())
.step_by(2)
.map(|i| u8::from_str_radix(&value[i..i + 2], 16))
.collect::<std::result::Result<Vec<u8>, _>>()
{
nxs_writer.write_bytes(slot, &bytes);
}
}
}
b'^' => {
}
_ => {
nxs_writer.write_str(slot, value);
}
}
}
}
nxs_writer.end_object();
records_written += 1;
Ok(())
})?;
let bytes = nxs_writer.finish();
let output_bytes = bytes.len();
writer
.write_all(&bytes)
.map_err(|e| NxsError::IoError(e.to_string()))?;
Ok(ImportReport {
records_written,
output_bytes,
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::convert::{ConflictPolicy, ImportArgs, ImportFormat, XmlAttrsMode};
fn args_with_tag(tag: &str) -> ImportArgs {
ImportArgs {
from: ImportFormat::Xml,
conflict: ConflictPolicy::Error,
xml_record_tag: Some(tag.to_owned()),
..ImportArgs::default()
}
}
#[test]
fn import_xml_record_tag_required() {
let xml: &[u8] = b"<root><item/></root>";
let args = ImportArgs {
from: ImportFormat::Xml,
conflict: ConflictPolicy::Error,
xml_record_tag: None, ..ImportArgs::default()
};
let result = infer_schema(xml, &args);
assert!(result.is_err(), "missing --xml-record-tag must fail");
}
#[test]
fn import_xml_attributes_as_fields() {
let xml: &[u8] = b"<users><user id=\"1\" name=\"alice\"/></users>";
let args = args_with_tag("user");
let schema = infer_schema(xml, &args).unwrap();
assert!(schema.keys.iter().any(|k| k.name == "id"));
assert!(schema.keys.iter().any(|k| k.name == "name"));
let id = schema.keys.iter().find(|k| k.name == "id").unwrap();
assert_eq!(id.sigil, b'=', "id=\"1\" should infer as int");
}
#[test]
fn import_xml_attributes_prefixed() {
let xml: &[u8] = b"<users><user id=\"1\" name=\"alice\"/></users>";
let args = ImportArgs {
from: ImportFormat::Xml,
conflict: ConflictPolicy::Error,
xml_record_tag: Some("user".into()),
xml_attrs: XmlAttrsMode::Prefix,
..ImportArgs::default()
};
let schema = infer_schema(xml, &args).unwrap();
assert!(schema.keys.iter().any(|k| k.name == "@id"));
assert!(schema.keys.iter().any(|k| k.name == "@name"));
}
#[test]
fn import_xml_nested_elements_become_nested_objects() {
let xml: &[u8] = b"<users><user><addr><city>NYC</city></addr></user></users>";
let args = args_with_tag("user");
let schema = infer_schema(xml, &args).unwrap();
assert!(
schema.keys.iter().any(|k| k.name == "addr.city"),
"nested element must become dot-notation key"
);
}
#[test]
fn import_xml_entity_expansion_rejected() {
let xml: &[u8] = b"<?xml version=\"1.0\"?>\
<!DOCTYPE lolz [\
<!ENTITY lol \"lol\">\
<!ENTITY lol2 \"&lol;&lol;&lol;&lol;&lol;&lol;&lol;&lol;&lol;&lol;\">\
]><root><item/></root>";
let args = args_with_tag("item");
let result = infer_schema(xml, &args);
assert!(result.is_err(), "entity expansion must be rejected");
assert!(
matches!(result.unwrap_err(), NxsError::ConvertEntityExpansion),
"error must be ConvertEntityExpansion"
);
}
#[test]
fn import_xml_depth_limit_enforced() {
let depth = 70usize;
let open: String = "<a>".repeat(depth);
let close: String = "</a>".repeat(depth);
let xml = format!("<root><item>{open}text{close}</item></root>");
let args = ImportArgs {
from: ImportFormat::Xml,
xml_record_tag: Some("item".into()),
max_depth: 64,
xml_max_depth: 64,
..ImportArgs::default()
};
let result = infer_schema(xml.as_bytes(), &args);
assert!(result.is_err());
assert!(
matches!(result.unwrap_err(), NxsError::ConvertDepthExceeded),
"expected ConvertDepthExceeded"
);
}
#[test]
fn import_xml_utf16_bom_accepted() {
let xml: &[u8] = b"<?xml version=\"1.0\" encoding=\"UTF-8\"?><root><item id=\"1\"/></root>";
let args = args_with_tag("item");
let schema = infer_schema(xml, &args).unwrap();
assert!(schema.keys.iter().any(|k| k.name == "id"));
}
#[test]
fn import_xml_unsupported_encoding_exits_3() {
let xml: &[u8] =
b"<?xml version=\"1.0\" encoding=\"EBCDIC\"?><root><item id=\"1\"/></root>";
let args = args_with_tag("item");
let _ = infer_schema(xml, &args);
}
}