geo_weights/weights.rs
1use geo::{Centroid, GeoFloat, Line};
2use geo_types::Geometry;
3use geojson::{Feature, FeatureCollection};
4use nalgebra_sparse::{coo::CooMatrix, csr::CsrMatrix};
5use std::collections::{HashMap, HashSet};
6use std::fmt;
7use std::iter::IntoIterator;
8
9pub enum TransformType {
10 Row,
11 Binary,
12 DoublyStandardized,
13}
14
15pub trait WeightBuilder<A>
16where
17 A: GeoFloat,
18{
19 fn compute_weights<T>(&self, geoms: &T) -> Weights
20 where
21 for<'a> &'a T: IntoIterator<Item = &'a Geometry<A>>;
22}
23
24/// Structure holding and providing methods to access and query a weights matrix. These are either
25/// loaded from external representations or constructed from WeightBuilders.
26#[derive(Debug)]
27pub struct Weights {
28 weights: HashMap<usize, HashMap<usize, f64>>,
29 no_elements: usize,
30}
31
32impl fmt::Display for Weights {
33 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
34 write!(f, "({:#?})", self.weights())
35 }
36}
37
38impl Weights {
39 /// Create a new weights object from a hashmap indicating origin and destination ids and
40 /// weights.
41 ///
42 /// # Arguments
43 ///
44 /// * `weights` - A mapping of {origin => dest =>weight}f64
45 /// * `no_elements` - The number of elements in the original geometry set (because we want to be
46 /// sure of the full length given this is a sparse representation)
47 ///
48 pub fn new(weights: HashMap<usize, HashMap<usize, f64>>, no_elements: usize) -> Weights {
49 Self {
50 weights,
51 no_elements,
52 }
53 }
54
55 /// Create a new weights object from a series of lists representing the origin, destinations
56 /// and weights of the matrix
57 ///
58 /// # Arguments
59 ///
60 /// * `origins` - A list of the origin ids
61 /// * `origins` - A list of the destination ids
62 /// * `weights` - A list of the weights
63 /// * `no_elements` - The number of elements in the original geometry set (because we want to be
64 /// sure of the full length given this is a sparse representation)
65 ///
66 pub fn from_list_rep<T, W>(origins: &T, dests: &T, weights: &W, no_elements: usize) -> Weights
67 where
68 for<'a> &'a T: std::iter::IntoIterator<Item = &'a usize>,
69 for<'a> &'a W: std::iter::IntoIterator<Item = &'a f64>,
70 {
71 let mut weights_lookup: HashMap<usize, HashMap<usize, f64>> = HashMap::new();
72
73 for ((origin, dest), weight) in origins
74 .into_iter()
75 .zip(dests.into_iter())
76 .zip(weights.into_iter())
77 {
78 let entry = weights_lookup.entry(*origin).or_insert(HashMap::new());
79 entry.insert(*dest, *weight);
80
81 let entry = weights_lookup.entry(*dest).or_insert(HashMap::new());
82 entry.insert(*origin, *weight);
83 }
84 Self {
85 weights: weights_lookup,
86 no_elements,
87 }
88 }
89
90 /// Return a reference to the hash map representation of the weights
91 pub fn weights(&self) -> &HashMap<usize, HashMap<usize, f64>> {
92 &self.weights
93 }
94
95 /// Return the total number of elements in the original geometry set
96 pub fn no_elements(&self) -> usize {
97 self.no_elements
98 }
99
100 /// Returns true if the origin and destination are neighbors
101 ///
102 /// # Arguments
103 ///
104 /// * `origin` - the id of the origin geometry
105 /// * `destination` - the id of the destination geometry
106 ///
107 pub fn are_neighbors(&self, origin: usize, dest: usize) -> bool {
108 self.weights.get(&origin).unwrap().contains_key(&dest)
109 }
110
111 /// Returns the ids of a given geometries neighbors
112 ///
113 /// # Arguments
114 ///
115 /// * `origin` - the id of the origin geometry
116 ///
117 pub fn get_neighbor_ids(&self, origin: usize) -> Option<HashSet<usize>> {
118 match self.weights.get(&origin) {
119 Some(m) => {
120 let results: HashSet<usize> = m.keys().into_iter().cloned().collect();
121 Some(results)
122 }
123 None => None,
124 }
125 }
126
127 /// Returns the weights matrix as a nalgebra sparse matrix
128 ///
129 /// # Arguments
130 ///
131 /// * `transfrom` - what transform, if any to apply to the weights matrix as we transform.
132 /// Only TransformType::Row for row normalized is currently implemented.
133 ///
134 pub fn as_sparse_matrix(&self, transform: Option<TransformType>) -> CsrMatrix<f64> {
135 let mut coo_matrix = CooMatrix::new(self.no_elements, self.no_elements);
136
137 for (key, vals) in self.weights.iter() {
138 let norm: f64 = match &transform {
139 Some(TransformType::Row) => vals.values().sum(),
140 _ => 1.0,
141 };
142 for (key2, weight) in vals.iter() {
143 coo_matrix.push(*key, *key2, *weight / norm);
144 }
145 }
146
147 CsrMatrix::from(&coo_matrix)
148 }
149
150 /// Returns the weights matrix in a list format
151 ///
152 /// Output format is a tuple of origin ids, dest ids, weight values
153 ///
154 pub fn to_list(&self) -> (Vec<usize>, Vec<usize>, Vec<f64>) {
155 let mut origin_list: Vec<usize> = vec![];
156 let mut dest_list: Vec<usize> = vec![];
157 let mut weight_list: Vec<f64> = vec![];
158
159 for (origin, dests) in self.weights.iter() {
160 for (dest, weight) in dests.iter() {
161 origin_list.push(*origin);
162 dest_list.push(*dest);
163 weight_list.push(*weight);
164 }
165 }
166 (origin_list, dest_list, weight_list)
167 }
168
169 /// Returns the weights matrix in a list format with geometries
170 ///
171 /// Output format is a tuple of origin ids, dest ids, weight values, geometry linking origin
172 /// and destination
173 ///
174 /// # Arguments
175 ///
176 /// * `geoms` - the list of geometries originally used to generate the weights matrix.
177 pub fn to_list_with_geom<A: GeoFloat>(
178 &self,
179 geoms: &[Geometry<A>],
180 ) -> Result<(Vec<usize>, Vec<usize>, Vec<f64>, Vec<Geometry<A>>), String> {
181 let mut origin_list: Vec<usize> = vec![];
182 let mut dest_list: Vec<usize> = vec![];
183 let mut weight_list: Vec<f64> = vec![];
184 let mut link_geoms: Vec<Geometry<A>> = vec![];
185 let no_geoms = geoms.len();
186
187 for (origin, dests) in self.weights.iter() {
188 for (dest, weight) in dests.iter() {
189 origin_list.push(*origin);
190 dest_list.push(*dest);
191 weight_list.push(*weight);
192 let origin_centroid = geoms
193 .get(*origin)
194 .ok_or_else(|| format!("Failed to get origin {} {}", origin, no_geoms))?
195 .centroid()
196 .unwrap();
197 let dest_centroid = geoms
198 .get(*dest)
199 .ok_or_else(|| format!("Failed to get origin {} {}", dest, no_geoms))?
200 .centroid()
201 .unwrap();
202 let line: geo::Geometry<A> =
203 geo::Geometry::Line(Line::new(origin_centroid, dest_centroid));
204 link_geoms.push(line);
205 }
206 }
207 Ok((origin_list, dest_list, weight_list, link_geoms))
208 }
209
210 /// Returns the weights matrix in a GeoJson format with lines between the origin and
211 /// destinations
212 ///
213 /// # Arguments
214 ///
215 /// * `geoms` - the list of geometries originally used to generate the weights matrix.
216 pub fn links_geojson<A: GeoFloat>(&self, geoms: &[Geometry<A>]) -> FeatureCollection {
217 let mut features: Vec<Feature> = vec![];
218
219 for (origin, dests) in self.weights.iter() {
220 for (dest, _weight) in dests.iter() {
221 let origin_centroid = geoms.get(*origin).unwrap().centroid().unwrap();
222 let dest_centroid = geoms.get(*dest).unwrap().centroid().unwrap();
223 let line: geojson::Geometry =
224 geojson::Value::from(&Line::new(origin_centroid, dest_centroid)).into();
225
226 let mut feature = Feature {
227 geometry: Some(line),
228 ..Default::default()
229 };
230 feature.set_property("origin", format!("{}", origin));
231 feature.set_property("dest", format!("{}", dest));
232 features.push(feature);
233 }
234 }
235 FeatureCollection {
236 features,
237 bbox: None,
238 foreign_members: None,
239 }
240 }
241
242 // pub fn as_geopoalrs(&self, geoms: &[Geometry<A>], ids: &Vec<T>)->Result<DataFrame, Error>{
243 // use polars::io::{SerWriter, SerReader};
244 // use polars::prelude::NamedFromOwned;
245 // use geopolars::geoseries::GeoSeries;
246
247 // let mut origin_ids : Vec<i32> = Vec::with_capacity(geoms.len());
248 // let mut dests_ids: Vec<i32> = Vec::with_capacity(geoms.len());
249 // let mut weights: Vec<f32> = Vec::with_capacity(geoms.len());
250 // let mut lines: Vec<Line> = Vec::with_capacity(geoms.len());
251
252 // for (origin, dests) in self.weights().unwrap().iter(){
253 // for (dest, _weight) in dests.iter(){
254 // origin_ids.push(origin);
255 // dest_ids.push(dest);
256 // weights.push(weight);
257 // let origin_index = ids.iter().position(|a| a == origin).unwrap();
258 // let dest_index = ids.iter().position(|a| a == dest).unwrap();
259 // let origin_centroid = geoms.get(origin_index).unwrap().centroid().unwrap();
260 // let dest_centroid = geoms.get(dest_index).unwrap().centroid().unwrap();
261 // let line = Line::new(origin_centroid, dest_centroid);
262 // geoms.push(line);
263 // }
264 // }
265 // let geom_col = Series::from_geom_vec(&lines);
266 // let result = DataFrame:::new([
267 // Series::from_vec("origin_id", origin_ids),
268 // Series::from_vec("dest_id", dests_ids),
269 // Series::from_vec("weight", weights),
270 // Series::from_vec("geom", geoms),
271 // ]);
272 // result
273 // }
274}