1use std::{
19 collections::BTreeSet,
20 sync::{
21 atomic::{AtomicUsize, Ordering::Relaxed},
22 Arc,
23 },
24};
25
26use bit_set::BitSet;
27use rayon::iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator};
28
29use crate::{
30 molecule::Bond, molecule::Element, molecule::Molecule, utils::connected_components_under_edges,
31};
32
33#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
34struct EdgeType {
35 bond: Bond,
36 ends: (Element, Element),
37}
38
39static PARALLEL_MATCH_SIZE_THRESHOLD: usize = 100;
40
41#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
45pub enum Bound {
46 Log,
48 IntChain,
51 VecChainSimple,
54 VecChainSmallFrags,
57}
58
59pub fn naive_assembly_depth(mol: &Molecule) -> u32 {
60 let mut ix = u32::MAX;
61 for (left, right) in mol.partitions().unwrap() {
62 let l = if left.is_basic_unit() {
63 0
64 } else {
65 naive_assembly_depth(&left)
66 };
67
68 let r = if right.is_basic_unit() {
69 0
70 } else {
71 naive_assembly_depth(&right)
72 };
73
74 ix = ix.min(l.max(r) + 1)
75 }
76 ix
77}
78
79fn recurse_naive_index_search(
80 mol: &Molecule,
81 matches: &BTreeSet<(BitSet, BitSet)>,
82 fragments: &[BitSet],
83 ix: usize,
84) -> usize {
85 let mut cx = ix;
86 for (h1, h2) in matches {
87 let mut fractures = fragments.to_owned();
88 let f1 = fragments.iter().enumerate().find(|(_, c)| h1.is_subset(c));
89 let f2 = fragments.iter().enumerate().find(|(_, c)| h2.is_subset(c));
90
91 let (Some((i1, f1)), Some((i2, f2))) = (f1, f2) else {
92 continue;
93 };
94
95 if i1 == i2 {
97 let mut union = h1.clone();
98 union.union_with(h2);
99 let mut difference = f1.clone();
100 difference.difference_with(&union);
101 let c = connected_components_under_edges(mol.graph(), &difference);
102 fractures.extend(c);
103 fractures.swap_remove(i1);
104 fractures.push(h1.clone());
105 } else {
106 let mut f1r = f1.clone();
107 f1r.difference_with(h1);
108 let mut f2r = f2.clone();
109 f2r.difference_with(h2);
110
111 let c1 = connected_components_under_edges(mol.graph(), &f1r);
112 let c2 = connected_components_under_edges(mol.graph(), &f2r);
113
114 fractures.extend(c1);
115 fractures.extend(c2);
116
117 fractures.swap_remove(i1.max(i2));
118 fractures.swap_remove(i1.min(i2));
119
120 fractures.push(h1.clone());
121 }
122 cx = cx.min(recurse_naive_index_search(
123 mol,
124 matches,
125 &fractures,
126 ix - h1.len() + 1,
127 ));
128 }
129 cx
130}
131
132pub fn naive_index_search(mol: &Molecule) -> u32 {
136 let mut init = BitSet::new();
137 init.extend(mol.graph().edge_indices().map(|ix| ix.index()));
138
139 recurse_naive_index_search(
140 mol,
141 &mol.matches().collect(),
142 &[init],
143 mol.graph().edge_count() - 1,
144 ) as u32
145}
146
147#[allow(clippy::too_many_arguments)]
148fn recurse_index_search(
149 mol: &Molecule,
150 matches: &[(BitSet, BitSet)],
151 fragments: &[BitSet],
152 ix: usize,
153 largest_remove: usize,
154 mut best: usize,
155 bounds: &[Bound],
156 states_searched: &mut usize,
157) -> usize {
158 let mut cx = ix;
159
160 *states_searched += 1;
161
162 for bound_type in bounds {
164 let exceeds = match bound_type {
165 Bound::Log => ix - log_bound(fragments) >= best,
166 Bound::IntChain => ix - addition_bound(fragments, largest_remove) >= best,
167 Bound::VecChainSimple => ix - vec_bound_simple(fragments, largest_remove, mol) >= best,
168 Bound::VecChainSmallFrags => {
169 ix - vec_bound_small_frags(fragments, largest_remove, mol) >= best
170 }
171 };
172 if exceeds {
173 return ix;
174 }
175 }
176
177 for (i, (h1, h2)) in matches.iter().enumerate() {
179 let mut fractures = fragments.to_owned();
180 let f1 = fragments.iter().enumerate().find(|(_, c)| h1.is_subset(c));
181 let f2 = fragments.iter().enumerate().find(|(_, c)| h2.is_subset(c));
182
183 let largest_remove = h1.len();
184
185 let (Some((i1, f1)), Some((i2, f2))) = (f1, f2) else {
186 continue;
187 };
188
189 if i1 == i2 {
191 let mut union = h1.clone();
192 union.union_with(h2);
193 let mut difference = f1.clone();
194 difference.difference_with(&union);
195 let c = connected_components_under_edges(mol.graph(), &difference);
196 fractures.extend(c);
197 fractures.swap_remove(i1);
198 } else {
199 let mut f1r = f1.clone();
200 f1r.difference_with(h1);
201 let mut f2r = f2.clone();
202 f2r.difference_with(h2);
203
204 let c1 = connected_components_under_edges(mol.graph(), &f1r);
205 let c2 = connected_components_under_edges(mol.graph(), &f2r);
206
207 fractures.extend(c1);
208 fractures.extend(c2);
209
210 fractures.swap_remove(i1.max(i2));
211 fractures.swap_remove(i1.min(i2));
212 }
213
214 fractures.retain(|i| i.len() > 1);
215 fractures.push(h1.clone());
216
217 cx = cx.min(recurse_index_search(
218 mol,
219 &matches[i + 1..],
220 &fractures,
221 ix - h1.len() + 1,
222 largest_remove,
223 best,
224 bounds,
225 states_searched,
226 ));
227 best = best.min(cx);
228 }
229
230 cx
231}
232
233#[allow(clippy::too_many_arguments)]
234fn parallel_recurse_index_search(
235 mol: &Molecule,
236 matches: &[(BitSet, BitSet)],
237 fragments: &[BitSet],
238 ix: usize,
239 largest_remove: usize,
240 best: AtomicUsize,
241 bounds: &[Bound],
242 states_searched: Arc<AtomicUsize>,
243) -> usize {
244 let cx = AtomicUsize::from(ix);
245
246 states_searched.fetch_add(1, Relaxed);
247
248 for bound_type in bounds {
250 let best = best.load(Relaxed);
251 let exceeds = match bound_type {
252 Bound::Log => ix - log_bound(fragments) >= best,
253 Bound::IntChain => ix - addition_bound(fragments, largest_remove) >= best,
254 Bound::VecChainSimple => ix - vec_bound_simple(fragments, largest_remove, mol) >= best,
255 Bound::VecChainSmallFrags => {
256 ix - vec_bound_small_frags(fragments, largest_remove, mol) >= best
257 }
258 };
259 if exceeds {
260 return ix;
261 }
262 }
263
264 matches.par_iter().enumerate().for_each(|(i, (h1, h2))| {
266 let mut fractures = fragments.to_owned();
267 let f1 = fragments.iter().enumerate().find(|(_, c)| h1.is_subset(c));
268 let f2 = fragments.iter().enumerate().find(|(_, c)| h2.is_subset(c));
269
270 let largest_remove = h1.len();
271
272 let (Some((i1, f1)), Some((i2, f2))) = (f1, f2) else {
273 return;
274 };
275
276 if i1 == i2 {
278 let mut union = h1.clone();
279 union.union_with(h2);
280 let mut difference = f1.clone();
281 difference.difference_with(&union);
282 let c = connected_components_under_edges(mol.graph(), &difference);
283 fractures.extend(c);
284 fractures.swap_remove(i1);
285 } else {
286 let mut f1r = f1.clone();
287 f1r.difference_with(h1);
288 let mut f2r = f2.clone();
289 f2r.difference_with(h2);
290
291 let c1 = connected_components_under_edges(mol.graph(), &f1r);
292 let c2 = connected_components_under_edges(mol.graph(), &f2r);
293
294 fractures.extend(c1);
295 fractures.extend(c2);
296
297 fractures.swap_remove(i1.max(i2));
298 fractures.swap_remove(i1.min(i2));
299 }
300
301 fractures.retain(|i| i.len() > 1);
302 fractures.push(h1.clone());
303
304 let output = parallel_recurse_index_search(
305 mol,
306 &matches[i + 1..],
307 &fractures,
308 ix - h1.len() + 1,
309 largest_remove,
310 best.load(Relaxed).into(),
311 bounds,
312 states_searched.clone(),
313 );
314 cx.fetch_min(output, Relaxed);
315
316 best.fetch_min(cx.load(Relaxed), Relaxed);
317 });
318
319 cx.load(Relaxed)
320}
321
322pub fn index_search(mol: &Molecule, bounds: &[Bound]) -> (u32, u32, usize) {
359 let mut init = BitSet::new();
360 init.extend(mol.graph().edge_indices().map(|ix| ix.index()));
361
362 let mut matches: Vec<(BitSet, BitSet)> = mol.matches().collect();
364 matches.sort_by(|e1, e2| e2.0.len().cmp(&e1.0.len()));
365
366 let edge_count = mol.graph().edge_count();
367
368 let (index, total_search) = if matches.len() > PARALLEL_MATCH_SIZE_THRESHOLD {
369 let total_search = Arc::new(AtomicUsize::from(0));
370 let index = parallel_recurse_index_search(
371 mol,
372 &matches,
373 &[init],
374 edge_count - 1,
375 edge_count,
376 (edge_count - 1).into(),
377 bounds,
378 total_search.clone(),
379 );
380 let total_search = total_search.load(Relaxed);
381 (index as u32, total_search)
382 } else {
383 let mut total_search = 0;
384 let index = recurse_index_search(
385 mol,
386 &matches,
387 &[init],
388 edge_count - 1,
389 edge_count,
390 edge_count - 1,
391 bounds,
392 &mut total_search,
393 );
394 (index as u32, total_search)
395 };
396
397 (index, matches.len() as u32, total_search)
398}
399
400pub fn serial_index_search(mol: &Molecule, bounds: &[Bound]) -> (u32, u32, usize) {
426 let mut init = BitSet::new();
427 init.extend(mol.graph().edge_indices().map(|ix| ix.index()));
428
429 let mut matches: Vec<(BitSet, BitSet)> = mol.matches().collect();
431 matches.sort_by(|e1, e2| e2.0.len().cmp(&e1.0.len()));
432
433 let edge_count = mol.graph().edge_count();
434 let mut total_search = 0;
435 let index = recurse_index_search(
436 mol,
437 &matches,
438 &[init],
439 edge_count - 1,
440 edge_count,
441 edge_count - 1,
442 bounds,
443 &mut total_search,
444 );
445 (index as u32, matches.len() as u32, total_search)
446}
447
448fn log_bound(fragments: &[BitSet]) -> usize {
449 let mut size = 0;
450 for f in fragments {
451 size += f.len();
452 }
453
454 size - (size as f32).log2().ceil() as usize
455}
456
457fn addition_bound(fragments: &[BitSet], m: usize) -> usize {
458 let mut max_s: usize = 0;
459 let mut frag_sizes: Vec<usize> = Vec::new();
460
461 for f in fragments {
462 frag_sizes.push(f.len());
463 }
464
465 let size_sum: usize = frag_sizes.iter().sum();
466
467 for max in 2..m + 1 {
469 let log = (max as f32).log2().ceil();
470 let mut aux_sum: usize = 0;
471
472 for len in &frag_sizes {
473 aux_sum += (len / max) + (len % max != 0) as usize
474 }
475
476 max_s = max_s.max(size_sum - log as usize - aux_sum);
477 }
478
479 max_s
480}
481
482fn unique_edges(fragment: &BitSet, mol: &Molecule) -> Vec<EdgeType> {
485 let g = mol.graph();
486 let mut nodes: Vec<Element> = Vec::new();
487 for v in g.node_weights() {
488 nodes.push(v.element());
489 }
490 let edges: Vec<petgraph::prelude::EdgeIndex> = g.edge_indices().collect();
491 let weights: Vec<Bond> = g.edge_weights().copied().collect();
492
493 let mut types: Vec<EdgeType> = Vec::new();
495 for idx in fragment.iter() {
496 let bond = weights[idx];
497 let e = edges[idx];
498
499 let (e1, e2) = g.edge_endpoints(e).expect("bad");
500 let e1 = nodes[e1.index()];
501 let e2 = nodes[e2.index()];
502 let ends = if e1 < e2 { (e1, e2) } else { (e2, e1) };
503
504 let edge_type = EdgeType { bond, ends };
505
506 if types.iter().any(|&t| t == edge_type) {
507 continue;
508 } else {
509 types.push(edge_type);
510 }
511 }
512
513 types
514}
515
516fn vec_bound_simple(fragments: &[BitSet], m: usize, mol: &Molecule) -> usize {
517 let mut s = 0;
520 for f in fragments {
521 s += f.len();
522 }
523
524 let mut union_set = BitSet::new();
525 for f in fragments {
526 union_set.union_with(f);
527 }
528 let z = unique_edges(&union_set, mol).len();
529
530 (s - z) - ((s - z) as f32 / m as f32).ceil() as usize
531}
532
533fn vec_bound_small_frags(fragments: &[BitSet], m: usize, mol: &Molecule) -> usize {
534 let mut size_two_fragments: Vec<BitSet> = Vec::new();
535 let mut large_fragments: Vec<BitSet> = fragments.to_owned();
536 let mut indices_to_remove: Vec<usize> = Vec::new();
537
538 for (i, frag) in fragments.iter().enumerate() {
540 if frag.len() == 2 {
541 indices_to_remove.push(i);
542 }
543 }
544 for &index in indices_to_remove.iter().rev() {
545 let removed_bitset = large_fragments.remove(index);
546 size_two_fragments.push(removed_bitset);
547 }
548
549 let mut fragments_union = BitSet::new();
551 let mut size_two_fragments_union = BitSet::new();
552 for f in fragments {
553 fragments_union.union_with(f);
554 }
555 for f in size_two_fragments.iter() {
556 size_two_fragments_union.union_with(f);
557 }
558 let z = unique_edges(&fragments_union, mol).len()
559 - unique_edges(&size_two_fragments_union, mol).len();
560
561 let mut s = 0;
564 let mut sl = 0;
565 for f in fragments {
566 s += f.len();
567 }
568 for f in large_fragments {
569 sl += f.len();
570 }
571
572 let mut size_two_types: Vec<(EdgeType, EdgeType)> = Vec::new();
574 for f in size_two_fragments.iter() {
575 let mut types = unique_edges(f, mol);
576 types.sort();
577 if types.len() == 1 {
578 size_two_types.push((types[0], types[0]));
579 } else {
580 size_two_types.push((types[0], types[1]));
581 }
582 }
583 size_two_types.sort();
584 size_two_types.dedup();
585
586 s - (z + size_two_types.len() + size_two_fragments.len())
587 - ((sl - z) as f32 / m as f32).ceil() as usize
588}
589
590pub fn index(m: &Molecule) -> u32 {
608 index_search(
609 m,
610 &[
611 Bound::IntChain,
612 Bound::VecChainSimple,
613 Bound::VecChainSmallFrags,
614 ],
615 )
616 .0
617}
618
619#[cfg(test)]
620mod tests {
621 use std::{collections::HashMap, fs, path::PathBuf};
622
623 use csv::ReaderBuilder;
624
625 use crate::loader;
626
627 use super::*;
628
629 fn read_dataset_index(dataset: &str) -> HashMap<String, u32> {
631 let path = format!("./data/{dataset}/ma-index.csv");
632 let mut reader = ReaderBuilder::new()
633 .from_path(path)
634 .expect("ma-index.csv does not exist.");
635 let mut index_records = HashMap::new();
636 for result in reader.records() {
637 let record = result.expect("ma-index.csv is malformed.");
638 let record = record.iter().collect::<Vec<_>>();
639 index_records.insert(
640 record[0].to_string(),
641 record[1]
642 .to_string()
643 .parse::<u32>()
644 .expect("Assembly index is not an integer."),
645 );
646 }
647 index_records
648 }
649
650 fn test_molecule<F>(function: F, dataset: &str, filename: &str)
652 where
653 F: Fn(&Molecule) -> u32,
654 {
655 let path = PathBuf::from(format!("./data/{dataset}/{filename}"));
656 let molfile = fs::read_to_string(path).expect("Cannot read file");
657 let molecule = loader::parse_molfile_str(&molfile).expect("Cannot parse molecule");
658 let dataset = read_dataset_index(dataset);
659 let ground_truth = dataset
660 .get(filename)
661 .expect("Index dataset has no ground truth value");
662 let index = function(&molecule);
663 assert_eq!(index, *ground_truth);
664 }
665
666 #[test]
667 fn all_bounds_benzene() {
668 test_molecule(index, "checks", "benzene.mol");
669 }
670
671 #[test]
672 fn all_bounds_aspirin() {
673 test_molecule(index, "checks", "aspirin.mol");
674 }
675
676 #[test]
677 #[ignore = "expensive test"]
678 fn all_bounds_morphine() {
679 test_molecule(index, "checks", "morphine.mol");
680 }
681
682 #[test]
683 fn naive_method_benzene() {
684 test_molecule(naive_index_search, "checks", "benzene.mol");
685 }
686
687 #[test]
688 fn naive_method_aspirin() {
689 test_molecule(naive_index_search, "checks", "aspirin.mol");
690 }
691
692 #[test]
693 #[ignore = "expensive test"]
694 fn naive_method_morphine() {
695 test_molecule(naive_index_search, "checks", "morphine.mol");
696 }
697}