Skip to main content

talus/
lib.rs

1//! Talus
2//! =====
3//!
4//! A collection of computational topology algorithms written in Rust, with Python bindings.
5///
6/// The current use case covered by this crate is the creation of kNN graphs, and the computation
7/// of the MorseSmaleComplex of those graphs (and the corresponding persistence values for the
8/// extrema in the graph).
9use std::fs::File;
10use std::f64;
11use std::error::Error;
12use std::collections::HashMap;
13use std::path::Path;
14use std::io::BufReader;
15use csv::StringRecord;
16use petgraph::graph::{UnGraph, NodeIndex};
17
18
19pub mod morse;
20pub mod graph;
21
22
23#[macro_use] extern crate cpython;
24use cpython::{PyResult, Python, PyList, PyTuple, PyObject, ToPyObject, FromPyObject, PyErr, exc, PyString};
25use crate::cpython::ObjectProtocol;
26
27
28
29py_module_initializer!(talus, inittalus, PyInit_talus, |py, m| {
30    m.add(py, "__doc__", "This module is implemented in Rust.")?;
31    m.add(py, "_persistence", py_fn!(py, persistence_py(nodes: PyList, edges: PyList)))?;
32    m.add(py, "_persistence_by_knn", py_fn!(py, knn_persistence_py(points: PyList, k: usize)))?;
33    m.add(py, "_persistence_by_approximate_knn", py_fn!(py, approximate_knn_persistence_py(points: PyList, k: usize, sample_rate: f64, precision: f64)))?;
34    Ok(())
35});
36
37fn approximate_knn_persistence_py(py: Python, points: PyList, k: usize, sample_rate: f64, precision: f64) -> PyResult<PyTuple> {
38    let mut labeled_points = Vec::with_capacity(points.len(py));
39    for point in points.iter(py) {
40        labeled_points.push(point.extract(py)?);
41    }
42    let g = match graph::build_knn_approximate(&labeled_points, k, sample_rate, precision){
43        Err(err) => return Err(PyErr::new::<exc::Exception, PyString>(py, PyString::new(py, &format!("{:?}", err)))),
44        Ok(g) => g
45    };
46    let complex = match morse::MorseSmaleComplex::from_graph(&g) {
47        Err(err) => return Err(PyErr::new::<exc::Exception, PyString>(py, PyString::new(py, &format!("{:?}", err)))),
48        Ok(complex) => complex
49    };
50    let data = complex.to_data(&g);
51    Ok(data.into_py_object(py))
52}
53
54fn knn_persistence_py(py: Python, points: PyList, k: usize) -> PyResult<PyTuple> {
55    let mut labeled_points = Vec::with_capacity(points.len(py));
56    for point in points.iter(py) {
57        labeled_points.push(point.extract(py)?);
58    }
59    let g = match graph::build_knn(&labeled_points, k) {
60        Err(err) => return Err(PyErr::new::<exc::Exception, PyString>(py, PyString::new(py, &format!("{:?}", err)))),
61        Ok(g) => g
62    };
63    let complex = match morse::MorseSmaleComplex::from_graph(&g) {
64        Err(err) => return Err(PyErr::new::<exc::Exception, PyString>(py, PyString::new(py, &format!("{:?}", err)))),
65        Ok(complex) => complex
66    };
67    let data = complex.to_data(&g);
68    Ok(data.into_py_object(py))
69}
70
71fn persistence_py(py: Python, nodes: PyList, edges: PyList) -> PyResult<PyTuple> {
72    let mut labeled_nodes: Vec<NodeIndex> = Vec::with_capacity(nodes.len(py));
73    let mut id_lookup: HashMap<i64, (usize, NodeIndex)> = HashMap::with_capacity(nodes.len(py));
74    let mut g = UnGraph::new_undirected();
75    for (i, node) in nodes.iter(py).enumerate() {
76        let point: LabeledPoint<Vec<f64>> = node.extract(py)?;
77        let node = g.add_node(point.clone());
78        labeled_nodes.push(node);
79        id_lookup.insert(point.id, (i, node));
80    }
81    for edge in edges.iter(py) {
82        let node_tuple: PyTuple = edge.extract(py)?;
83        let left: i64 = node_tuple.get_item(py, 0).extract(py)?;
84        let right: i64 = node_tuple.get_item(py, 1).extract(py)?;
85        g.add_edge((id_lookup.get(&left).unwrap()).1, id_lookup.get(&right).unwrap().1, 1.);
86    }
87    let complex = match morse::MorseSmaleComplex::from_graph(&g) {
88        Err(err) => return Err(PyErr::new::<exc::Exception, PyString>(py, PyString::new(py, &format!("{:?}", err)))),
89        Ok(complex) => complex
90    };
91    let data = complex.to_data(&g);
92    Ok(data.into_py_object(py))
93}
94
95impl ToPyObject for morse::MorseComplexData {
96    type ObjectType = PyTuple;
97    fn to_py_object(&self, py: Python) -> Self::ObjectType {
98        (self.lifetimes.clone(), self.filtration.clone(), self.complex.clone()).to_py_object(py)
99    }
100
101    fn into_py_object(self, py: Python) -> Self::ObjectType {
102        (self.lifetimes, self.filtration, self.complex).to_py_object(py)
103    }
104}
105
106
107pub trait Metric {
108    fn distance(&self, other: &Self) -> f64;
109}
110
111impl Metric for Vec<f64> {
112    fn distance(&self, other:&Self) -> f64 {
113        self.iter().zip(other.iter())
114            .map(|(a, b)| (a - b).powi(2))
115            .sum::<f64>().sqrt()
116    }
117}
118
119pub trait PreMetric {
120    fn predistance(&self, other: &Self) -> f64;
121}
122
123impl PreMetric for Vec<f64> {
124    fn predistance(&self, other:&Self) -> f64 {
125        self.distance(other)
126    }
127}
128
129/// A point in a graph that contains enough information to allow for Morse complex construction
130///
131///
132#[derive(Debug)]
133pub struct LabeledPoint<T> {
134    /// An identifier for this point. Assumed to be unique.
135    pub id: i64,
136
137    /// FIXME The vector denoting the points location in some space. Used for distance computations.
138    pub point: T,
139
140    /// The scalar value associated with this point. 
141    ///
142    /// This is the value that is used to determine extrema in the graph.
143    ///
144    /// Mathematically speaking, this corresponds to the value of some morse function at this
145    /// point.
146    pub value: f64
147}
148
149impl<'s> FromPyObject<'s> for LabeledPoint<Vec<f64>> {
150    fn extract(py: Python, obj: &'s PyObject) -> PyResult<Self>{
151        let id: i64 = obj.getattr(py, "identifier")?.extract(py)?;
152        let value: f64 = obj.getattr(py, "value")?.extract(py)?;
153        let list: PyList = obj.getattr(py, "vector")?.extract(py)?;
154        let mut point: Vec<f64> = Vec::with_capacity(list.len(py));
155        for value in list.iter(py) {
156            let v = value.extract(py)?;
157            point.push(v);
158        };
159        Ok(LabeledPoint{id, value, point})
160    }
161}
162
163impl<T: Clone> Clone for LabeledPoint<T> {
164    fn clone(&self) -> Self {
165        LabeledPoint{value: self.value, point: self.point.clone(), id: self.id}
166    }
167}
168
169impl LabeledPoint<Vec<f64>> {
170    // FIXME: move hte vec stuff to its own impl
171    pub fn from_record(record: &StringRecord) -> LabeledPoint<Vec<f64>> {
172        let id = record[0].parse::<i64>().expect("Expected an int");
173        let value = record[1].parse::<f64>().expect("Expected a float");
174        let point = record.iter()
175            .skip(2)
176            .map(|v| v.parse::<f64>().expect("Expected a float"))
177            .collect();
178        LabeledPoint{id, point, value}
179    }
180
181    pub fn points_from_file<P: AsRef<Path>>(filename: P) -> Result<Vec<LabeledPoint<Vec<f64>>>, Box<dyn Error>> {
182        let f = File::open(filename).expect("Unable to open file");
183        let f = BufReader::new(f);
184        let mut points = Vec::with_capacity(16);
185        let mut rdr = csv::ReaderBuilder::new()
186            .has_headers(false)
187            .from_reader(f);
188        for result in rdr.records() {
189            let mut record = result?;
190            record.trim();
191            points.push(LabeledPoint::from_record(&record));
192        }
193        Ok(points)
194    }
195}