1use crate::model::structure::Structure;
9use crate::model::types::{ResidueCategory, StandardResidue};
10use crate::ops::error::Error;
11use std::collections::HashSet;
12
13#[derive(Debug, Clone, Default)]
18pub struct CleanConfig {
19 pub remove_water: bool,
21 pub remove_ions: bool,
23 pub remove_hydrogens: bool,
25 pub remove_hetero: bool,
27 pub remove_residue_names: HashSet<String>,
29 pub keep_residue_names: HashSet<String>,
31}
32
33impl CleanConfig {
34 pub fn water_only() -> Self {
42 Self {
43 remove_water: true,
44 ..Default::default()
45 }
46 }
47
48 pub fn water_and_ions() -> Self {
56 Self {
57 remove_water: true,
58 remove_ions: true,
59 ..Default::default()
60 }
61 }
62}
63
64pub fn clean_structure(structure: &mut Structure, config: &CleanConfig) -> Result<(), Error> {
83 if config.remove_hydrogens {
84 for chain in structure.iter_chains_mut() {
85 for residue in chain.iter_residues_mut() {
86 residue.strip_hydrogens();
87 }
88 }
89 }
90
91 structure.retain_residues(|_chain_id, residue| {
92 if config.keep_residue_names.contains(&residue.name) {
93 return true;
94 }
95
96 if config.remove_residue_names.contains(&residue.name) {
97 return false;
98 }
99
100 if config.remove_water && residue.standard_name == Some(StandardResidue::HOH) {
101 return false;
102 }
103
104 if config.remove_ions && residue.category == ResidueCategory::Ion {
105 return false;
106 }
107
108 if config.remove_hetero && residue.category == ResidueCategory::Hetero {
109 return false;
110 }
111
112 true
113 });
114
115 structure.prune_empty_chains();
116
117 Ok(())
118}
119
120#[cfg(test)]
121mod tests {
122 use super::*;
123 use crate::model::{
124 atom::Atom,
125 chain::Chain,
126 residue::Residue,
127 structure::Structure,
128 types::{Element, Point},
129 };
130
131 fn make_structure(
132 residues: Vec<(String, ResidueCategory, Option<StandardResidue>)>,
133 ) -> Structure {
134 let mut structure = Structure::new();
135 let mut chain = Chain::new("A");
136
137 for (idx, (name, category, standard)) in residues.into_iter().enumerate() {
138 let mut residue = Residue::new(idx as i32 + 1, None, &name, standard, category);
139 residue.add_atom(Atom::new("C", Element::C, Point::origin()));
140 chain.add_residue(residue);
141 }
142
143 structure.add_chain(chain);
144 structure
145 }
146
147 #[test]
148 fn removes_hydrogens_when_flag_enabled() {
149 let mut structure = Structure::new();
150 let mut chain = Chain::new("A");
151 let mut residue = Residue::new(
152 1,
153 None,
154 "GLY",
155 Some(StandardResidue::GLY),
156 ResidueCategory::Standard,
157 );
158 residue.add_atom(Atom::new("CA", Element::C, Point::origin()));
159 residue.add_atom(Atom::new("HA", Element::H, Point::new(1.0, 0.0, 0.0)));
160 chain.add_residue(residue);
161 structure.add_chain(chain);
162
163 let config = CleanConfig {
164 remove_hydrogens: true,
165 ..Default::default()
166 };
167
168 clean_structure(&mut structure, &config).unwrap();
169
170 let residue = structure.chain("A").unwrap().residue(1, None).unwrap();
171 assert_eq!(residue.atom_count(), 1);
172 assert!(residue.atom("HA").is_none());
173 }
174
175 #[test]
176 fn removes_water_and_ions_when_configured() {
177 let mut structure = make_structure(vec![
178 (
179 "HOH".to_string(),
180 ResidueCategory::Standard,
181 Some(StandardResidue::HOH),
182 ),
183 ("NA".to_string(), ResidueCategory::Ion, None),
184 (
185 "GLY".to_string(),
186 ResidueCategory::Standard,
187 Some(StandardResidue::GLY),
188 ),
189 ]);
190
191 let config = CleanConfig {
192 remove_water: true,
193 remove_ions: true,
194 ..Default::default()
195 };
196
197 clean_structure(&mut structure, &config).unwrap();
198
199 let chain = structure.chain("A").unwrap();
200 assert_eq!(chain.residue_count(), 1);
201 assert_eq!(chain.residue(3, None).unwrap().name, "GLY");
202 }
203
204 #[test]
205 fn removes_named_residues_and_hetero_categories() {
206 let mut structure = make_structure(vec![
207 ("LIG".to_string(), ResidueCategory::Hetero, None),
208 ("SO4".to_string(), ResidueCategory::Hetero, None),
209 (
210 "ALA".to_string(),
211 ResidueCategory::Standard,
212 Some(StandardResidue::ALA),
213 ),
214 ]);
215
216 let config = CleanConfig {
217 remove_hetero: true,
218 remove_residue_names: HashSet::from(["SO4".to_string()]),
219 ..Default::default()
220 };
221
222 clean_structure(&mut structure, &config).unwrap();
223
224 let chain = structure.chain("A").unwrap();
225 let names: Vec<_> = chain.iter_residues().map(|res| res.name.as_str()).collect();
226 assert_eq!(names, vec!["ALA"]);
227 }
228
229 #[test]
230 fn keep_list_overrides_removal_rules() {
231 let mut structure = make_structure(vec![(
232 "HOH".to_string(),
233 ResidueCategory::Standard,
234 Some(StandardResidue::HOH),
235 )]);
236
237 let config = CleanConfig {
238 remove_water: true,
239 keep_residue_names: HashSet::from(["HOH".to_string()]),
240 ..Default::default()
241 };
242
243 clean_structure(&mut structure, &config).unwrap();
244
245 assert_eq!(structure.chain("A").unwrap().residue_count(), 1);
246 }
247
248 #[test]
249 fn prunes_empty_chains_after_residue_removal() {
250 let mut structure = make_structure(vec![(
251 "HOH".to_string(),
252 ResidueCategory::Standard,
253 Some(StandardResidue::HOH),
254 )]);
255 let mut chain_b = Chain::new("B");
256 let mut residue_b = Residue::new(
257 10,
258 None,
259 "ALA",
260 Some(StandardResidue::ALA),
261 ResidueCategory::Standard,
262 );
263 residue_b.add_atom(Atom::new("CA", Element::C, Point::new(2.0, 0.0, 0.0)));
264 chain_b.add_residue(residue_b);
265 structure.add_chain(chain_b);
266
267 let config = CleanConfig {
268 remove_water: true,
269 ..Default::default()
270 };
271
272 clean_structure(&mut structure, &config).unwrap();
273
274 assert!(structure.chain("A").is_none());
275 assert!(structure.chain("B").is_some());
276 }
277}