1use crate::model::structure::Structure;
9use crate::model::types::{ResidueCategory, StandardResidue};
10use crate::ops::error::Error;
11use std::collections::HashSet;
12
13#[derive(Debug, Clone, Default)]
18pub struct CleanConfig {
19 pub remove_water: bool,
21 pub remove_ions: bool,
23 pub remove_hydrogens: bool,
25 pub remove_hetero: bool,
27 pub remove_residue_names: HashSet<String>,
29 pub keep_residue_names: HashSet<String>,
31}
32
33impl CleanConfig {
34 pub fn water_only() -> Self {
42 Self {
43 remove_water: true,
44 ..Default::default()
45 }
46 }
47
48 pub fn water_and_ions() -> Self {
56 Self {
57 remove_water: true,
58 remove_ions: true,
59 ..Default::default()
60 }
61 }
62}
63
64pub fn clean_structure(structure: &mut Structure, config: &CleanConfig) -> Result<(), Error> {
83 structure.par_retain_residues_mut(|_chain_id, residue| {
84 if config.keep_residue_names.contains(residue.name.as_str()) {
85 if config.remove_hydrogens {
86 residue.strip_hydrogens();
87 }
88 return true;
89 }
90
91 if config.remove_residue_names.contains(residue.name.as_str()) {
92 return false;
93 }
94
95 if config.remove_water && residue.standard_name == Some(StandardResidue::HOH) {
96 return false;
97 }
98
99 if config.remove_ions && residue.category == ResidueCategory::Ion {
100 return false;
101 }
102
103 if config.remove_hetero && residue.category == ResidueCategory::Hetero {
104 return false;
105 }
106
107 if config.remove_hydrogens {
108 residue.strip_hydrogens();
109 }
110
111 true
112 });
113
114 structure.prune_empty_chains();
115
116 Ok(())
117}
118
119#[cfg(test)]
120mod tests {
121 use super::*;
122 use crate::model::{
123 atom::Atom,
124 chain::Chain,
125 residue::Residue,
126 structure::Structure,
127 types::{Element, Point},
128 };
129
130 fn make_structure(
131 residues: Vec<(String, ResidueCategory, Option<StandardResidue>)>,
132 ) -> Structure {
133 let mut structure = Structure::new();
134 let mut chain = Chain::new("A");
135
136 for (idx, (name, category, standard)) in residues.into_iter().enumerate() {
137 let mut residue = Residue::new(idx as i32 + 1, None, &name, standard, category);
138 residue.add_atom(Atom::new("C", Element::C, Point::origin()));
139 chain.add_residue(residue);
140 }
141
142 structure.add_chain(chain);
143 structure
144 }
145
146 #[test]
147 fn removes_hydrogens_when_flag_enabled() {
148 let mut structure = Structure::new();
149 let mut chain = Chain::new("A");
150 let mut residue = Residue::new(
151 1,
152 None,
153 "GLY",
154 Some(StandardResidue::GLY),
155 ResidueCategory::Standard,
156 );
157 residue.add_atom(Atom::new("CA", Element::C, Point::origin()));
158 residue.add_atom(Atom::new("HA", Element::H, Point::new(1.0, 0.0, 0.0)));
159 chain.add_residue(residue);
160 structure.add_chain(chain);
161
162 let config = CleanConfig {
163 remove_hydrogens: true,
164 ..Default::default()
165 };
166
167 clean_structure(&mut structure, &config).unwrap();
168
169 let residue = structure.chain("A").unwrap().residue(1, None).unwrap();
170 assert_eq!(residue.atom_count(), 1);
171 assert!(residue.atom("HA").is_none());
172 }
173
174 #[test]
175 fn removes_water_and_ions_when_configured() {
176 let mut structure = make_structure(vec![
177 (
178 "HOH".to_string(),
179 ResidueCategory::Standard,
180 Some(StandardResidue::HOH),
181 ),
182 ("NA".to_string(), ResidueCategory::Ion, None),
183 (
184 "GLY".to_string(),
185 ResidueCategory::Standard,
186 Some(StandardResidue::GLY),
187 ),
188 ]);
189
190 let config = CleanConfig {
191 remove_water: true,
192 remove_ions: true,
193 ..Default::default()
194 };
195
196 clean_structure(&mut structure, &config).unwrap();
197
198 let chain = structure.chain("A").unwrap();
199 assert_eq!(chain.residue_count(), 1);
200 assert_eq!(chain.residue(3, None).unwrap().name, "GLY");
201 }
202
203 #[test]
204 fn removes_named_residues_and_hetero_categories() {
205 let mut structure = make_structure(vec![
206 ("LIG".to_string(), ResidueCategory::Hetero, None),
207 ("SO4".to_string(), ResidueCategory::Hetero, None),
208 (
209 "ALA".to_string(),
210 ResidueCategory::Standard,
211 Some(StandardResidue::ALA),
212 ),
213 ]);
214
215 let config = CleanConfig {
216 remove_hetero: true,
217 remove_residue_names: HashSet::from(["SO4".to_string()]),
218 ..Default::default()
219 };
220
221 clean_structure(&mut structure, &config).unwrap();
222
223 let chain = structure.chain("A").unwrap();
224 let names: Vec<_> = chain.iter_residues().map(|res| res.name.as_str()).collect();
225 assert_eq!(names, vec!["ALA"]);
226 }
227
228 #[test]
229 fn keep_list_overrides_removal_rules() {
230 let mut structure = make_structure(vec![(
231 "HOH".to_string(),
232 ResidueCategory::Standard,
233 Some(StandardResidue::HOH),
234 )]);
235
236 let config = CleanConfig {
237 remove_water: true,
238 keep_residue_names: HashSet::from(["HOH".to_string()]),
239 ..Default::default()
240 };
241
242 clean_structure(&mut structure, &config).unwrap();
243
244 assert_eq!(structure.chain("A").unwrap().residue_count(), 1);
245 }
246
247 #[test]
248 fn prunes_empty_chains_after_residue_removal() {
249 let mut structure = make_structure(vec![(
250 "HOH".to_string(),
251 ResidueCategory::Standard,
252 Some(StandardResidue::HOH),
253 )]);
254 let mut chain_b = Chain::new("B");
255 let mut residue_b = Residue::new(
256 10,
257 None,
258 "ALA",
259 Some(StandardResidue::ALA),
260 ResidueCategory::Standard,
261 );
262 residue_b.add_atom(Atom::new("CA", Element::C, Point::new(2.0, 0.0, 0.0)));
263 chain_b.add_residue(residue_b);
264 structure.add_chain(chain_b);
265
266 let config = CleanConfig {
267 remove_water: true,
268 ..Default::default()
269 };
270
271 clean_structure(&mut structure, &config).unwrap();
272
273 assert!(structure.chain("A").is_none());
274 assert!(structure.chain("B").is_some());
275 }
276}