use crate::error::MmCifError;
use crate::model::{Atom, AtomName, Chain, ChainId, Model, Residue, ResidueName, Structure};
use smallvec::SmallVec;
use std::io::Read;
use std::path::Path;
const MAX_DECODED_ELEMENTS: usize = 64 * 1024 * 1024;
const MAX_TRUSTED_LEN: usize = 64;
const MAX_MSGPACK_DEPTH: u32 = 64;
pub(crate) fn read_bcif_structure<P: AsRef<Path>>(path: P) -> Result<Structure, MmCifError> {
let mut file = std::fs::File::open(path)?;
let mut bytes = Vec::new();
file.read_to_end(&mut bytes)?;
parse_bcif(&bytes)
}
pub(crate) fn read_bcif_bytes(bytes: &[u8]) -> Result<Structure, MmCifError> {
parse_bcif(bytes)
}
fn parse_bcif(bytes: &[u8]) -> Result<Structure, MmCifError> {
let mut pos = 0usize;
let root = parse_msgpack(bytes, &mut pos, 0)?;
let data_blocks = as_array(
get_key(&root, "dataBlocks")
.ok_or_else(|| MmCifError::Parse("BinaryCIF: missing dataBlocks".into()))?,
"BinaryCIF.dataBlocks",
)?;
let block = data_blocks
.first()
.ok_or_else(|| MmCifError::Parse("BinaryCIF: empty dataBlocks".into()))?;
let categories = as_array(
get_key(block, "categories")
.ok_or_else(|| MmCifError::Parse("BinaryCIF: missing categories".into()))?,
"DataBlock.categories",
)?;
let atom_site = categories
.iter()
.find(|cat| {
get_key(cat, "name")
.and_then(MpValue::as_str)
.map(|name| name == "atom_site" || name == "_atom_site")
.unwrap_or(false)
})
.ok_or(MmCifError::MissingField("_atom_site"))?;
let row_count = get_key(atom_site, "rowCount")
.and_then(MpValue::as_usize)
.ok_or_else(|| MmCifError::Parse("BinaryCIF: atom_site.rowCount is missing".into()))?;
if row_count > MAX_DECODED_ELEMENTS {
return Err(MmCifError::ResourceLimit("BinaryCIF: rowCount too large"));
}
let columns = as_array(
get_key(atom_site, "columns")
.ok_or_else(|| MmCifError::Parse("BinaryCIF: atom_site.columns is missing".into()))?,
"atom_site.columns",
)?;
let mut decoded = AtomSiteColumns::default();
for column in columns {
let Some(name) = get_key(column, "name").and_then(MpValue::as_str) else {
continue;
};
let parsed = decode_column(column, row_count)?;
match name {
"group_PDB" => decoded.group_pdb = Some(parsed),
"label_atom_id" => decoded.label_atom_id = Some(parsed),
"label_comp_id" => decoded.label_comp_id = Some(parsed),
"label_asym_id" => decoded.label_asym_id = Some(parsed),
"label_seq_id" => decoded.label_seq_id = Some(parsed),
"Cartn_x" => decoded.cartn_x = Some(parsed),
"Cartn_y" => decoded.cartn_y = Some(parsed),
"Cartn_z" => decoded.cartn_z = Some(parsed),
"label_alt_id" => decoded.label_alt_id = Some(parsed),
"pdbx_PDB_model_num" => decoded.model_num = Some(parsed),
_ => {}
}
}
decoded.validate_required()?;
let view = decoded.as_view()?;
let mut structure = Structure::default();
structure.models.push(Model::default());
let model = &mut structure.models[0];
let mut chain_cache_pos: Option<usize> = None;
let mut tmp = String::new();
for row in 0..row_count {
let group = view.group_pdb.str_at(row, &mut tmp)?;
if group != "ATOM" {
continue;
}
if let Some(model_col) = view.model_num {
let model_num = model_col.str_at(row, &mut tmp)?;
if !is_first_model(model_num) {
continue;
}
}
if let Some(alt_col) = view.label_alt_id {
let alt_id = alt_col.str_at(row, &mut tmp)?;
if alt_id != "." && alt_id != "A" && alt_id != "?" {
continue;
}
}
let atom_name_token = view.label_atom_id.str_at(row, &mut tmp)?;
let Some(atom_name) = AtomName::from_label_atom_id(atom_name_token) else {
continue;
};
let comp_token = view.label_comp_id.str_at(row, &mut tmp)?;
let residue_name = ResidueName::from_label_comp_id(comp_token);
let chain_token = view.label_asym_id.str_at(row, &mut tmp)?;
let chain_id = ChainId::from_label_asym_id(chain_token)
.ok_or_else(|| MmCifError::InvalidChainId(chain_token.to_string()))?;
let seq_id = view.label_seq_id.int_at(row)?;
let x = view.cartn_x.float_at(row)?;
let y = view.cartn_y.float_at(row)?;
let z = view.cartn_z.float_at(row)?;
let chain_pos = locate_or_append_chain(model, chain_id, &mut chain_cache_pos);
let residues = &mut model.chains[chain_pos].residues;
let use_existing = residues
.last()
.map(|res| res.seq_id == seq_id && res.name == residue_name)
.unwrap_or(false);
if !use_existing {
residues.push(Residue {
name: residue_name,
seq_id,
atoms: SmallVec::new(),
});
}
let residue = residues.last_mut().expect("residue exists");
residue.atoms.push(Atom {
name: atom_name,
x,
y,
z,
});
}
for chain in &mut model.chains {
if chain.residues.len() <= 1 {
continue;
}
let mut needs_sort = false;
for i in 1..chain.residues.len() {
if chain.residues[i - 1].seq_id > chain.residues[i].seq_id {
needs_sort = true;
break;
}
}
if needs_sort {
chain.residues.sort_by_key(|res| res.seq_id);
}
}
Ok(structure)
}
#[inline]
fn locate_or_append_chain(
model: &mut Model,
chain_id: ChainId,
chain_cache_pos: &mut Option<usize>,
) -> usize {
if let Some(pos) = *chain_cache_pos
&& pos < model.chains.len()
&& model.chains[pos].id == chain_id
{
return pos;
}
if let Some(pos) = model.chains.iter().rposition(|c| c.id == chain_id) {
*chain_cache_pos = Some(pos);
return pos;
}
let pos = model.chains.len();
model.chains.push(Chain {
id: chain_id,
residues: Vec::new(),
});
*chain_cache_pos = Some(pos);
pos
}
#[inline]
fn is_first_model(token: &str) -> bool {
let trimmed = token.trim();
if let Ok(int) = trimmed.parse::<i64>() {
return int == 1;
}
if let Ok(flt) = trimmed.parse::<f64>() {
return (flt - 1.0).abs() < 1e-9;
}
false
}
#[derive(Default)]
struct AtomSiteColumns {
group_pdb: Option<DecodedColumn>,
label_atom_id: Option<DecodedColumn>,
label_comp_id: Option<DecodedColumn>,
label_asym_id: Option<DecodedColumn>,
label_seq_id: Option<DecodedColumn>,
cartn_x: Option<DecodedColumn>,
cartn_y: Option<DecodedColumn>,
cartn_z: Option<DecodedColumn>,
label_alt_id: Option<DecodedColumn>,
model_num: Option<DecodedColumn>,
}
impl AtomSiteColumns {
fn validate_required(&self) -> Result<(), MmCifError> {
if self.group_pdb.is_none() {
return Err(MmCifError::MissingField("_atom_site.group_PDB"));
}
if self.label_atom_id.is_none() {
return Err(MmCifError::MissingField("_atom_site.label_atom_id"));
}
if self.label_comp_id.is_none() {
return Err(MmCifError::MissingField("_atom_site.label_comp_id"));
}
if self.label_asym_id.is_none() {
return Err(MmCifError::MissingField("_atom_site.label_asym_id"));
}
if self.label_seq_id.is_none() {
return Err(MmCifError::MissingField("_atom_site.label_seq_id"));
}
if self.cartn_x.is_none() {
return Err(MmCifError::MissingField("_atom_site.Cartn_x"));
}
if self.cartn_y.is_none() {
return Err(MmCifError::MissingField("_atom_site.Cartn_y"));
}
if self.cartn_z.is_none() {
return Err(MmCifError::MissingField("_atom_site.Cartn_z"));
}
Ok(())
}
fn as_view(&self) -> Result<AtomSiteView<'_>, MmCifError> {
Ok(AtomSiteView {
group_pdb: self
.group_pdb
.as_ref()
.ok_or(MmCifError::MissingField("_atom_site.group_PDB"))?,
label_atom_id: self
.label_atom_id
.as_ref()
.ok_or(MmCifError::MissingField("_atom_site.label_atom_id"))?,
label_comp_id: self
.label_comp_id
.as_ref()
.ok_or(MmCifError::MissingField("_atom_site.label_comp_id"))?,
label_asym_id: self
.label_asym_id
.as_ref()
.ok_or(MmCifError::MissingField("_atom_site.label_asym_id"))?,
label_seq_id: self
.label_seq_id
.as_ref()
.ok_or(MmCifError::MissingField("_atom_site.label_seq_id"))?,
cartn_x: self
.cartn_x
.as_ref()
.ok_or(MmCifError::MissingField("_atom_site.Cartn_x"))?,
cartn_y: self
.cartn_y
.as_ref()
.ok_or(MmCifError::MissingField("_atom_site.Cartn_y"))?,
cartn_z: self
.cartn_z
.as_ref()
.ok_or(MmCifError::MissingField("_atom_site.Cartn_z"))?,
label_alt_id: self.label_alt_id.as_ref(),
model_num: self.model_num.as_ref(),
})
}
}
struct AtomSiteView<'a> {
group_pdb: &'a DecodedColumn,
label_atom_id: &'a DecodedColumn,
label_comp_id: &'a DecodedColumn,
label_asym_id: &'a DecodedColumn,
label_seq_id: &'a DecodedColumn,
cartn_x: &'a DecodedColumn,
cartn_y: &'a DecodedColumn,
cartn_z: &'a DecodedColumn,
label_alt_id: Option<&'a DecodedColumn>,
model_num: Option<&'a DecodedColumn>,
}
struct DecodedColumn {
data: DecodedData,
mask: Option<Vec<u8>>,
}
impl DecodedColumn {
fn str_at<'a>(
&'a self,
row: usize,
scratch: &'a mut String,
) -> Result<&'a str, MmCifError> {
if let Some(mask) = &self.mask {
match mask.get(row).copied().unwrap_or(0) {
1 => return Ok("."),
2 => return Ok("?"),
_ => {}
}
}
match &self.data {
DecodedData::Strings(v) => v
.get(row)
.map(|s| s.as_str())
.ok_or(MmCifError::Parse("BinaryCIF: row out of range".into())),
DecodedData::Ints(v) => {
let value = v
.get(row)
.copied()
.ok_or(MmCifError::Parse("BinaryCIF: row out of range".into()))?;
scratch.clear();
use std::fmt::Write as _;
write!(scratch, "{value}").ok();
Ok(scratch.as_str())
}
DecodedData::Floats(v) => {
let value = v
.get(row)
.copied()
.ok_or(MmCifError::Parse("BinaryCIF: row out of range".into()))?;
scratch.clear();
use std::fmt::Write as _;
write!(scratch, "{value}").ok();
Ok(scratch.as_str())
}
DecodedData::Bytes(_) => Err(MmCifError::Parse(
"BinaryCIF: cannot interpret byte column as text".into(),
)),
}
}
fn int_at(&self, row: usize) -> Result<Option<i32>, MmCifError> {
if let Some(mask) = &self.mask
&& matches!(mask.get(row).copied().unwrap_or(0), 1 | 2)
{
return Ok(None);
}
match &self.data {
DecodedData::Ints(v) => v
.get(row)
.copied()
.map(Some)
.ok_or(MmCifError::Parse("BinaryCIF: row out of range".into())),
DecodedData::Floats(v) => v
.get(row)
.map(|f| Some(*f as i32))
.ok_or(MmCifError::Parse("BinaryCIF: row out of range".into())),
DecodedData::Strings(v) => {
let s = v
.get(row)
.ok_or(MmCifError::Parse("BinaryCIF: row out of range".into()))?;
if s == "." || s == "?" {
return Ok(None);
}
let parsed: i32 = s.parse()?;
Ok(Some(parsed))
}
DecodedData::Bytes(_) => Err(MmCifError::Parse(
"BinaryCIF: cannot interpret byte column as integer".into(),
)),
}
}
fn float_at(&self, row: usize) -> Result<f32, MmCifError> {
if let Some(mask) = &self.mask
&& matches!(mask.get(row).copied().unwrap_or(0), 1 | 2)
{
return Err(MmCifError::Parse("BinaryCIF: missing float".into()));
}
match &self.data {
DecodedData::Floats(v) => v
.get(row)
.copied()
.ok_or(MmCifError::Parse("BinaryCIF: row out of range".into())),
DecodedData::Ints(v) => v
.get(row)
.map(|i| *i as f32)
.ok_or(MmCifError::Parse("BinaryCIF: row out of range".into())),
DecodedData::Strings(v) => {
let s = v
.get(row)
.ok_or(MmCifError::Parse("BinaryCIF: row out of range".into()))?;
if s == "." || s == "?" {
return Err(MmCifError::Parse("BinaryCIF: missing float".into()));
}
let parsed: f32 = s.parse()?;
Ok(parsed)
}
DecodedData::Bytes(_) => Err(MmCifError::Parse(
"BinaryCIF: cannot interpret byte column as float".into(),
)),
}
}
}
fn decode_column(column: &MpValue, row_count: usize) -> Result<DecodedColumn, MmCifError> {
let data = get_key(column, "data")
.ok_or_else(|| MmCifError::Parse("BinaryCIF: column missing data".into()))?;
let decoded = decode_encoded_data(data)?;
if decoded.len() != row_count {
return Err(MmCifError::Parse(format!(
"BinaryCIF: column length mismatch: expected {row_count}, got {}",
decoded.len()
)));
}
let mask = if let Some(mask_raw) = get_key(column, "mask") {
let mask_data = decode_encoded_data(mask_raw)?;
let mask_ints = mask_data.into_ints()?;
if mask_ints.len() != row_count {
return Err(MmCifError::Parse("BinaryCIF: mask length mismatch".into()));
}
let mut bytes = Vec::with_capacity(mask_ints.len());
for v in &mask_ints {
bytes.push((*v).clamp(0, 255) as u8);
}
Some(bytes)
} else {
None
};
Ok(DecodedColumn {
data: decoded,
mask,
})
}
fn decode_encoded_data(encoded: &MpValue) -> Result<DecodedData, MmCifError> {
let encodings = get_key(encoded, "encoding")
.ok_or_else(|| MmCifError::Parse("BinaryCIF: EncodedData.encoding missing".into()))?;
let encodings = as_array(encodings, "EncodedData.encoding")?;
let data = get_key(encoded, "data")
.ok_or_else(|| MmCifError::Parse("BinaryCIF: EncodedData.data missing".into()))?;
let mut decoded = decode_raw_data(data)?;
for encoding in encodings.iter().rev() {
decoded = apply_encoding(decoded, encoding)?;
}
Ok(decoded)
}
fn decode_raw_data(data: &MpValue) -> Result<DecodedData, MmCifError> {
match data {
MpValue::Bin(v) => Ok(DecodedData::Bytes(v.clone())),
MpValue::Array(items) => {
if items.iter().all(|v| matches!(v, MpValue::Str(_))) {
let strs = items
.iter()
.map(|v| v.as_str().unwrap_or_default().to_string())
.collect::<Vec<_>>();
Ok(DecodedData::Strings(strs))
} else if items.iter().all(|v| v.as_i64().is_some()) {
let ints = items
.iter()
.map(|v| v.as_i64().unwrap_or(0) as i32)
.collect::<Vec<_>>();
Ok(DecodedData::Ints(ints))
} else if items.iter().all(|v| v.as_f64().is_some()) {
let floats = items
.iter()
.map(|v| v.as_f64().unwrap_or(0.0) as f32)
.collect::<Vec<_>>();
Ok(DecodedData::Floats(floats))
} else {
Err(MmCifError::Parse(
"BinaryCIF: unsupported raw array element type".into(),
))
}
}
_ => Err(MmCifError::Parse(
"BinaryCIF: unsupported raw data payload".into(),
)),
}
}
fn apply_encoding(data: DecodedData, encoding: &MpValue) -> Result<DecodedData, MmCifError> {
let kind = get_key(encoding, "kind")
.and_then(MpValue::as_str)
.ok_or_else(|| MmCifError::Parse("BinaryCIF: encoding.kind missing".into()))?;
match kind {
"ByteArray" => decode_byte_array(data, encoding),
"FixedPoint" => decode_fixed_point(data, encoding),
"IntervalQuantization" => decode_interval_quantization(data, encoding),
"RunLength" => decode_run_length(data),
"Delta" => decode_delta(data, encoding),
"IntegerPacking" => decode_integer_packing(data, encoding),
"StringArray" => decode_string_array(data, encoding),
other => Err(MmCifError::UnsupportedEncoding(other.to_string())),
}
}
fn decode_byte_array(data: DecodedData, encoding: &MpValue) -> Result<DecodedData, MmCifError> {
let ty = get_key(encoding, "type")
.and_then(MpValue::as_i64)
.ok_or_else(|| MmCifError::Parse("ByteArray: missing type".into()))?;
let bytes = data.into_bytes()?;
match ty {
1 => Ok(DecodedData::Ints(
bytes.iter().map(|&b| (b as i8) as i32).collect(),
)),
2 => {
if !bytes.len().is_multiple_of(2) {
return Err(MmCifError::Parse(
"ByteArray i16: invalid byte length".into(),
));
}
let mut out = Vec::with_capacity(bytes.len() / 2);
for chunk in bytes.chunks_exact(2) {
out.push(i16::from_le_bytes([chunk[0], chunk[1]]) as i32);
}
Ok(DecodedData::Ints(out))
}
3 => {
if !bytes.len().is_multiple_of(4) {
return Err(MmCifError::Parse(
"ByteArray i32: invalid byte length".into(),
));
}
let mut out = Vec::with_capacity(bytes.len() / 4);
for chunk in bytes.chunks_exact(4) {
out.push(i32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]));
}
Ok(DecodedData::Ints(out))
}
4 => Ok(DecodedData::Ints(bytes.iter().map(|&b| b as i32).collect())),
5 => {
if !bytes.len().is_multiple_of(2) {
return Err(MmCifError::Parse(
"ByteArray u16: invalid byte length".into(),
));
}
let mut out = Vec::with_capacity(bytes.len() / 2);
for chunk in bytes.chunks_exact(2) {
out.push(u16::from_le_bytes([chunk[0], chunk[1]]) as i32);
}
Ok(DecodedData::Ints(out))
}
6 => {
if !bytes.len().is_multiple_of(4) {
return Err(MmCifError::Parse(
"ByteArray u32: invalid byte length".into(),
));
}
let mut out = Vec::with_capacity(bytes.len() / 4);
for chunk in bytes.chunks_exact(4) {
out.push(u32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]) as i32);
}
Ok(DecodedData::Ints(out))
}
32 => {
if !bytes.len().is_multiple_of(4) {
return Err(MmCifError::Parse(
"ByteArray f32: invalid byte length".into(),
));
}
let mut out = Vec::with_capacity(bytes.len() / 4);
for chunk in bytes.chunks_exact(4) {
out.push(f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]));
}
Ok(DecodedData::Floats(out))
}
33 => {
if !bytes.len().is_multiple_of(8) {
return Err(MmCifError::Parse(
"ByteArray f64: invalid byte length".into(),
));
}
let mut out = Vec::with_capacity(bytes.len() / 8);
for chunk in bytes.chunks_exact(8) {
out.push(f64::from_le_bytes([
chunk[0], chunk[1], chunk[2], chunk[3], chunk[4], chunk[5], chunk[6], chunk[7],
]) as f32);
}
Ok(DecodedData::Floats(out))
}
other => Err(MmCifError::UnsupportedEncoding(format!(
"ByteArray.type={other}"
))),
}
}
fn decode_fixed_point(data: DecodedData, encoding: &MpValue) -> Result<DecodedData, MmCifError> {
let factor = get_key(encoding, "factor")
.and_then(MpValue::as_f64)
.ok_or_else(|| MmCifError::Parse("FixedPoint: missing factor".into()))?;
if !factor.is_finite() || factor == 0.0 {
return Err(MmCifError::Parse("FixedPoint: factor is invalid".into()));
}
let ints = data.into_ints()?;
let inv = 1.0 / factor;
let out = ints.into_iter().map(|v| (v as f64 * inv) as f32).collect();
Ok(DecodedData::Floats(out))
}
fn decode_interval_quantization(
data: DecodedData,
encoding: &MpValue,
) -> Result<DecodedData, MmCifError> {
let min = get_key(encoding, "min")
.and_then(MpValue::as_f64)
.ok_or_else(|| MmCifError::Parse("IntervalQuantization: missing min".into()))?;
let max = get_key(encoding, "max")
.and_then(MpValue::as_f64)
.ok_or_else(|| MmCifError::Parse("IntervalQuantization: missing max".into()))?;
let steps = get_key(encoding, "numSteps")
.and_then(MpValue::as_f64)
.ok_or_else(|| MmCifError::Parse("IntervalQuantization: missing numSteps".into()))?;
if !min.is_finite() || !max.is_finite() || !steps.is_finite() || steps <= 0.0 {
return Err(MmCifError::Parse(
"IntervalQuantization: invalid min/max/numSteps".into(),
));
}
let ints = data.into_ints()?;
let span = max - min;
let out = ints
.into_iter()
.map(|v| (min + (v as f64 * span) / steps) as f32)
.collect();
Ok(DecodedData::Floats(out))
}
fn decode_run_length(data: DecodedData) -> Result<DecodedData, MmCifError> {
let ints = data.into_ints()?;
if !ints.len().is_multiple_of(2) {
return Err(MmCifError::Parse("RunLength: odd integer count".into()));
}
let mut total: u64 = 0;
for pair in ints.chunks_exact(2) {
let count = pair[1];
if count < 0 {
return Err(MmCifError::Parse("RunLength: negative count".into()));
}
total = total.saturating_add(count as u64);
if total as usize > MAX_DECODED_ELEMENTS {
return Err(MmCifError::ResourceLimit(
"RunLength: decoded output exceeds limit",
));
}
}
let mut out: Vec<i32> = Vec::with_capacity(total as usize);
for pair in ints.chunks_exact(2) {
let value = pair[0];
let count = pair[1] as usize;
if count == 0 {
continue;
}
let new_len = out.len() + count;
out.resize(new_len, value);
}
Ok(DecodedData::Ints(out))
}
fn decode_delta(data: DecodedData, encoding: &MpValue) -> Result<DecodedData, MmCifError> {
let origin = get_key(encoding, "origin")
.and_then(MpValue::as_i64)
.ok_or_else(|| MmCifError::Parse("Delta: missing origin".into()))? as i32;
let mut ints = data.into_ints()?;
let mut acc = origin;
for item in &mut ints {
acc = acc.wrapping_add(*item);
*item = acc;
}
Ok(DecodedData::Ints(ints))
}
fn decode_integer_packing(
data: DecodedData,
encoding: &MpValue,
) -> Result<DecodedData, MmCifError> {
let ints = data.into_ints()?;
let byte_count = get_key(encoding, "byteCount")
.and_then(MpValue::as_i64)
.ok_or_else(|| MmCifError::Parse("IntegerPacking: missing byteCount".into()))?;
let is_unsigned = get_key(encoding, "isUnsigned")
.and_then(MpValue::as_bool)
.unwrap_or(false);
if !(byte_count == 1 || byte_count == 2 || byte_count == 4) {
return Err(MmCifError::UnsupportedEncoding(format!(
"IntegerPacking.byteCount={byte_count}"
)));
}
let byte_count = byte_count as i32;
let bits = byte_count * 8;
let unsigned_limit = ((1i64 << bits) - 1) as i32;
let signed_upper = ((1i64 << (bits - 1)) - 1) as i32;
let signed_lower = (-(1i64 << (bits - 1))) as i32;
let mut out: Vec<i32> = Vec::with_capacity(ints.len());
let mut acc: i32 = 0;
for v in ints {
if is_unsigned {
if v == unsigned_limit {
acc = acc.wrapping_add(v);
continue;
}
out.push(acc.wrapping_add(v));
acc = 0;
} else if v == signed_upper || v == signed_lower {
acc = acc.wrapping_add(v);
continue;
} else {
out.push(acc.wrapping_add(v));
acc = 0;
}
if out.len() > MAX_DECODED_ELEMENTS {
return Err(MmCifError::ResourceLimit(
"IntegerPacking: decoded output exceeds limit",
));
}
}
Ok(DecodedData::Ints(out))
}
fn decode_string_array(data: DecodedData, encoding: &MpValue) -> Result<DecodedData, MmCifError> {
let indices = data.into_ints()?;
if indices.len() > MAX_DECODED_ELEMENTS {
return Err(MmCifError::ResourceLimit(
"StringArray: indices exceed limit",
));
}
let string_data = get_key(encoding, "stringData")
.ok_or_else(|| MmCifError::Parse("StringArray: missing stringData".into()))?;
let data_enc = get_key(encoding, "dataEncoding")
.ok_or_else(|| MmCifError::Parse("StringArray: missing dataEncoding".into()))?;
let data_enc = as_array(data_enc, "StringArray.dataEncoding")?;
let mut decoded_string_data = decode_raw_data(string_data)?;
for enc in data_enc.iter().rev() {
decoded_string_data = apply_encoding(decoded_string_data, enc)?;
}
let string_bytes = decoded_string_data.into_bytes()?;
if string_bytes.len() > MAX_DECODED_ELEMENTS {
return Err(MmCifError::ResourceLimit(
"StringArray: string data exceeds limit",
));
}
let offsets = get_key(encoding, "offsets")
.ok_or_else(|| MmCifError::Parse("StringArray: missing offsets".into()))?;
let offset_enc = get_key(encoding, "offsetEncoding")
.ok_or_else(|| MmCifError::Parse("StringArray: missing offsetEncoding".into()))?;
let offset_enc = as_array(offset_enc, "StringArray.offsetEncoding")?;
let mut decoded_offsets = decode_raw_data(offsets)?;
for enc in offset_enc.iter().rev() {
decoded_offsets = apply_encoding(decoded_offsets, enc)?;
}
let offsets = decoded_offsets.into_ints()?;
if offsets.len() < 2 {
return Err(MmCifError::Parse(
"StringArray: offsets are too short".into(),
));
}
if offsets.len() > MAX_DECODED_ELEMENTS {
return Err(MmCifError::ResourceLimit(
"StringArray: offsets exceed limit",
));
}
let mut dict: Vec<String> = Vec::with_capacity(offsets.len() - 1);
for pair in offsets.windows(2) {
let start = pair[0];
let end = pair[1];
if start < 0 || end < start {
return Err(MmCifError::Parse("StringArray: invalid offsets".into()));
}
let start = start as usize;
let end = end as usize;
if end > string_bytes.len() {
return Err(MmCifError::Parse(
"StringArray: offset out of bounds for string data".into(),
));
}
let text = std::str::from_utf8(&string_bytes[start..end])
.map_err(|_| MmCifError::Parse("StringArray: invalid UTF-8".into()))?;
dict.push(text.to_string());
}
let mut out: Vec<String> = Vec::with_capacity(indices.len());
for idx in indices {
if idx < 0 {
out.push(String::new());
continue;
}
let idx = idx as usize;
let value = dict.get(idx).ok_or_else(|| {
MmCifError::Parse("StringArray: dictionary index out of range".into())
})?;
out.push(value.clone());
}
Ok(DecodedData::Strings(out))
}
enum DecodedData {
Bytes(Vec<u8>),
Ints(Vec<i32>),
Floats(Vec<f32>),
Strings(Vec<String>),
}
impl DecodedData {
fn len(&self) -> usize {
match self {
Self::Bytes(v) => v.len(),
Self::Ints(v) => v.len(),
Self::Floats(v) => v.len(),
Self::Strings(v) => v.len(),
}
}
fn into_bytes(self) -> Result<Vec<u8>, MmCifError> {
match self {
Self::Bytes(v) => Ok(v),
Self::Ints(v) => Ok(v.into_iter().map(|x| x as u8).collect()),
_ => Err(MmCifError::Parse(
"BinaryCIF: expected bytes but got another data kind".into(),
)),
}
}
fn into_ints(self) -> Result<Vec<i32>, MmCifError> {
match self {
Self::Ints(v) => Ok(v),
_ => Err(MmCifError::Parse(
"BinaryCIF: expected integers but got another data kind".into(),
)),
}
}
}
#[derive(Debug, Clone)]
enum MpValue {
Nil,
Bool(bool),
Int(i64),
UInt(u64),
F32(f32),
F64(f64),
Str(String),
Bin(Vec<u8>),
Array(Vec<MpValue>),
Map(Vec<(MpValue, MpValue)>),
}
impl MpValue {
fn as_str(&self) -> Option<&str> {
if let Self::Str(v) = self {
Some(v.as_str())
} else {
None
}
}
fn as_i64(&self) -> Option<i64> {
match self {
Self::Int(v) => Some(*v),
Self::UInt(v) => i64::try_from(*v).ok(),
_ => None,
}
}
fn as_usize(&self) -> Option<usize> {
self.as_i64().and_then(|v| usize::try_from(v).ok())
}
fn as_f64(&self) -> Option<f64> {
match self {
Self::F32(v) => Some(*v as f64),
Self::F64(v) => Some(*v),
Self::Int(v) => Some(*v as f64),
Self::UInt(v) => Some(*v as f64),
_ => None,
}
}
fn as_bool(&self) -> Option<bool> {
if let Self::Bool(v) = self {
Some(*v)
} else {
None
}
}
}
fn as_array<'a>(value: &'a MpValue, ctx: &str) -> Result<&'a [MpValue], MmCifError> {
if let MpValue::Array(v) = value {
Ok(v)
} else {
Err(MmCifError::Parse(format!("{ctx} is not an array")))
}
}
fn get_key<'a>(value: &'a MpValue, key: &str) -> Option<&'a MpValue> {
let MpValue::Map(entries) = value else {
return None;
};
entries
.iter()
.find_map(|(k, v)| match k {
MpValue::Str(name) if name == key => Some(v),
_ => None,
})
}
fn parse_msgpack(bytes: &[u8], pos: &mut usize, depth: u32) -> Result<MpValue, MmCifError> {
if depth > MAX_MSGPACK_DEPTH {
return Err(MmCifError::ResourceLimit("BinaryCIF: msgpack nesting too deep"));
}
let marker = read_u8(bytes, pos)?;
match marker {
0x00..=0x7f => Ok(MpValue::UInt(marker as u64)),
0x80..=0x8f => {
let len = (marker & 0x0f) as usize;
parse_map(bytes, pos, len, depth + 1)
}
0x90..=0x9f => {
let len = (marker & 0x0f) as usize;
parse_array(bytes, pos, len, depth + 1)
}
0xa0..=0xbf => {
let len = (marker & 0x1f) as usize;
Ok(MpValue::Str(read_string(bytes, pos, len)?))
}
0xc0 => Ok(MpValue::Nil),
0xc2 => Ok(MpValue::Bool(false)),
0xc3 => Ok(MpValue::Bool(true)),
0xc4 => {
let len = read_u8(bytes, pos)? as usize;
Ok(MpValue::Bin(read_bytes(bytes, pos, len)?.to_vec()))
}
0xc5 => {
let len = read_u16(bytes, pos)? as usize;
Ok(MpValue::Bin(read_bytes(bytes, pos, len)?.to_vec()))
}
0xc6 => {
let len = read_u32(bytes, pos)? as usize;
Ok(MpValue::Bin(read_bytes(bytes, pos, len)?.to_vec()))
}
0xca => Ok(MpValue::F32(f32::from_bits(read_u32(bytes, pos)?))),
0xcb => Ok(MpValue::F64(f64::from_bits(read_u64(bytes, pos)?))),
0xcc => Ok(MpValue::UInt(read_u8(bytes, pos)? as u64)),
0xcd => Ok(MpValue::UInt(read_u16(bytes, pos)? as u64)),
0xce => Ok(MpValue::UInt(read_u32(bytes, pos)? as u64)),
0xcf => Ok(MpValue::UInt(read_u64(bytes, pos)?)),
0xd0 => Ok(MpValue::Int(read_i8(bytes, pos)? as i64)),
0xd1 => Ok(MpValue::Int(read_i16(bytes, pos)? as i64)),
0xd2 => Ok(MpValue::Int(read_i32(bytes, pos)? as i64)),
0xd3 => Ok(MpValue::Int(read_i64(bytes, pos)?)),
0xd9 => {
let len = read_u8(bytes, pos)? as usize;
Ok(MpValue::Str(read_string(bytes, pos, len)?))
}
0xda => {
let len = read_u16(bytes, pos)? as usize;
Ok(MpValue::Str(read_string(bytes, pos, len)?))
}
0xdb => {
let len = read_u32(bytes, pos)? as usize;
Ok(MpValue::Str(read_string(bytes, pos, len)?))
}
0xdc => {
let len = read_u16(bytes, pos)? as usize;
parse_array(bytes, pos, len, depth + 1)
}
0xdd => {
let len = read_u32(bytes, pos)? as usize;
parse_array(bytes, pos, len, depth + 1)
}
0xde => {
let len = read_u16(bytes, pos)? as usize;
parse_map(bytes, pos, len, depth + 1)
}
0xdf => {
let len = read_u32(bytes, pos)? as usize;
parse_map(bytes, pos, len, depth + 1)
}
0xe0..=0xff => Ok(MpValue::Int((marker as i8) as i64)),
_ => Err(MmCifError::Parse(format!(
"BinaryCIF: unsupported msgpack marker 0x{marker:02x}"
))),
}
}
fn parse_array(
bytes: &[u8],
pos: &mut usize,
len: usize,
depth: u32,
) -> Result<MpValue, MmCifError> {
let cap = len.min(MAX_TRUSTED_LEN);
let mut out = Vec::with_capacity(cap);
for _ in 0..len {
out.push(parse_msgpack(bytes, pos, depth)?);
}
Ok(MpValue::Array(out))
}
fn parse_map(
bytes: &[u8],
pos: &mut usize,
len: usize,
depth: u32,
) -> Result<MpValue, MmCifError> {
let cap = len.min(MAX_TRUSTED_LEN);
let mut out = Vec::with_capacity(cap);
for _ in 0..len {
let k = parse_msgpack(bytes, pos, depth)?;
let v = parse_msgpack(bytes, pos, depth)?;
out.push((k, v));
}
Ok(MpValue::Map(out))
}
fn read_bytes<'a>(bytes: &'a [u8], pos: &mut usize, len: usize) -> Result<&'a [u8], MmCifError> {
let end = pos
.checked_add(len)
.ok_or_else(|| MmCifError::Parse("BinaryCIF: position overflow".into()))?;
if end > bytes.len() {
return Err(MmCifError::Parse(
"BinaryCIF: unexpected end of msgpack stream".into(),
));
}
let slice = &bytes[*pos..end];
*pos = end;
Ok(slice)
}
fn read_string(bytes: &[u8], pos: &mut usize, len: usize) -> Result<String, MmCifError> {
let raw = read_bytes(bytes, pos, len)?;
let text = std::str::from_utf8(raw)
.map_err(|_| MmCifError::Parse("BinaryCIF: invalid UTF-8 string".into()))?;
Ok(text.to_string())
}
fn read_u8(bytes: &[u8], pos: &mut usize) -> Result<u8, MmCifError> {
Ok(*read_bytes(bytes, pos, 1)?
.first()
.ok_or_else(|| MmCifError::Parse("BinaryCIF: missing byte".into()))?)
}
fn read_u16(bytes: &[u8], pos: &mut usize) -> Result<u16, MmCifError> {
let b = read_bytes(bytes, pos, 2)?;
Ok(u16::from_be_bytes([b[0], b[1]]))
}
fn read_u32(bytes: &[u8], pos: &mut usize) -> Result<u32, MmCifError> {
let b = read_bytes(bytes, pos, 4)?;
Ok(u32::from_be_bytes([b[0], b[1], b[2], b[3]]))
}
fn read_u64(bytes: &[u8], pos: &mut usize) -> Result<u64, MmCifError> {
let b = read_bytes(bytes, pos, 8)?;
Ok(u64::from_be_bytes([
b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7],
]))
}
fn read_i8(bytes: &[u8], pos: &mut usize) -> Result<i8, MmCifError> {
Ok(read_u8(bytes, pos)? as i8)
}
fn read_i16(bytes: &[u8], pos: &mut usize) -> Result<i16, MmCifError> {
Ok(read_u16(bytes, pos)? as i16)
}
fn read_i32(bytes: &[u8], pos: &mut usize) -> Result<i32, MmCifError> {
Ok(read_u32(bytes, pos)? as i32)
}
fn read_i64(bytes: &[u8], pos: &mut usize) -> Result<i64, MmCifError> {
Ok(read_u64(bytes, pos)? as i64)
}