geo_weights/
weights.rs

1use geo::{Centroid, GeoFloat, Line};
2use geo_types::Geometry;
3use geojson::{Feature, FeatureCollection};
4use nalgebra_sparse::{coo::CooMatrix, csr::CsrMatrix};
5use std::collections::{HashMap, HashSet};
6use std::fmt;
7use std::iter::IntoIterator;
8
9pub enum TransformType {
10    Row,
11    Binary,
12    DoublyStandardized,
13}
14
15pub trait WeightBuilder<A>
16where
17    A: GeoFloat,
18{
19    fn compute_weights<T>(&self, geoms: &T) -> Weights
20    where
21        for<'a> &'a T: IntoIterator<Item = &'a Geometry<A>>;
22}
23
24/// Structure holding and providing methods to access and query a weights matrix. These are either
25/// loaded from external representations or constructed from WeightBuilders.
26#[derive(Debug)]
27pub struct Weights {
28    weights: HashMap<usize, HashMap<usize, f64>>,
29    no_elements: usize,
30}
31
32impl fmt::Display for Weights {
33    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
34        write!(f, "({:#?})", self.weights())
35    }
36}
37
38impl Weights {
39    /// Create a new weights object from a hashmap indicating origin and destination ids and
40    /// weights.
41    ///
42    /// # Arguments
43    ///
44    /// * `weights` - A mapping of {origin => dest =>weight}f64
45    /// * `no_elements` - The number of elements in the original geometry set (because we want to be
46    /// sure of the full length given this is a sparse representation)
47    ///
48    pub fn new(weights: HashMap<usize, HashMap<usize, f64>>, no_elements: usize) -> Weights {
49        Self {
50            weights,
51            no_elements,
52        }
53    }
54
55    /// Create a new weights object from a series of lists representing the origin, destinations
56    /// and weights of the matrix
57    ///
58    /// # Arguments
59    ///
60    /// * `origins` - A  list of the origin ids
61    /// * `origins` - A  list of the destination ids
62    /// * `weights` - A  list of the weights
63    /// * `no_elements` - The number of elements in the original geometry set (because we want to be
64    /// sure of the full length given this is a sparse representation)
65    ///
66    pub fn from_list_rep<T, W>(origins: &T, dests: &T, weights: &W, no_elements: usize) -> Weights
67    where
68        for<'a> &'a T: std::iter::IntoIterator<Item = &'a usize>,
69        for<'a> &'a W: std::iter::IntoIterator<Item = &'a f64>,
70    {
71        let mut weights_lookup: HashMap<usize, HashMap<usize, f64>> = HashMap::new();
72
73        for ((origin, dest), weight) in origins
74            .into_iter()
75            .zip(dests.into_iter())
76            .zip(weights.into_iter())
77        {
78            let entry = weights_lookup.entry(*origin).or_insert(HashMap::new());
79            entry.insert(*dest, *weight);
80
81            let entry = weights_lookup.entry(*dest).or_insert(HashMap::new());
82            entry.insert(*origin, *weight);
83        }
84        Self {
85            weights: weights_lookup,
86            no_elements,
87        }
88    }
89
90    /// Return a reference to the hash map representation of the weights
91    pub fn weights(&self) -> &HashMap<usize, HashMap<usize, f64>> {
92        &self.weights
93    }
94
95    /// Return the total number of elements in the original geometry set
96    pub fn no_elements(&self) -> usize {
97        self.no_elements
98    }
99
100    /// Returns true if the origin and destination are neighbors
101    ///
102    /// # Arguments
103    ///
104    /// * `origin` - the id of the origin geometry
105    /// * `destination` - the id of the destination geometry
106    ///
107    pub fn are_neighbors(&self, origin: usize, dest: usize) -> bool {
108        self.weights.get(&origin).unwrap().contains_key(&dest)
109    }
110
111    /// Returns the ids of a given geometries neighbors
112    ///
113    /// # Arguments
114    ///
115    /// * `origin` - the id of the origin geometry
116    ///
117    pub fn get_neighbor_ids(&self, origin: usize) -> Option<HashSet<usize>> {
118        match self.weights.get(&origin) {
119            Some(m) => {
120                let results: HashSet<usize> = m.keys().into_iter().cloned().collect();
121                Some(results)
122            }
123            None => None,
124        }
125    }
126
127    /// Returns the weights matrix as a nalgebra sparse matrix
128    ///
129    /// # Arguments
130    ///
131    /// * `transfrom` - what transform, if any to apply to the weights matrix as we  transform.
132    /// Only TransformType::Row for row normalized is currently implemented.
133    ///
134    pub fn as_sparse_matrix(&self, transform: Option<TransformType>) -> CsrMatrix<f64> {
135        let mut coo_matrix = CooMatrix::new(self.no_elements, self.no_elements);
136
137        for (key, vals) in self.weights.iter() {
138            let norm: f64 = match &transform {
139                Some(TransformType::Row) => vals.values().sum(),
140                _ => 1.0,
141            };
142            for (key2, weight) in vals.iter() {
143                coo_matrix.push(*key, *key2, *weight / norm);
144            }
145        }
146
147        CsrMatrix::from(&coo_matrix)
148    }
149
150    /// Returns the weights matrix in a list format
151    ///
152    /// Output format is a tuple of origin ids, dest ids, weight values
153    ///
154    pub fn to_list(&self) -> (Vec<usize>, Vec<usize>, Vec<f64>) {
155        let mut origin_list: Vec<usize> = vec![];
156        let mut dest_list: Vec<usize> = vec![];
157        let mut weight_list: Vec<f64> = vec![];
158
159        for (origin, dests) in self.weights.iter() {
160            for (dest, weight) in dests.iter() {
161                origin_list.push(*origin);
162                dest_list.push(*dest);
163                weight_list.push(*weight);
164            }
165        }
166        (origin_list, dest_list, weight_list)
167    }
168
169    /// Returns the weights matrix in a list format with geometries
170    ///
171    /// Output format is a tuple of origin ids, dest ids, weight values, geometry linking origin
172    /// and destination
173    ///
174    /// # Arguments
175    ///
176    /// * `geoms` - the list of geometries originally used to generate the weights matrix.
177    pub fn to_list_with_geom<A: GeoFloat>(
178        &self,
179        geoms: &[Geometry<A>],
180    ) -> Result<(Vec<usize>, Vec<usize>, Vec<f64>, Vec<Geometry<A>>), String> {
181        let mut origin_list: Vec<usize> = vec![];
182        let mut dest_list: Vec<usize> = vec![];
183        let mut weight_list: Vec<f64> = vec![];
184        let mut link_geoms: Vec<Geometry<A>> = vec![];
185        let no_geoms = geoms.len();
186
187        for (origin, dests) in self.weights.iter() {
188            for (dest, weight) in dests.iter() {
189                origin_list.push(*origin);
190                dest_list.push(*dest);
191                weight_list.push(*weight);
192                let origin_centroid = geoms
193                    .get(*origin)
194                    .ok_or_else(|| format!("Failed to get origin {} {}", origin, no_geoms))?
195                    .centroid()
196                    .unwrap();
197                let dest_centroid = geoms
198                    .get(*dest)
199                    .ok_or_else(|| format!("Failed to get origin {} {}", dest, no_geoms))?
200                    .centroid()
201                    .unwrap();
202                let line: geo::Geometry<A> =
203                    geo::Geometry::Line(Line::new(origin_centroid, dest_centroid));
204                link_geoms.push(line);
205            }
206        }
207        Ok((origin_list, dest_list, weight_list, link_geoms))
208    }
209
210    /// Returns the weights matrix in a GeoJson format with lines between the origin and
211    /// destinations
212    ///
213    /// # Arguments
214    ///
215    /// * `geoms` - the list of geometries originally used to generate the weights matrix.
216    pub fn links_geojson<A: GeoFloat>(&self, geoms: &[Geometry<A>]) -> FeatureCollection {
217        let mut features: Vec<Feature> = vec![];
218
219        for (origin, dests) in self.weights.iter() {
220            for (dest, _weight) in dests.iter() {
221                let origin_centroid = geoms.get(*origin).unwrap().centroid().unwrap();
222                let dest_centroid = geoms.get(*dest).unwrap().centroid().unwrap();
223                let line: geojson::Geometry =
224                    geojson::Value::from(&Line::new(origin_centroid, dest_centroid)).into();
225
226                let mut feature = Feature {
227                    geometry: Some(line),
228                    ..Default::default()
229                };
230                feature.set_property("origin", format!("{}", origin));
231                feature.set_property("dest", format!("{}", dest));
232                features.push(feature);
233            }
234        }
235        FeatureCollection {
236            features,
237            bbox: None,
238            foreign_members: None,
239        }
240    }
241
242    // pub fn as_geopoalrs(&self, geoms: &[Geometry<A>], ids: &Vec<T>)->Result<DataFrame, Error>{
243    //    use polars::io::{SerWriter, SerReader};
244    //    use polars::prelude::NamedFromOwned;
245    //    use geopolars::geoseries::GeoSeries;
246
247    //    let mut origin_ids : Vec<i32> = Vec::with_capacity(geoms.len());
248    //    let mut dests_ids: Vec<i32> = Vec::with_capacity(geoms.len());
249    //    let mut weights: Vec<f32> = Vec::with_capacity(geoms.len());
250    //    let mut lines: Vec<Line> = Vec::with_capacity(geoms.len());
251
252    //    for (origin, dests) in self.weights().unwrap().iter(){
253    //         for (dest, _weight) in dests.iter(){
254    //             origin_ids.push(origin);
255    //             dest_ids.push(dest);
256    //             weights.push(weight);
257    //             let origin_index =  ids.iter().position(|a| a == origin).unwrap();
258    //             let dest_index =  ids.iter().position(|a| a == dest).unwrap();
259    //             let origin_centroid = geoms.get(origin_index).unwrap().centroid().unwrap();
260    //             let dest_centroid = geoms.get(dest_index).unwrap().centroid().unwrap();
261    //             let line  = Line::new(origin_centroid, dest_centroid);
262    //             geoms.push(line);
263    //         }
264    //    }
265    //    let geom_col = Series::from_geom_vec(&lines);
266    //    let result = DataFrame:::new([
267    //         Series::from_vec("origin_id", origin_ids),
268    //         Series::from_vec("dest_id", dests_ids),
269    //         Series::from_vec("weight", weights),
270    //         Series::from_vec("geom", geoms),
271    //    ]);
272    //    result
273    // }
274}