use crate::error::MmCifError;
use crate::model::{Atom, AtomName, Chain, ChainId, Model, Residue, ResidueName, Structure};
use smallvec::SmallVec;
use std::io::BufRead;
use std::sync::OnceLock;
pub(crate) const READ_BUFFER_CAPACITY: usize = 256 * 1024;
pub(crate) fn parse_structure<R: BufRead>(reader: R) -> Result<Structure, MmCifError> {
let mut tokenizer = Tokenizer::new(reader);
let mut structure = Structure::default();
structure.models.push(Model::default());
let mut saw_atom_site_loop = false;
let mut saw_any_token = false;
while let Some(token) = tokenizer.next_token()? {
saw_any_token = true;
if is_keyword(token, b"loop_") {
let atom_site_seen = parse_loop(&mut tokenizer, &mut structure.models[0])?;
saw_atom_site_loop |= atom_site_seen;
}
}
if !saw_any_token {
return Err(MmCifError::Parse(
"empty or non-mmCIF input".into(),
));
}
if !saw_atom_site_loop {
return Err(MmCifError::MissingField("_atom_site"));
}
for chain in &mut structure.models[0].chains {
if chain.residues.len() <= 1 {
continue;
}
let mut needs_sort = false;
for i in 1..chain.residues.len() {
if chain.residues[i - 1].seq_id > chain.residues[i].seq_id {
needs_sort = true;
break;
}
}
if needs_sort {
chain.residues.sort_by_key(|res| res.seq_id);
}
}
Ok(structure)
}
fn parse_loop<R: BufRead>(
tokenizer: &mut Tokenizer<R>,
model: &mut Model,
) -> Result<bool, MmCifError> {
let mut tags: Vec<Vec<u8>> = Vec::new();
let mut first_value: Option<Vec<u8>> = None;
while let Some(token) = tokenizer.next_token()? {
if !token.is_empty() && token[0] == b'_' {
tags.push(token.to_vec());
} else {
first_value = Some(token.to_vec());
break;
}
}
if tags.is_empty() {
return Err(MmCifError::Parse("loop_ without tags".into()));
}
let tag_count = tags.len();
let atom_site = tags
.iter()
.any(|t| starts_with_ci(t, b"_atom_site."));
let dispatch: Vec<Field> = if atom_site {
tags.iter().map(|t| Field::from_tag(t)).collect()
} else {
vec![Field::Ignore; tag_count]
};
if atom_site {
require_field(&dispatch, Field::GroupPdb, "_atom_site.group_PDB")?;
require_field(&dispatch, Field::LabelAtomId, "_atom_site.label_atom_id")?;
require_field(&dispatch, Field::LabelCompId, "_atom_site.label_comp_id")?;
require_field(&dispatch, Field::LabelAsymId, "_atom_site.label_asym_id")?;
require_field(&dispatch, Field::LabelSeqId, "_atom_site.label_seq_id")?;
require_field(&dispatch, Field::CartnX, "_atom_site.Cartn_x")?;
require_field(&dispatch, Field::CartnY, "_atom_site.Cartn_y")?;
require_field(&dispatch, Field::CartnZ, "_atom_site.Cartn_z")?;
}
let mut value_index = 0usize;
let mut current = AtomRow::default();
let mut chain_cache_pos: Option<usize> = None;
if let Some(value) = first_value.take() {
if is_loop_terminator(&value) {
tokenizer.push_back(&value);
return Ok(atom_site);
}
if atom_site {
current.capture(dispatch[0], &value);
}
value_index = 1;
if value_index == tag_count {
if atom_site {
current.emit(model, &mut chain_cache_pos)?;
}
current.reset();
value_index = 0;
}
}
loop {
let token = match tokenizer.next_token()? {
Some(tok) => tok,
None => {
if value_index != 0 {
return Err(MmCifError::Parse("incomplete loop row".into()));
}
break;
}
};
if value_index == 0 && is_loop_terminator(token) {
let saved: Vec<u8> = token.to_vec();
tokenizer.push_back(&saved);
break;
}
if atom_site {
current.capture(dispatch[value_index], token);
}
value_index += 1;
if value_index == tag_count {
if atom_site {
current.emit(model, &mut chain_cache_pos)?;
}
current.reset();
value_index = 0;
}
}
Ok(atom_site)
}
fn require_field(
dispatch: &[Field],
field: Field,
name: &'static str,
) -> Result<(), MmCifError> {
if dispatch.contains(&field) {
Ok(())
} else {
Err(MmCifError::MissingField(name))
}
}
fn is_keyword(token: &[u8], keyword: &[u8]) -> bool {
token.eq_ignore_ascii_case(keyword)
}
fn is_loop_terminator(token: &[u8]) -> bool {
is_keyword(token, b"loop_")
|| is_keyword(token, b"stop_")
|| is_keyword(token, b"global_")
|| starts_with_ci(token, b"data_")
|| starts_with_ci(token, b"save_")
|| (!token.is_empty() && token[0] == b'_')
}
#[inline]
fn starts_with_ci(token: &[u8], prefix: &[u8]) -> bool {
token.len() >= prefix.len() && token[..prefix.len()].eq_ignore_ascii_case(prefix)
}
#[derive(Clone, Copy, PartialEq, Eq)]
enum Field {
Ignore,
GroupPdb,
LabelAtomId,
LabelCompId,
LabelAsymId,
LabelSeqId,
CartnX,
CartnY,
CartnZ,
LabelAltId,
ModelNum,
}
impl Field {
fn from_tag(tag: &[u8]) -> Self {
if eq_ci(tag, b"_atom_site.group_PDB") {
Self::GroupPdb
} else if eq_ci(tag, b"_atom_site.label_atom_id") {
Self::LabelAtomId
} else if eq_ci(tag, b"_atom_site.label_comp_id") {
Self::LabelCompId
} else if eq_ci(tag, b"_atom_site.label_asym_id") {
Self::LabelAsymId
} else if eq_ci(tag, b"_atom_site.label_seq_id") {
Self::LabelSeqId
} else if eq_ci(tag, b"_atom_site.Cartn_x") {
Self::CartnX
} else if eq_ci(tag, b"_atom_site.Cartn_y") {
Self::CartnY
} else if eq_ci(tag, b"_atom_site.Cartn_z") {
Self::CartnZ
} else if eq_ci(tag, b"_atom_site.label_alt_id") {
Self::LabelAltId
} else if eq_ci(tag, b"_atom_site.pdbx_PDB_model_num") {
Self::ModelNum
} else {
Self::Ignore
}
}
}
#[inline]
fn eq_ci(a: &[u8], b: &[u8]) -> bool {
a.eq_ignore_ascii_case(b)
}
struct AtomRow {
group_is_atom: Option<bool>,
atom_name: Option<Option<AtomName>>,
residue_name: Option<ResidueName>,
chain_id: Option<ChainId>,
chain_error: Option<Vec<u8>>,
seq_id: Option<Option<i32>>,
seq_error: Option<Vec<u8>>,
x: Option<f32>,
x_error: Option<Vec<u8>>,
y: Option<f32>,
y_error: Option<Vec<u8>>,
z: Option<f32>,
z_error: Option<Vec<u8>>,
keep_altloc: bool,
keep_model: bool,
}
impl Default for AtomRow {
fn default() -> Self {
Self {
group_is_atom: None,
atom_name: None,
residue_name: None,
chain_id: None,
chain_error: None,
seq_id: None,
seq_error: None,
x: None,
x_error: None,
y: None,
y_error: None,
z: None,
z_error: None,
keep_altloc: true,
keep_model: true,
}
}
}
impl AtomRow {
fn reset(&mut self) {
self.group_is_atom = None;
self.atom_name = None;
self.residue_name = None;
self.chain_id = None;
self.chain_error = None;
self.seq_id = None;
self.seq_error = None;
self.x = None;
self.x_error = None;
self.y = None;
self.y_error = None;
self.z = None;
self.z_error = None;
self.keep_altloc = true;
self.keep_model = true;
}
fn capture(&mut self, field: Field, token: &[u8]) {
match field {
Field::Ignore => {}
Field::GroupPdb => self.group_is_atom = Some(token == b"ATOM"),
Field::LabelAtomId => {
self.atom_name = Some(AtomName::from_label_atom_id_bytes(token));
}
Field::LabelCompId => {
self.residue_name = Some(ResidueName::from_label_comp_id_bytes(token));
}
Field::LabelAsymId => match ChainId::from_bytes(token) {
Some(id) => self.chain_id = Some(id),
None => self.chain_error = Some(token.to_vec()),
},
Field::LabelSeqId => match parse_seq_id(token) {
Ok(v) => self.seq_id = Some(v),
Err(_) => self.seq_error = Some(token.to_vec()),
},
Field::CartnX => match parse_f32(token) {
Ok(v) => self.x = Some(v),
Err(_) => self.x_error = Some(token.to_vec()),
},
Field::CartnY => match parse_f32(token) {
Ok(v) => self.y = Some(v),
Err(_) => self.y_error = Some(token.to_vec()),
},
Field::CartnZ => match parse_f32(token) {
Ok(v) => self.z = Some(v),
Err(_) => self.z_error = Some(token.to_vec()),
},
Field::LabelAltId => {
self.keep_altloc = matches_altloc(token);
}
Field::ModelNum => {
self.keep_model = matches_first_model(token);
}
}
}
fn emit(
&mut self,
model: &mut Model,
chain_cache_pos: &mut Option<usize>,
) -> Result<(), MmCifError> {
let keep_atom = match self.group_is_atom {
Some(val) => val,
None => return Err(MmCifError::MissingField("_atom_site.group_PDB")),
};
if !keep_atom || !self.keep_model || !self.keep_altloc {
return Ok(());
}
let atom_name = match self.atom_name {
Some(name) => name,
None => return Err(MmCifError::MissingField("_atom_site.label_atom_id")),
};
let atom_name = match atom_name {
Some(name) => name,
None => return Ok(()),
};
let residue_name = match self.residue_name {
Some(val) => val,
None => return Err(MmCifError::MissingField("_atom_site.label_comp_id")),
};
if let Some(id) = self.chain_error.take() {
return Err(MmCifError::InvalidChainId(
String::from_utf8_lossy(&id).into_owned(),
));
}
let chain_id = match self.chain_id {
Some(val) => val,
None => return Err(MmCifError::MissingField("_atom_site.label_asym_id")),
};
if let Some(err) = self.seq_error.take() {
return Err(MmCifError::Parse(format!(
"invalid label_seq_id: {}",
String::from_utf8_lossy(&err)
)));
}
let seq_id = match self.seq_id {
Some(val) => val,
None => return Err(MmCifError::MissingField("_atom_site.label_seq_id")),
};
if let Some(err) = self.x_error.take() {
return Err(MmCifError::Parse(format!(
"invalid Cartn_x: {}",
String::from_utf8_lossy(&err)
)));
}
let x = match self.x {
Some(val) => val,
None => return Err(MmCifError::MissingField("_atom_site.Cartn_x")),
};
if let Some(err) = self.y_error.take() {
return Err(MmCifError::Parse(format!(
"invalid Cartn_y: {}",
String::from_utf8_lossy(&err)
)));
}
let y = match self.y {
Some(val) => val,
None => return Err(MmCifError::MissingField("_atom_site.Cartn_y")),
};
if let Some(err) = self.z_error.take() {
return Err(MmCifError::Parse(format!(
"invalid Cartn_z: {}",
String::from_utf8_lossy(&err)
)));
}
let z = match self.z {
Some(val) => val,
None => return Err(MmCifError::MissingField("_atom_site.Cartn_z")),
};
let chain_pos = locate_or_append_chain(model, chain_id, chain_cache_pos);
let residues = &mut model.chains[chain_pos].residues;
let use_existing = residues
.last()
.map(|res| res.seq_id == seq_id && res.name == residue_name)
.unwrap_or(false);
if !use_existing {
residues.push(Residue {
name: residue_name,
seq_id,
atoms: SmallVec::new(),
});
}
let residue = residues.last_mut().expect("residue exists");
residue.atoms.push(Atom {
name: atom_name,
x,
y,
z,
});
Ok(())
}
}
#[inline]
fn locate_or_append_chain(
model: &mut Model,
chain_id: ChainId,
chain_cache_pos: &mut Option<usize>,
) -> usize {
if let Some(pos) = *chain_cache_pos
&& pos < model.chains.len()
&& model.chains[pos].id == chain_id
{
return pos;
}
if let Some(pos) = model.chains.iter().rposition(|c| c.id == chain_id) {
*chain_cache_pos = Some(pos);
return pos;
}
let pos = model.chains.len();
model.chains.push(Chain {
id: chain_id,
residues: Vec::new(),
});
*chain_cache_pos = Some(pos);
pos
}
#[inline]
fn matches_altloc(token: &[u8]) -> bool {
token == b"." || token == b"A" || token == b"?"
}
#[inline]
fn matches_first_model(token: &[u8]) -> bool {
let Ok(text) = std::str::from_utf8(token) else {
return false;
};
let trimmed = text.trim();
if let Ok(int) = trimmed.parse::<i64>() {
return int == 1;
}
if let Ok(flt) = trimmed.parse::<f64>() {
return (flt - 1.0).abs() < 1e-9;
}
false
}
fn parse_seq_id(token: &[u8]) -> Result<Option<i32>, MmCifError> {
if token == b"." || token == b"?" {
return Ok(None);
}
let text = std::str::from_utf8(token)
.map_err(|_| MmCifError::Parse("non-UTF-8 label_seq_id".into()))?;
let parsed: i32 = text.parse()?;
Ok(Some(parsed))
}
fn parse_f32(token: &[u8]) -> Result<f32, MmCifError> {
if token == b"." || token == b"?" {
return Err(MmCifError::Parse("missing float".into()));
}
let text = std::str::from_utf8(token)
.map_err(|_| MmCifError::Parse("non-UTF-8 float".into()))?;
let parsed: f32 = text.parse()?;
Ok(parsed)
}
struct Tokenizer<R: BufRead> {
reader: R,
line: Vec<u8>,
line_pos: usize,
scratch: Vec<u8>,
pushback: Vec<u8>,
has_pushback: bool,
}
impl<R: BufRead> Tokenizer<R> {
fn new(reader: R) -> Self {
Self {
reader,
line: Vec::with_capacity(4096),
line_pos: 0,
scratch: Vec::new(),
pushback: Vec::new(),
has_pushback: false,
}
}
fn push_back(&mut self, token: &[u8]) {
self.pushback.clear();
self.pushback.extend_from_slice(token);
self.has_pushback = true;
}
fn next_token(&mut self) -> Result<Option<&[u8]>, MmCifError> {
if self.has_pushback {
self.has_pushback = false;
return Ok(Some(&self.pushback));
}
loop {
while self.line_pos < self.line.len() {
let ch = self.line[self.line_pos];
match ch {
b' ' | b'\t' | b'\r' | b'\n' => self.line_pos += 1,
b'#' => {
self.line_pos = self.line.len();
}
_ => break,
}
}
if self.line_pos >= self.line.len() {
self.line.clear();
self.line_pos = 0;
let n = self.reader.read_until(b'\n', &mut self.line)?;
if n == 0 {
return Ok(None);
}
if self.line.first() == Some(&b';') {
self.scratch.clear();
let body_end = trim_eol(&self.line);
if body_end > 1 {
self.scratch.extend_from_slice(&self.line[1..body_end]);
}
loop {
self.line.clear();
let n2 = self.reader.read_until(b'\n', &mut self.line)?;
if n2 == 0 {
return Err(MmCifError::Parse(
"unterminated semicolon text".into(),
));
}
if self.line.first() == Some(&b';') {
self.line_pos = 1;
break;
}
if !self.scratch.is_empty() {
self.scratch.push(b'\n');
}
let end = trim_eol(&self.line);
self.scratch.extend_from_slice(&self.line[..end]);
}
return Ok(Some(&self.scratch));
}
continue;
}
return self.read_token_at_pos();
}
}
fn read_token_at_pos(&mut self) -> Result<Option<&[u8]>, MmCifError> {
let bytes = self.line.as_slice();
let pos = self.line_pos;
let first = bytes[pos];
if first == b'\'' || first == b'"' {
let quote = first;
let body_start = pos + 1;
let mut i = body_start;
while i < bytes.len() {
if bytes[i] == quote {
let next = bytes.get(i + 1).copied();
let closes = matches!(
next,
None | Some(b' ') | Some(b'\t') | Some(b'\r') | Some(b'\n') | Some(b'#')
);
if closes {
let token_end = i;
self.line_pos = i + 1;
return Ok(Some(&self.line[body_start..token_end]));
}
}
i += 1;
}
return Err(MmCifError::Parse("unterminated quoted string".into()));
}
let start = pos;
let end = find_token_end(bytes, start);
self.line_pos = end;
Ok(Some(&self.line[start..end]))
}
}
#[inline]
fn trim_eol(line: &[u8]) -> usize {
let mut end = line.len();
while end > 0 && (line[end - 1] == b'\n' || line[end - 1] == b'\r') {
end -= 1;
}
end
}
#[inline]
fn is_delim(ch: u8) -> bool {
ch == b' ' || ch == b'\t' || ch == b'\r' || ch == b'\n' || ch == b'#'
}
#[inline]
fn find_token_end(bytes: &[u8], start: usize) -> usize {
#[cfg(target_arch = "x86_64")]
{
if cpu_has_avx512bw() {
return unsafe { find_token_end_avx512(bytes, start) };
}
if cpu_has_avx2() {
return unsafe { find_token_end_avx2(bytes, start) };
}
}
#[cfg(target_arch = "aarch64")]
{
if cpu_has_neon() {
return unsafe { find_token_end_neon(bytes, start) };
}
}
find_token_end_scalar(bytes, start)
}
#[inline]
fn find_token_end_scalar(bytes: &[u8], start: usize) -> usize {
let mut i = start;
while i < bytes.len() {
if is_delim(bytes[i]) {
return i;
}
i += 1;
}
bytes.len()
}
#[cfg(target_arch = "x86_64")]
#[inline]
fn cpu_has_avx2() -> bool {
static FLAG: OnceLock<bool> = OnceLock::new();
*FLAG.get_or_init(|| std::arch::is_x86_feature_detected!("avx2"))
}
#[cfg(target_arch = "x86_64")]
#[inline]
fn cpu_has_avx512bw() -> bool {
static FLAG: OnceLock<bool> = OnceLock::new();
*FLAG.get_or_init(|| {
std::arch::is_x86_feature_detected!("avx512f")
&& std::arch::is_x86_feature_detected!("avx512bw")
})
}
#[cfg(target_arch = "aarch64")]
#[inline]
fn cpu_has_neon() -> bool {
static FLAG: OnceLock<bool> = OnceLock::new();
*FLAG.get_or_init(|| std::arch::is_aarch64_feature_detected!("neon"))
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn find_token_end_avx2(bytes: &[u8], start: usize) -> usize {
use std::arch::x86_64::*;
let len = bytes.len();
let mut i = start;
let space = _mm256_set1_epi8(b' ' as i8);
let tab = _mm256_set1_epi8(b'\t' as i8);
let cr = _mm256_set1_epi8(b'\r' as i8);
let lf = _mm256_set1_epi8(b'\n' as i8);
let hash = _mm256_set1_epi8(b'#' as i8);
while i + 32 <= len {
let chunk = unsafe { _mm256_loadu_si256(bytes.as_ptr().add(i) as *const _) };
let m1 = _mm256_cmpeq_epi8(chunk, space);
let m2 = _mm256_cmpeq_epi8(chunk, tab);
let m3 = _mm256_cmpeq_epi8(chunk, cr);
let m4 = _mm256_cmpeq_epi8(chunk, lf);
let m5 = _mm256_cmpeq_epi8(chunk, hash);
let mask = _mm256_or_si256(
_mm256_or_si256(_mm256_or_si256(m1, m2), _mm256_or_si256(m3, m4)),
m5,
);
let bits = _mm256_movemask_epi8(mask);
if bits != 0 {
return i + bits.trailing_zeros() as usize;
}
i += 32;
}
find_token_end_scalar(bytes, i)
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f,avx512bw")]
unsafe fn find_token_end_avx512(bytes: &[u8], start: usize) -> usize {
use std::arch::x86_64::*;
let len = bytes.len();
let mut i = start;
let space = _mm512_set1_epi8(b' ' as i8);
let tab = _mm512_set1_epi8(b'\t' as i8);
let cr = _mm512_set1_epi8(b'\r' as i8);
let lf = _mm512_set1_epi8(b'\n' as i8);
let hash = _mm512_set1_epi8(b'#' as i8);
while i + 64 <= len {
let chunk = unsafe { _mm512_loadu_si512(bytes.as_ptr().add(i) as *const _) };
let m1 = _mm512_cmpeq_epi8_mask(chunk, space);
let m2 = _mm512_cmpeq_epi8_mask(chunk, tab);
let m3 = _mm512_cmpeq_epi8_mask(chunk, cr);
let m4 = _mm512_cmpeq_epi8_mask(chunk, lf);
let m5 = _mm512_cmpeq_epi8_mask(chunk, hash);
let mask = m1 | m2 | m3 | m4 | m5;
if mask != 0 {
return i + mask.trailing_zeros() as usize;
}
i += 64;
}
if cpu_has_avx2() && i + 32 <= len {
return unsafe { find_token_end_avx2(bytes, i) };
}
find_token_end_scalar(bytes, i)
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
unsafe fn find_token_end_neon(bytes: &[u8], start: usize) -> usize {
use std::arch::aarch64::*;
let len = bytes.len();
let mut i = start;
let space = vdupq_n_u8(b' ');
let tab = vdupq_n_u8(b'\t');
let cr = vdupq_n_u8(b'\r');
let lf = vdupq_n_u8(b'\n');
let hash = vdupq_n_u8(b'#');
while i + 16 <= len {
let chunk = unsafe { vld1q_u8(bytes.as_ptr().add(i)) };
let m1 = vceqq_u8(chunk, space);
let m2 = vceqq_u8(chunk, tab);
let m3 = vceqq_u8(chunk, cr);
let m4 = vceqq_u8(chunk, lf);
let m5 = vceqq_u8(chunk, hash);
let mask = vorrq_u8(vorrq_u8(vorrq_u8(m1, m2), vorrq_u8(m3, m4)), m5);
let nibble = vshrn_n_u16(vreinterpretq_u16_u8(mask), 4);
let bits = vget_lane_u64(vreinterpret_u64_u8(nibble), 0);
if bits != 0 {
return i + (bits.trailing_zeros() as usize) / 4;
}
i += 16;
}
find_token_end_scalar(bytes, i)
}