1use std::fs::File;
10use std::f64;
11use std::error::Error;
12use std::collections::HashMap;
13use std::path::Path;
14use std::io::BufReader;
15use csv::StringRecord;
16use petgraph::graph::{UnGraph, NodeIndex};
17
18
19pub mod morse;
20pub mod graph;
21
22
23#[macro_use] extern crate cpython;
24use cpython::{PyResult, Python, PyList, PyTuple, PyObject, ToPyObject, FromPyObject, PyErr, exc, PyString};
25use crate::cpython::ObjectProtocol;
26
27
28
29py_module_initializer!(talus, inittalus, PyInit_talus, |py, m| {
30 m.add(py, "__doc__", "This module is implemented in Rust.")?;
31 m.add(py, "_persistence", py_fn!(py, persistence_py(nodes: PyList, edges: PyList)))?;
32 m.add(py, "_persistence_by_knn", py_fn!(py, knn_persistence_py(points: PyList, k: usize)))?;
33 m.add(py, "_persistence_by_approximate_knn", py_fn!(py, approximate_knn_persistence_py(points: PyList, k: usize, sample_rate: f64, precision: f64)))?;
34 Ok(())
35});
36
37fn approximate_knn_persistence_py(py: Python, points: PyList, k: usize, sample_rate: f64, precision: f64) -> PyResult<PyTuple> {
38 let mut labeled_points = Vec::with_capacity(points.len(py));
39 for point in points.iter(py) {
40 labeled_points.push(point.extract(py)?);
41 }
42 let g = match graph::build_knn_approximate(&labeled_points, k, sample_rate, precision){
43 Err(err) => return Err(PyErr::new::<exc::Exception, PyString>(py, PyString::new(py, &format!("{:?}", err)))),
44 Ok(g) => g
45 };
46 let complex = match morse::MorseSmaleComplex::from_graph(&g) {
47 Err(err) => return Err(PyErr::new::<exc::Exception, PyString>(py, PyString::new(py, &format!("{:?}", err)))),
48 Ok(complex) => complex
49 };
50 let data = complex.to_data(&g);
51 Ok(data.into_py_object(py))
52}
53
54fn knn_persistence_py(py: Python, points: PyList, k: usize) -> PyResult<PyTuple> {
55 let mut labeled_points = Vec::with_capacity(points.len(py));
56 for point in points.iter(py) {
57 labeled_points.push(point.extract(py)?);
58 }
59 let g = match graph::build_knn(&labeled_points, k) {
60 Err(err) => return Err(PyErr::new::<exc::Exception, PyString>(py, PyString::new(py, &format!("{:?}", err)))),
61 Ok(g) => g
62 };
63 let complex = match morse::MorseSmaleComplex::from_graph(&g) {
64 Err(err) => return Err(PyErr::new::<exc::Exception, PyString>(py, PyString::new(py, &format!("{:?}", err)))),
65 Ok(complex) => complex
66 };
67 let data = complex.to_data(&g);
68 Ok(data.into_py_object(py))
69}
70
71fn persistence_py(py: Python, nodes: PyList, edges: PyList) -> PyResult<PyTuple> {
72 let mut labeled_nodes: Vec<NodeIndex> = Vec::with_capacity(nodes.len(py));
73 let mut id_lookup: HashMap<i64, (usize, NodeIndex)> = HashMap::with_capacity(nodes.len(py));
74 let mut g = UnGraph::new_undirected();
75 for (i, node) in nodes.iter(py).enumerate() {
76 let point: LabeledPoint<Vec<f64>> = node.extract(py)?;
77 let node = g.add_node(point.clone());
78 labeled_nodes.push(node);
79 id_lookup.insert(point.id, (i, node));
80 }
81 for edge in edges.iter(py) {
82 let node_tuple: PyTuple = edge.extract(py)?;
83 let left: i64 = node_tuple.get_item(py, 0).extract(py)?;
84 let right: i64 = node_tuple.get_item(py, 1).extract(py)?;
85 g.add_edge((id_lookup.get(&left).unwrap()).1, id_lookup.get(&right).unwrap().1, 1.);
86 }
87 let complex = match morse::MorseSmaleComplex::from_graph(&g) {
88 Err(err) => return Err(PyErr::new::<exc::Exception, PyString>(py, PyString::new(py, &format!("{:?}", err)))),
89 Ok(complex) => complex
90 };
91 let data = complex.to_data(&g);
92 Ok(data.into_py_object(py))
93}
94
95impl ToPyObject for morse::MorseComplexData {
96 type ObjectType = PyTuple;
97 fn to_py_object(&self, py: Python) -> Self::ObjectType {
98 (self.lifetimes.clone(), self.filtration.clone(), self.complex.clone()).to_py_object(py)
99 }
100
101 fn into_py_object(self, py: Python) -> Self::ObjectType {
102 (self.lifetimes, self.filtration, self.complex).to_py_object(py)
103 }
104}
105
106
107pub trait Metric {
108 fn distance(&self, other: &Self) -> f64;
109}
110
111impl Metric for Vec<f64> {
112 fn distance(&self, other:&Self) -> f64 {
113 self.iter().zip(other.iter())
114 .map(|(a, b)| (a - b).powi(2))
115 .sum::<f64>().sqrt()
116 }
117}
118
119pub trait PreMetric {
120 fn predistance(&self, other: &Self) -> f64;
121}
122
123impl PreMetric for Vec<f64> {
124 fn predistance(&self, other:&Self) -> f64 {
125 self.distance(other)
126 }
127}
128
129#[derive(Debug)]
133pub struct LabeledPoint<T> {
134 pub id: i64,
136
137 pub point: T,
139
140 pub value: f64
147}
148
149impl<'s> FromPyObject<'s> for LabeledPoint<Vec<f64>> {
150 fn extract(py: Python, obj: &'s PyObject) -> PyResult<Self>{
151 let id: i64 = obj.getattr(py, "identifier")?.extract(py)?;
152 let value: f64 = obj.getattr(py, "value")?.extract(py)?;
153 let list: PyList = obj.getattr(py, "vector")?.extract(py)?;
154 let mut point: Vec<f64> = Vec::with_capacity(list.len(py));
155 for value in list.iter(py) {
156 let v = value.extract(py)?;
157 point.push(v);
158 };
159 Ok(LabeledPoint{id, value, point})
160 }
161}
162
163impl<T: Clone> Clone for LabeledPoint<T> {
164 fn clone(&self) -> Self {
165 LabeledPoint{value: self.value, point: self.point.clone(), id: self.id}
166 }
167}
168
169impl LabeledPoint<Vec<f64>> {
170 pub fn from_record(record: &StringRecord) -> LabeledPoint<Vec<f64>> {
172 let id = record[0].parse::<i64>().expect("Expected an int");
173 let value = record[1].parse::<f64>().expect("Expected a float");
174 let point = record.iter()
175 .skip(2)
176 .map(|v| v.parse::<f64>().expect("Expected a float"))
177 .collect();
178 LabeledPoint{id, point, value}
179 }
180
181 pub fn points_from_file<P: AsRef<Path>>(filename: P) -> Result<Vec<LabeledPoint<Vec<f64>>>, Box<dyn Error>> {
182 let f = File::open(filename).expect("Unable to open file");
183 let f = BufReader::new(f);
184 let mut points = Vec::with_capacity(16);
185 let mut rdr = csv::ReaderBuilder::new()
186 .has_headers(false)
187 .from_reader(f);
188 for result in rdr.records() {
189 let mut record = result?;
190 record.trim();
191 points.push(LabeledPoint::from_record(&record));
192 }
193 Ok(points)
194 }
195}