use rayon::iter::IndexedParallelIterator;
use crate::{prelude::*, selection::utils::{difference_sorted, intersection_sorted, union_sorted}};
pub trait Selectable {
fn select(&self, def: impl SelectionDef) -> Result<Sel, SelectionError>;
}
pub trait SelectableBound: SystemProvider + Selectable {
fn select_bound(&self, def: impl SelectionDef) -> Result<SelOwnBound<'_>, SelectionError>;
}
pub trait AtomPosAnalysis: LenProvider + IndexProvider + Sized {
fn atoms_ptr(&self) -> *const Atom;
fn coords_ptr(&self) -> *const Pos;
fn split_par<F, R>(&self, func: F) -> Result<ParSplit, SelectionError>
where
F: Fn(Particle) -> Option<R>,
R: Default + PartialOrd,
Self: Sized,
{
let selections: Vec<Sel> = self.split(func).collect();
if selections.is_empty() {
return Err(SelectionError::EmptySplit);
}
let max_index = selections
.iter()
.map(|sel| *sel.0.last().unwrap())
.max()
.unwrap();
Ok(ParSplit {
selections,
max_index,
})
}
fn split<RT, F>(&self, func: F) -> impl Iterator<Item = Sel>
where
RT: Default + std::cmp::PartialEq,
F: Fn(Particle) -> Option<RT>,
{
let mut cur_val = RT::default();
let mut cur = 0usize;
let next_fn = move || {
let mut index = Vec::<usize>::new();
while cur < self.len() {
let p = unsafe { self.get_particle_unchecked(cur) };
let id = p.id;
let val = func(p);
if let Some(val) = val {
if val == cur_val {
index.push(id);
} else if index.is_empty() {
cur_val = val;
index.push(id);
} else {
cur_val = val; return Some(Sel(unsafe { SVec::from_sorted(index) }));
}
}
cur += 1;
}
if !index.is_empty() {
return Some(Sel(unsafe { SVec::from_sorted(index) }));
}
None
};
std::iter::from_fn(next_fn)
}
fn split_resindex(&self) -> impl Iterator<Item = Sel> {
self.split(|p| Some(p.atom.resindex))
}
fn whole_attr<T>(&self, attr_fn: fn(&Atom) -> &T) -> Sel
where
T: Eq + std::hash::Hash + Copy,
{
let mut properties = std::collections::HashSet::<T>::new();
for at in self.iter_atoms() {
properties.insert(*attr_fn(at));
}
let mut ind = vec![];
for (i, at) in self.iter_atoms().enumerate() {
let cur_prop = attr_fn(at);
if properties.contains(cur_prop) {
ind.push(i);
}
}
Sel(unsafe { SVec::from_sorted(ind) })
}
fn whole_residues(&self) -> Sel {
self.whole_attr(|at| &at.resindex)
}
fn whole_chains(&self) -> Sel {
self.whole_attr(|at| &at.chain)
}
}
pub trait NonAtomPosAnalysis: LenProvider + IndexProvider + Sized {
fn top_ptr(&self) -> *const Topology;
fn st_ptr(&self) -> *const State;
}
pub trait AtomPosAnalysisMut: AtomPosAnalysis {
fn atoms_ptr_mut(&mut self) -> *mut Atom {
self.atoms_ptr() as *mut Atom
}
fn coords_ptr_mut(&mut self) -> *mut Pos {
self.coords_ptr() as *mut Pos
}
}
pub trait NonAtomPosAnalysisMut: NonAtomPosAnalysis {
fn top_ptr_mut(&mut self) -> *mut Topology {
self.top_ptr() as *mut Topology
}
fn st_ptr_mut(&mut self) -> *mut State {
self.st_ptr() as *mut State
}
}
impl<T: AtomPosAnalysis> PosIterProvider for T {
fn iter_pos(&self) -> impl PosIterator<'_> {
unsafe { self.iter_index().map(|i| &*self.coords_ptr().add(i)) }
}
}
impl<T: AtomPosAnalysis + IndexParProvider> PosParIterProvider for T {
fn par_iter_pos(&self) -> impl IndexedParallelIterator<Item = &Pos> {
let p = self.coords_ptr() as usize; unsafe {
self.par_iter_index()
.map(move |i| &*(p as *const Pos).add(i))
}
}
}
impl<T: AtomPosAnalysis> AtomIterProvider for T {
fn iter_atoms(&self) -> impl AtomIterator<'_> {
unsafe { self.iter_index().map(|i| &*self.atoms_ptr().add(i)) }
}
}
impl<T: AtomPosAnalysis + IndexParProvider> AtomParIterProvider for T {
fn par_iter_atoms(&self) -> impl IndexedParallelIterator<Item = &Atom> {
let p = self.atoms_ptr() as usize; unsafe {
self.par_iter_index()
.map(move |i| &*(p as *const Atom).add(i))
}
}
}
impl<T: AtomPosAnalysis> RandomPosProvider for T {
unsafe fn get_pos_unchecked(&self, i: usize) -> &Pos {
let ind = self.get_index_unchecked(i);
&*self.coords_ptr().add(ind)
}
}
impl<T: AtomPosAnalysis> RandomAtomProvider for T {
unsafe fn get_atom_unchecked(&self, i: usize) -> &Atom {
let ind = self.get_index_unchecked(i);
&*self.atoms_ptr().add(ind)
}
}
impl<T: AtomPosAnalysis> ParticleIterProvider for T {
fn iter_particle(&self) -> impl Iterator<Item = Particle<'_>> {
unsafe {
self.iter_index().map(|i| Particle {
id: i,
atom: &*self.atoms_ptr().add(i),
pos: &*self.coords_ptr().add(i),
})
}
}
}
impl<T: AtomPosAnalysis + IndexParProvider> ParticleParIterProvider for T {
fn par_iter_particle(&self) -> impl IndexedParallelIterator<Item = Particle<'_>> {
let cp = self.coords_ptr() as usize;
let ap = self.atoms_ptr() as usize;
unsafe {
self.par_iter_index().map(move |i| Particle {
id: i,
atom: &*(ap as *const Atom).add(i),
pos: &*(cp as *const Pos).add(i),
})
}
}
}
impl<T: AtomPosAnalysisMut> ParticleIterMutProvider for T {
fn iter_particle_mut(&mut self) -> impl Iterator<Item = ParticleMut<'_>> {
let cp = self.coords_ptr_mut();
let ap = self.atoms_ptr_mut();
unsafe {
self.iter_index().map(move |i| ParticleMut {
id: i,
atom: &mut *ap.add(i),
pos: &mut *cp.add(i),
})
}
}
}
impl<T: AtomPosAnalysisMut + IndexParProvider> ParticleParIterMutProvider for T {
fn par_iter_particle_mut(&mut self) -> impl IndexedParallelIterator<Item = ParticleMut<'_>> {
let cp = self.coords_ptr_mut() as usize;
let ap = self.atoms_ptr_mut() as usize;
unsafe {
self.par_iter_index().map(move |i| ParticleMut {
id: i,
atom: &mut *(ap as *mut Atom).add(i),
pos: &mut *(cp as *mut Pos).add(i),
})
}
}
}
impl<T: AtomPosAnalysis> RandomParticleProvider for T {
unsafe fn get_particle_unchecked(&self, i: usize) -> Particle<'_> {
let ind = self.get_index_unchecked(i);
Particle {
id: ind,
atom: &*self.atoms_ptr().add(ind),
pos: &*self.coords_ptr().add(ind),
}
}
}
impl<T: NonAtomPosAnalysis> BoxProvider for T {
fn get_box(&self) -> Option<&PeriodicBox> {
unsafe { &*self.st_ptr() }.get_box()
}
}
impl<T: NonAtomPosAnalysis> TimeProvider for T {
fn get_time(&self) -> f32 {
unsafe { &*self.st_ptr() }.time
}
}
impl<T: NonAtomPosAnalysis> RandomMoleculeProvider for T {
fn num_molecules(&self) -> usize {
unsafe { &*self.top_ptr() }.num_molecules()
}
unsafe fn get_molecule_unchecked(&self, i: usize) -> &[usize; 2] {
unsafe { &*self.top_ptr() }.get_molecule_unchecked(i)
}
}
impl<T: NonAtomPosAnalysis> MoleculeIterProvider for T {
fn iter_molecules(&self) -> impl Iterator<Item = &[usize; 2]> {
unsafe { &*self.top_ptr() }.iter_molecules()
}
}
impl<T: NonAtomPosAnalysis> RandomBondProvider for T {
fn num_bonds(&self) -> usize {
unsafe { &*self.top_ptr() }.num_bonds()
}
unsafe fn get_bond_unchecked(&self, i: usize) -> &[usize; 2] {
unsafe { &*self.top_ptr() }.get_bond_unchecked(i)
}
}
impl<T: NonAtomPosAnalysis> BondIterProvider for T {
fn iter_bonds(&self) -> impl Iterator<Item = &[usize; 2]> {
unsafe { &*self.top_ptr() }.iter_bonds()
}
}
impl<T: AtomPosAnalysisMut> RandomAtomMutProvider for T {
unsafe fn get_atom_mut_unchecked(&mut self, i: usize) -> &mut Atom {
let ind = self.get_index_unchecked(i);
&mut *self.atoms_ptr_mut().add(ind)
}
}
impl<T: AtomPosAnalysisMut> AtomIterMutProvider for T {
fn iter_atoms_mut(&mut self) -> impl AtomMutIterator<'_> {
(0..self.len()).map(|i| {
let ind = unsafe { self.get_index_unchecked(i) };
unsafe { &mut *self.atoms_ptr_mut().add(ind) }
})
}
}
impl<T: AtomPosAnalysisMut> PosIterMutProvider for T {
fn iter_pos_mut(&mut self) -> impl PosMutIterator<'_> {
(0..self.len()).map(|i| {
let ind = unsafe { self.get_index_unchecked(i) };
unsafe { &mut *self.coords_ptr_mut().add(ind) }
})
}
}
impl<T: AtomPosAnalysisMut> RandomPosMutProvider for T {
unsafe fn get_pos_mut_unchecked(&mut self, i: usize) -> &mut Pos {
let ind = self.get_index_unchecked(i);
&mut *self.coords_ptr_mut().add(ind)
}
}
impl<T: AtomPosAnalysisMut + IndexParProvider> PosParIterMutProvider for T {
fn par_iter_pos_mut(&mut self) -> impl IndexedParallelIterator<Item = &mut Pos> {
let p = self.coords_ptr_mut() as usize; unsafe {
self.par_iter_index()
.map(move |i| &mut *(p as *mut Pos).add(i))
}
}
}
impl<T: AtomPosAnalysisMut + IndexParProvider> AtomParIterMutProvider for T {
fn par_iter_atoms_mut(&mut self) -> impl IndexedParallelIterator<Item = &mut Atom> {
let p = self.atoms_ptr_mut() as usize; unsafe {
self.par_iter_index()
.map(move |i| &mut *(p as *mut Atom).add(i))
}
}
}
impl<T: AtomPosAnalysisMut> RandomParticleMutProvider for T {
unsafe fn get_particle_mut_unchecked(&mut self, i: usize) -> ParticleMut<'_> {
let ind = self.get_index_unchecked(i);
ParticleMut {
id: ind,
atom: unsafe { &mut *self.atoms_ptr_mut().add(ind) },
pos: unsafe { &mut *self.coords_ptr_mut().add(ind) },
}
}
}
impl<T: NonAtomPosAnalysisMut> TimeMutProvider for T {
fn set_time(&mut self, t: f32) {
unsafe { &mut *self.st_ptr_mut() }.time = t;
}
}
impl<T: NonAtomPosAnalysisMut> BoxMutProvider for T {
fn get_box_mut(&mut self) -> Option<&mut PeriodicBox> {
unsafe { &mut *self.st_ptr_mut() }.pbox.as_mut()
}
}
impl<T: AtomPosAnalysis> MeasurePos for T {}
impl<T: AtomPosAnalysis + NonAtomPosAnalysis> MeasurePeriodic for T {}
impl<T: AtomPosAnalysis> MeasureMasses for T {}
impl<T: AtomPosAnalysis> MeasureRandomAccess for T {}
impl<T: AtomPosAnalysis> MeasureAtomPos for T {}
impl<T: AtomPosAnalysisMut> ModifyPos for T {}
impl<T: AtomPosAnalysisMut + NonAtomPosAnalysis> ModifyPeriodic for T {}
impl<T: AtomPosAnalysisMut + AtomPosAnalysis + NonAtomPosAnalysis> ModifyRandomAccess for T {}
pub trait SystemProvider {
fn get_system_ptr(&self) -> *const System;
}
impl<T: SystemProvider + IndexProvider> AtomPosAnalysis for T {
fn atoms_ptr(&self) -> *const Atom {
unsafe {(&*self.get_system_ptr()).top.atoms.as_ptr()}
}
fn coords_ptr(&self) -> *const Pos {
unsafe {(&*self.get_system_ptr()).st.coords.as_ptr()}
}
}
impl<T: SystemProvider + IndexProvider> NonAtomPosAnalysis for T {
fn st_ptr(&self) -> *const State {
unsafe {&(*self.get_system_ptr()).st}
}
fn top_ptr(&self) -> *const Topology {
unsafe {&(*self.get_system_ptr()).top}
}
}
pub trait SystemMutProvider: SystemProvider {
fn get_system_mut(&mut self) -> *mut System {
self.get_system_ptr() as *mut System
}
}
impl<T: SystemMutProvider + IndexProvider> AtomPosAnalysisMut for T {}
impl<T: SystemMutProvider + IndexProvider> NonAtomPosAnalysisMut for T {}
pub trait SelectionLogic: IndexSliceProvider {
type DerivedSel;
fn clone_with_index(&self, index: SVec) -> Self::DerivedSel;
fn or(&self, rhs: &impl IndexSliceProvider) -> Self::DerivedSel {
let index = unsafe {union_sorted(self.get_index_slice(), rhs.get_index_slice())};
self.clone_with_index(index)
}
fn and(&self, rhs: &impl IndexSliceProvider) -> Result<Self::DerivedSel,SelectionError> {
let index = unsafe {intersection_sorted(self.get_index_slice(), rhs.get_index_slice())};
if index.is_empty() {
return Err(SelectionError::EmptyIntersection)
}
Ok(self.clone_with_index(index))
}
fn minus(&self, rhs: &impl IndexSliceProvider) -> Result<Self::DerivedSel,SelectionError> {
let index = unsafe {difference_sorted(self.get_index_slice(), rhs.get_index_slice())};
if index.is_empty() {
return Err(SelectionError::EmptyDifference)
}
Ok(self.clone_with_index(index))
}
fn invert(&self, rhs: &impl IndexSliceProvider) -> Result<Self::DerivedSel,SelectionError>
where Self: SystemProvider,
{
let all = (0..unsafe{&*self.get_system_ptr()}.len()).into_iter().collect::<Vec<_>>();
let index = unsafe {difference_sorted(&all, rhs.get_index_slice())};
if index.is_empty() {
return Err(SelectionError::EmptyDifference)
}
Ok(self.clone_with_index(index))
}
}