assembly_theory/
assembly.rs

1//! Compute assembly indices of molecules.
2//! # Example
3//! ```
4//! # use std::fs;
5//! # use std::path::PathBuf;
6//! # use assembly_theory::*;
7//! # fn main() -> Result<(), std::io::Error> {
8//! # let path = PathBuf::from(format!("./data/checks/benzene.mol"));
9//! // Read a molecule data file
10//! let molfile = fs::read_to_string(path)?;
11//! let benzene = loader::parse_molfile_str(&molfile).expect("Cannot parse molfile.");
12//!
13//! // Compute assembly index of benzene
14//! assert_eq!(assembly::index(&benzene), 3);
15//! # Ok(())
16//! # }
17//! ```
18use std::{
19    collections::BTreeSet,
20    sync::{
21        atomic::{AtomicUsize, Ordering::Relaxed},
22        Arc,
23    },
24};
25
26use bit_set::BitSet;
27use rayon::iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator};
28
29use crate::{
30    molecule::Bond, molecule::Element, molecule::Molecule, utils::connected_components_under_edges,
31};
32
33#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
34struct EdgeType {
35    bond: Bond,
36    ends: (Element, Element),
37}
38
39static PARALLEL_MATCH_SIZE_THRESHOLD: usize = 100;
40
41/// Enum to represent the different bounds available during the computation of molecular assembly
42/// indices.
43/// Bounds are used by `index_search()` to speed up assembly index computations.
44#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
45pub enum Bound {
46    /// `Log` bounds by the logarithm base 2 of remaining edges
47    Log,
48    /// `IntChain` bounds by the length of the smallest addition chain to create the remaining
49    /// fragments
50    IntChain,
51    /// 'VecChainSimple' bounds using addition chain length with the information of the edge types
52    /// in a molecule
53    VecChainSimple,
54    /// 'VecChainSmallFrags' bounds using information on the number of fragments of size 2 in the
55    /// molecule
56    VecChainSmallFrags,
57}
58
59pub fn naive_assembly_depth(mol: &Molecule) -> u32 {
60    let mut ix = u32::MAX;
61    for (left, right) in mol.partitions().unwrap() {
62        let l = if left.is_basic_unit() {
63            0
64        } else {
65            naive_assembly_depth(&left)
66        };
67
68        let r = if right.is_basic_unit() {
69            0
70        } else {
71            naive_assembly_depth(&right)
72        };
73
74        ix = ix.min(l.max(r) + 1)
75    }
76    ix
77}
78
79fn recurse_naive_index_search(
80    mol: &Molecule,
81    matches: &BTreeSet<(BitSet, BitSet)>,
82    fragments: &[BitSet],
83    ix: usize,
84) -> usize {
85    let mut cx = ix;
86    for (h1, h2) in matches {
87        let mut fractures = fragments.to_owned();
88        let f1 = fragments.iter().enumerate().find(|(_, c)| h1.is_subset(c));
89        let f2 = fragments.iter().enumerate().find(|(_, c)| h2.is_subset(c));
90
91        let (Some((i1, f1)), Some((i2, f2))) = (f1, f2) else {
92            continue;
93        };
94
95        // All of these clones are on bitsets and cheap enough
96        if i1 == i2 {
97            let mut union = h1.clone();
98            union.union_with(h2);
99            let mut difference = f1.clone();
100            difference.difference_with(&union);
101            let c = connected_components_under_edges(mol.graph(), &difference);
102            fractures.extend(c);
103            fractures.swap_remove(i1);
104            fractures.push(h1.clone());
105        } else {
106            let mut f1r = f1.clone();
107            f1r.difference_with(h1);
108            let mut f2r = f2.clone();
109            f2r.difference_with(h2);
110
111            let c1 = connected_components_under_edges(mol.graph(), &f1r);
112            let c2 = connected_components_under_edges(mol.graph(), &f2r);
113
114            fractures.extend(c1);
115            fractures.extend(c2);
116
117            fractures.swap_remove(i1.max(i2));
118            fractures.swap_remove(i1.min(i2));
119
120            fractures.push(h1.clone());
121        }
122        cx = cx.min(recurse_naive_index_search(
123            mol,
124            matches,
125            &fractures,
126            ix - h1.len() + 1,
127        ));
128    }
129    cx
130}
131
132/// Calculates the assembly index of a molecule without using any bounding strategy or
133/// parallelization. This function is very inefficient and should only be used as a performance
134/// benchmark against other strategies.
135pub fn naive_index_search(mol: &Molecule) -> u32 {
136    let mut init = BitSet::new();
137    init.extend(mol.graph().edge_indices().map(|ix| ix.index()));
138
139    recurse_naive_index_search(
140        mol,
141        &mol.matches().collect(),
142        &[init],
143        mol.graph().edge_count() - 1,
144    ) as u32
145}
146
147#[allow(clippy::too_many_arguments)]
148fn recurse_index_search(
149    mol: &Molecule,
150    matches: &[(BitSet, BitSet)],
151    fragments: &[BitSet],
152    ix: usize,
153    largest_remove: usize,
154    mut best: usize,
155    bounds: &[Bound],
156    states_searched: &mut usize,
157) -> usize {
158    let mut cx = ix;
159
160    *states_searched += 1;
161
162    // Branch and Bound
163    for bound_type in bounds {
164        let exceeds = match bound_type {
165            Bound::Log => ix - log_bound(fragments) >= best,
166            Bound::IntChain => ix - addition_bound(fragments, largest_remove) >= best,
167            Bound::VecChainSimple => ix - vec_bound_simple(fragments, largest_remove, mol) >= best,
168            Bound::VecChainSmallFrags => {
169                ix - vec_bound_small_frags(fragments, largest_remove, mol) >= best
170            }
171        };
172        if exceeds {
173            return ix;
174        }
175    }
176
177    // Search for duplicatable fragment
178    for (i, (h1, h2)) in matches.iter().enumerate() {
179        let mut fractures = fragments.to_owned();
180        let f1 = fragments.iter().enumerate().find(|(_, c)| h1.is_subset(c));
181        let f2 = fragments.iter().enumerate().find(|(_, c)| h2.is_subset(c));
182
183        let largest_remove = h1.len();
184
185        let (Some((i1, f1)), Some((i2, f2))) = (f1, f2) else {
186            continue;
187        };
188
189        // All of these clones are on bitsets and cheap enough
190        if i1 == i2 {
191            let mut union = h1.clone();
192            union.union_with(h2);
193            let mut difference = f1.clone();
194            difference.difference_with(&union);
195            let c = connected_components_under_edges(mol.graph(), &difference);
196            fractures.extend(c);
197            fractures.swap_remove(i1);
198        } else {
199            let mut f1r = f1.clone();
200            f1r.difference_with(h1);
201            let mut f2r = f2.clone();
202            f2r.difference_with(h2);
203
204            let c1 = connected_components_under_edges(mol.graph(), &f1r);
205            let c2 = connected_components_under_edges(mol.graph(), &f2r);
206
207            fractures.extend(c1);
208            fractures.extend(c2);
209
210            fractures.swap_remove(i1.max(i2));
211            fractures.swap_remove(i1.min(i2));
212        }
213
214        fractures.retain(|i| i.len() > 1);
215        fractures.push(h1.clone());
216
217        cx = cx.min(recurse_index_search(
218            mol,
219            &matches[i + 1..],
220            &fractures,
221            ix - h1.len() + 1,
222            largest_remove,
223            best,
224            bounds,
225            states_searched,
226        ));
227        best = best.min(cx);
228    }
229
230    cx
231}
232
233#[allow(clippy::too_many_arguments)]
234fn parallel_recurse_index_search(
235    mol: &Molecule,
236    matches: &[(BitSet, BitSet)],
237    fragments: &[BitSet],
238    ix: usize,
239    largest_remove: usize,
240    best: AtomicUsize,
241    bounds: &[Bound],
242    states_searched: Arc<AtomicUsize>,
243) -> usize {
244    let cx = AtomicUsize::from(ix);
245
246    states_searched.fetch_add(1, Relaxed);
247
248    // Branch and Bound
249    for bound_type in bounds {
250        let best = best.load(Relaxed);
251        let exceeds = match bound_type {
252            Bound::Log => ix - log_bound(fragments) >= best,
253            Bound::IntChain => ix - addition_bound(fragments, largest_remove) >= best,
254            Bound::VecChainSimple => ix - vec_bound_simple(fragments, largest_remove, mol) >= best,
255            Bound::VecChainSmallFrags => {
256                ix - vec_bound_small_frags(fragments, largest_remove, mol) >= best
257            }
258        };
259        if exceeds {
260            return ix;
261        }
262    }
263
264    // Search for duplicatable fragment
265    matches.par_iter().enumerate().for_each(|(i, (h1, h2))| {
266        let mut fractures = fragments.to_owned();
267        let f1 = fragments.iter().enumerate().find(|(_, c)| h1.is_subset(c));
268        let f2 = fragments.iter().enumerate().find(|(_, c)| h2.is_subset(c));
269
270        let largest_remove = h1.len();
271
272        let (Some((i1, f1)), Some((i2, f2))) = (f1, f2) else {
273            return;
274        };
275
276        // All of these clones are on bitsets and cheap enough
277        if i1 == i2 {
278            let mut union = h1.clone();
279            union.union_with(h2);
280            let mut difference = f1.clone();
281            difference.difference_with(&union);
282            let c = connected_components_under_edges(mol.graph(), &difference);
283            fractures.extend(c);
284            fractures.swap_remove(i1);
285        } else {
286            let mut f1r = f1.clone();
287            f1r.difference_with(h1);
288            let mut f2r = f2.clone();
289            f2r.difference_with(h2);
290
291            let c1 = connected_components_under_edges(mol.graph(), &f1r);
292            let c2 = connected_components_under_edges(mol.graph(), &f2r);
293
294            fractures.extend(c1);
295            fractures.extend(c2);
296
297            fractures.swap_remove(i1.max(i2));
298            fractures.swap_remove(i1.min(i2));
299        }
300
301        fractures.retain(|i| i.len() > 1);
302        fractures.push(h1.clone());
303
304        let output = parallel_recurse_index_search(
305            mol,
306            &matches[i + 1..],
307            &fractures,
308            ix - h1.len() + 1,
309            largest_remove,
310            best.load(Relaxed).into(),
311            bounds,
312            states_searched.clone(),
313        );
314        cx.fetch_min(output, Relaxed);
315
316        best.fetch_min(cx.load(Relaxed), Relaxed);
317    });
318
319    cx.load(Relaxed)
320}
321
322/// Computes information related to the assembly index of a molecule using the provided bounds.
323///
324/// The first result in the returned tuple is the assembly index of the molecule. The second result
325/// gives the number of duplicatable subgraphs (pairs of disjoint and isomorphic subgraphs) in the
326/// molecule. The third result is the number of states searched where a new state is considered to
327/// be searched each time a duplicatable subgraph is removed.
328///
329/// If the search space of the molecule is large (>100) parallelization will be used.
330///
331/// Bounds will be used in the order provided in the `bounds` slice. Execution along a search path
332/// will halt immediately after finding a bound that exceeds the current best assembly pathway. It
333/// is generally better to provide bounds that are quick to compute first.
334///
335/// # Example
336/// ```
337/// # use std::fs;
338/// # use std::path::PathBuf;
339/// # use assembly_theory::*;
340/// use assembly_theory::assembly::{Bound, index_search};
341/// # fn main() -> Result<(), std::io::Error> {
342/// # let path = PathBuf::from(format!("./data/checks/benzene.mol"));
343/// // Read a molecule data file
344/// let molfile = fs::read_to_string(path).expect("Cannot read input file.");
345/// let benzene = loader::parse_molfile_str(&molfile).expect("Cannot parse molfile.");
346///
347/// // Compute assembly index of benzene naively, with no bounds.
348/// let (slow_index, _, _) = index_search(&benzene, &[]);
349///
350/// // Compute assembly index of benzene with the log and integer chain bounds
351/// let (fast_index, _, _) = index_search(&benzene, &[Bound::Log, Bound::IntChain]);
352///
353/// assert_eq!(slow_index, 3);
354/// assert_eq!(fast_index, 3);
355/// # Ok(())
356/// # }
357/// ```
358pub fn index_search(mol: &Molecule, bounds: &[Bound]) -> (u32, u32, usize) {
359    let mut init = BitSet::new();
360    init.extend(mol.graph().edge_indices().map(|ix| ix.index()));
361
362    // Create and sort matches array
363    let mut matches: Vec<(BitSet, BitSet)> = mol.matches().collect();
364    matches.sort_by(|e1, e2| e2.0.len().cmp(&e1.0.len()));
365
366    let edge_count = mol.graph().edge_count();
367
368    let (index, total_search) = if matches.len() > PARALLEL_MATCH_SIZE_THRESHOLD {
369        let total_search = Arc::new(AtomicUsize::from(0));
370        let index = parallel_recurse_index_search(
371            mol,
372            &matches,
373            &[init],
374            edge_count - 1,
375            edge_count,
376            (edge_count - 1).into(),
377            bounds,
378            total_search.clone(),
379        );
380        let total_search = total_search.load(Relaxed);
381        (index as u32, total_search)
382    } else {
383        let mut total_search = 0;
384        let index = recurse_index_search(
385            mol,
386            &matches,
387            &[init],
388            edge_count - 1,
389            edge_count,
390            edge_count - 1,
391            bounds,
392            &mut total_search,
393        );
394        (index as u32, total_search)
395    };
396
397    (index, matches.len() as u32, total_search)
398}
399
400/// Like [`index_search`], but no parallelism is used.
401///
402/// # Example
403/// ```
404/// # use std::fs;
405/// # use std::path::PathBuf;
406/// # use assembly_theory::*;
407/// use assembly_theory::assembly::{Bound, serial_index_search};
408/// # fn main() -> Result<(), std::io::Error> {
409/// # let path = PathBuf::from(format!("./data/checks/benzene.mol"));
410/// // Read a molecule data file
411/// let molfile = fs::read_to_string(path).expect("Cannot read input file.");
412/// let benzene = loader::parse_molfile_str(&molfile).expect("Cannot parse molfile.");
413///
414/// // Compute assembly index of benzene naively, with no bounds.
415/// let (slow_index, _, _) = serial_index_search(&benzene, &[]);
416///
417/// // Compute assembly index of benzene with the log and integer chain bounds
418/// let (fast_index, _, _) = serial_index_search(&benzene, &[Bound::Log, Bound::IntChain]);
419///
420/// assert_eq!(slow_index, 3);
421/// assert_eq!(fast_index, 3);
422/// # Ok(())
423/// # }
424/// ```
425pub fn serial_index_search(mol: &Molecule, bounds: &[Bound]) -> (u32, u32, usize) {
426    let mut init = BitSet::new();
427    init.extend(mol.graph().edge_indices().map(|ix| ix.index()));
428
429    // Create and sort matches array
430    let mut matches: Vec<(BitSet, BitSet)> = mol.matches().collect();
431    matches.sort_by(|e1, e2| e2.0.len().cmp(&e1.0.len()));
432
433    let edge_count = mol.graph().edge_count();
434    let mut total_search = 0;
435    let index = recurse_index_search(
436        mol,
437        &matches,
438        &[init],
439        edge_count - 1,
440        edge_count,
441        edge_count - 1,
442        bounds,
443        &mut total_search,
444    );
445    (index as u32, matches.len() as u32, total_search)
446}
447
448fn log_bound(fragments: &[BitSet]) -> usize {
449    let mut size = 0;
450    for f in fragments {
451        size += f.len();
452    }
453
454    size - (size as f32).log2().ceil() as usize
455}
456
457fn addition_bound(fragments: &[BitSet], m: usize) -> usize {
458    let mut max_s: usize = 0;
459    let mut frag_sizes: Vec<usize> = Vec::new();
460
461    for f in fragments {
462        frag_sizes.push(f.len());
463    }
464
465    let size_sum: usize = frag_sizes.iter().sum();
466
467    // Test for all sizes m of largest removed duplicate
468    for max in 2..m + 1 {
469        let log = (max as f32).log2().ceil();
470        let mut aux_sum: usize = 0;
471
472        for len in &frag_sizes {
473            aux_sum += (len / max) + (len % max != 0) as usize
474        }
475
476        max_s = max_s.max(size_sum - log as usize - aux_sum);
477    }
478
479    max_s
480}
481
482// Count number of unique edges in a fragment
483// Helper function for vector bounds
484fn unique_edges(fragment: &BitSet, mol: &Molecule) -> Vec<EdgeType> {
485    let g = mol.graph();
486    let mut nodes: Vec<Element> = Vec::new();
487    for v in g.node_weights() {
488        nodes.push(v.element());
489    }
490    let edges: Vec<petgraph::prelude::EdgeIndex> = g.edge_indices().collect();
491    let weights: Vec<Bond> = g.edge_weights().copied().collect();
492
493    // types will hold an element for every unique edge type in fragment
494    let mut types: Vec<EdgeType> = Vec::new();
495    for idx in fragment.iter() {
496        let bond = weights[idx];
497        let e = edges[idx];
498
499        let (e1, e2) = g.edge_endpoints(e).expect("bad");
500        let e1 = nodes[e1.index()];
501        let e2 = nodes[e2.index()];
502        let ends = if e1 < e2 { (e1, e2) } else { (e2, e1) };
503
504        let edge_type = EdgeType { bond, ends };
505
506        if types.iter().any(|&t| t == edge_type) {
507            continue;
508        } else {
509            types.push(edge_type);
510        }
511    }
512
513    types
514}
515
516fn vec_bound_simple(fragments: &[BitSet], m: usize, mol: &Molecule) -> usize {
517    // Calculate s (total number of edges)
518    // Calculate z (number of unique edges)
519    let mut s = 0;
520    for f in fragments {
521        s += f.len();
522    }
523
524    let mut union_set = BitSet::new();
525    for f in fragments {
526        union_set.union_with(f);
527    }
528    let z = unique_edges(&union_set, mol).len();
529
530    (s - z) - ((s - z) as f32 / m as f32).ceil() as usize
531}
532
533fn vec_bound_small_frags(fragments: &[BitSet], m: usize, mol: &Molecule) -> usize {
534    let mut size_two_fragments: Vec<BitSet> = Vec::new();
535    let mut large_fragments: Vec<BitSet> = fragments.to_owned();
536    let mut indices_to_remove: Vec<usize> = Vec::new();
537
538    // Find and remove fragments of size 2
539    for (i, frag) in fragments.iter().enumerate() {
540        if frag.len() == 2 {
541            indices_to_remove.push(i);
542        }
543    }
544    for &index in indices_to_remove.iter().rev() {
545        let removed_bitset = large_fragments.remove(index);
546        size_two_fragments.push(removed_bitset);
547    }
548
549    // Compute z = num unique edges of large_fragments NOT also in size_two_fragments
550    let mut fragments_union = BitSet::new();
551    let mut size_two_fragments_union = BitSet::new();
552    for f in fragments {
553        fragments_union.union_with(f);
554    }
555    for f in size_two_fragments.iter() {
556        size_two_fragments_union.union_with(f);
557    }
558    let z = unique_edges(&fragments_union, mol).len()
559        - unique_edges(&size_two_fragments_union, mol).len();
560
561    // Compute s = total number of edges in fragments
562    // Compute sl = total number of edges in large fragments
563    let mut s = 0;
564    let mut sl = 0;
565    for f in fragments {
566        s += f.len();
567    }
568    for f in large_fragments {
569        sl += f.len();
570    }
571
572    // Find number of unique size two fragments
573    let mut size_two_types: Vec<(EdgeType, EdgeType)> = Vec::new();
574    for f in size_two_fragments.iter() {
575        let mut types = unique_edges(f, mol);
576        types.sort();
577        if types.len() == 1 {
578            size_two_types.push((types[0], types[0]));
579        } else {
580            size_two_types.push((types[0], types[1]));
581        }
582    }
583    size_two_types.sort();
584    size_two_types.dedup();
585
586    s - (z + size_two_types.len() + size_two_fragments.len())
587        - ((sl - z) as f32 / m as f32).ceil() as usize
588}
589
590/// Computes the assembly index of a molecule using an effecient bounding strategy
591/// # Example
592/// ```
593/// # use std::fs;
594/// # use std::path::PathBuf;
595/// # use assembly_theory::*;
596/// # fn main() -> Result<(), std::io::Error> {
597/// # let path = PathBuf::from(format!("./data/checks/benzene.mol"));
598/// // Read a molecule data file
599/// let molfile = fs::read_to_string(path).expect("Cannot read input file.");
600/// let benzene = loader::parse_molfile_str(&molfile).expect("Cannot parse molfile.");
601///
602/// // Compute assembly index of benzene
603/// assert_eq!(assembly::index(&benzene), 3);
604/// # Ok(())
605/// # }
606/// ```
607pub fn index(m: &Molecule) -> u32 {
608    index_search(
609        m,
610        &[
611            Bound::IntChain,
612            Bound::VecChainSimple,
613            Bound::VecChainSmallFrags,
614        ],
615    )
616    .0
617}
618
619#[cfg(test)]
620mod tests {
621    use std::{collections::HashMap, fs, path::PathBuf};
622
623    use csv::ReaderBuilder;
624
625    use crate::loader;
626
627    use super::*;
628
629    // Read Master CSV
630    fn read_dataset_index(dataset: &str) -> HashMap<String, u32> {
631        let path = format!("./data/{dataset}/ma-index.csv");
632        let mut reader = ReaderBuilder::new()
633            .from_path(path)
634            .expect("ma-index.csv does not exist.");
635        let mut index_records = HashMap::new();
636        for result in reader.records() {
637            let record = result.expect("ma-index.csv is malformed.");
638            let record = record.iter().collect::<Vec<_>>();
639            index_records.insert(
640                record[0].to_string(),
641                record[1]
642                    .to_string()
643                    .parse::<u32>()
644                    .expect("Assembly index is not an integer."),
645            );
646        }
647        index_records
648    }
649
650    // Read Test CSV
651    fn test_molecule<F>(function: F, dataset: &str, filename: &str)
652    where
653        F: Fn(&Molecule) -> u32,
654    {
655        let path = PathBuf::from(format!("./data/{dataset}/{filename}"));
656        let molfile = fs::read_to_string(path).expect("Cannot read file");
657        let molecule = loader::parse_molfile_str(&molfile).expect("Cannot parse molecule");
658        let dataset = read_dataset_index(dataset);
659        let ground_truth = dataset
660            .get(filename)
661            .expect("Index dataset has no ground truth value");
662        let index = function(&molecule);
663        assert_eq!(index, *ground_truth);
664    }
665
666    #[test]
667    fn all_bounds_benzene() {
668        test_molecule(index, "checks", "benzene.mol");
669    }
670
671    #[test]
672    fn all_bounds_aspirin() {
673        test_molecule(index, "checks", "aspirin.mol");
674    }
675
676    #[test]
677    #[ignore = "expensive test"]
678    fn all_bounds_morphine() {
679        test_molecule(index, "checks", "morphine.mol");
680    }
681
682    #[test]
683    fn naive_method_benzene() {
684        test_molecule(naive_index_search, "checks", "benzene.mol");
685    }
686
687    #[test]
688    fn naive_method_aspirin() {
689        test_molecule(naive_index_search, "checks", "aspirin.mol");
690    }
691
692    #[test]
693    #[ignore = "expensive test"]
694    fn naive_method_morphine() {
695        test_molecule(naive_index_search, "checks", "morphine.mol");
696    }
697}