geosuggest_core/
lib.rs

1#![doc = include_str!("../README.md")]
2use std::collections::HashMap;
3
4use itertools::Itertools;
5
6use kiddo::{self, SquaredEuclidean};
7
8use rayon::prelude::*;
9use rkyv::rend::{f32_le, u32_le};
10use strsim::jaro_winkler;
11
12#[cfg(feature = "geoip2")]
13use std::net::IpAddr;
14
15#[cfg(feature = "geoip2")]
16use geoip2::{City, Reader};
17
18#[cfg(feature = "oaph")]
19use oaph::schemars::{self, JsonSchema};
20
21pub mod index;
22pub mod storage;
23
24use index::{
25    ArchivedCitiesRecord, ArchivedCountryRecord, ArchivedEntry, ArchivedIndexData, IndexData,
26};
27
28#[cfg_attr(feature = "oaph", derive(JsonSchema))]
29#[derive(Debug, serde::Serialize)]
30pub struct ReverseItem<'a> {
31    pub city: &'a index::CitiesRecord,
32    pub distance: f32,
33    pub score: f32,
34}
35
36#[derive(Debug, serde::Serialize)]
37pub struct ArchivedReverseItem<'a> {
38    pub city: &'a index::ArchivedCitiesRecord,
39    pub distance: f32,
40    pub score: f32,
41}
42
43#[derive(
44    Debug, Default, Clone, rkyv::Serialize, rkyv::Deserialize, rkyv::Archive, serde::Serialize,
45)]
46pub struct EngineSourceMetadata {
47    pub cities: String,
48    pub names: Option<String>,
49    pub countries: Option<String>,
50    pub admin1_codes: Option<String>,
51    pub admin2_codes: Option<String>,
52    pub filter_languages: Vec<String>,
53    pub etag: HashMap<String, String>,
54}
55
56#[derive(Debug, Clone, rkyv::Serialize, rkyv::Deserialize, rkyv::Archive, serde::Serialize)]
57pub struct EngineMetadata {
58    /// Index was built on version
59    pub geosuggest_version: String,
60    /// Creation time
61    #[rkyv(with = rkyv::with::AsUnixTime)]
62    pub created_at: std::time::SystemTime,
63    /// Sources metadata
64    pub source: EngineSourceMetadata,
65    /// Custom metadata info
66    pub extra: HashMap<String, String>,
67}
68
69impl Default for EngineMetadata {
70    fn default() -> Self {
71        Self {
72            created_at: std::time::SystemTime::now(),
73            geosuggest_version: env!("CARGO_PKG_VERSION").to_owned(),
74            source: EngineSourceMetadata::default(),
75            extra: HashMap::default(),
76        }
77    }
78}
79
80pub struct EngineData {
81    pub data: rkyv::util::AlignedVec,
82    pub metadata: Option<EngineMetadata>,
83    #[cfg(feature = "geoip2")]
84    pub geoip2: Option<Vec<u8>>,
85}
86
87impl EngineData {
88    #[cfg(feature = "geoip2")]
89    pub fn load_geoip2<P: AsRef<std::path::Path>>(
90        &mut self,
91        path: P,
92    ) -> Result<(), Box<dyn std::error::Error>> {
93        self.geoip2 = std::fs::read(path)?.into();
94        Ok(())
95    }
96
97    pub fn as_engine(&self) -> Result<Engine, Box<dyn std::error::Error>> {
98        Ok(Engine {
99            data: rkyv::access::<_, rkyv::rancor::Error>(&self.data)?,
100            #[cfg(feature = "geoip2")]
101            geoip2: if let Some(geoip2) = &self.geoip2 {
102                Reader::<City>::from_bytes(geoip2)
103                    .map_err(|e| format!("Geoip2 error: {e:?}"))?
104                    .into()
105            } else {
106                None
107            },
108        })
109    }
110}
111
112pub struct Engine<'a> {
113    pub data: &'a ArchivedIndexData,
114    #[cfg(feature = "geoip2")]
115    geoip2: Option<Reader<'a, City<'a>>>,
116}
117
118impl Engine<'_> {
119    pub fn get(&self, id: &u32) -> Option<&ArchivedCitiesRecord> {
120        self.data.geonames.get(&u32_le::from_native(*id))
121    }
122
123    /// Get capital by uppercase country code
124    pub fn capital(&self, country_code: &str) -> Option<&ArchivedCitiesRecord> {
125        if let Some(city_id) = self.data.capitals.get(country_code) {
126            self.data.geonames.get(city_id)
127        } else {
128            None
129        }
130    }
131
132    /// Suggest cities by pattern (multilang).
133    ///
134    /// Optional: filter by Jaro–Winkler distance via min_score
135    ///
136    /// Optional: prefilter by countries
137    pub fn suggest<T: AsRef<str>>(
138        &self,
139        pattern: &str,
140        limit: usize,
141        min_score: Option<f32>,
142        countries: Option<&[T]>,
143    ) -> Vec<&ArchivedCitiesRecord> {
144        if limit == 0 {
145            return Vec::new();
146        }
147
148        let min_score = min_score.unwrap_or(0.8);
149        let normalized_pattern = pattern.to_lowercase();
150
151        let filter_by_pattern = |item: &ArchivedEntry| -> Option<(&ArchivedCitiesRecord, f32)> {
152            let score = if item.value.starts_with(&normalized_pattern) {
153                1.0
154            } else {
155                jaro_winkler(&item.value, &normalized_pattern) as f32
156            };
157            if score >= min_score {
158                self.data.geonames.get(&item.id).map(|city| (city, score))
159            } else {
160                None
161            }
162        };
163
164        let mut result: Vec<(&ArchivedCitiesRecord, f32)> = match &countries {
165            Some(countries) => {
166                let country_ids = countries
167                    .iter()
168                    .filter_map(|code| {
169                        self.data
170                            .country_info_by_code
171                            .get(code.as_ref())
172                            .map(|c| &c.info.geonameid)
173                    })
174                    .collect::<Vec<_>>();
175                self.data
176                    .entries
177                    .par_iter()
178                    .filter(|item| {
179                        item.country_id
180                            .as_ref()
181                            .map(|id| country_ids.contains(&id))
182                            .unwrap_or_default()
183                    })
184                    .filter_map(filter_by_pattern)
185                    .collect()
186            }
187            None => self
188                .data
189                .entries
190                .par_iter()
191                .filter_map(filter_by_pattern)
192                .collect(),
193        };
194
195        // sort by score desc, population desc
196        result.sort_unstable_by(|lhs, rhs| {
197            if (lhs.1 - rhs.1).abs() < f32::EPSILON {
198                rhs.0
199                    .population
200                    .partial_cmp(&lhs.0.population)
201                    .unwrap_or(std::cmp::Ordering::Equal)
202            } else {
203                rhs.1
204                    .partial_cmp(&lhs.1)
205                    .unwrap_or(std::cmp::Ordering::Equal)
206            }
207        });
208
209        result
210            .iter()
211            .unique_by(|item| item.0.id)
212            .take(limit)
213            .map(|item| item.0)
214            .collect::<Vec<_>>()
215    }
216
217    /// Find the nearest cities by coordinates.
218    ///
219    /// Optional: score results by `k` as `distance - k * city.population` and sort by score.
220    ///
221    /// Optional: prefilter by countries. It's a very expensive case; consider building an index for concrete countries and not applying this filter at all.
222    pub fn reverse<T: AsRef<str>>(
223        &self,
224        loc: (f32, f32),
225        limit: usize,
226        k: Option<f32>,
227        countries: Option<&[T]>,
228    ) -> Option<Vec<ArchivedReverseItem>> {
229        if limit == 0 {
230            return None;
231        }
232
233        let nearest_limit = std::num::NonZero::new(if countries.is_some() {
234            // ugly hack try to fetch nearest cities in requested countries
235            // much better is to build index for concrete countries
236            self.data.geonames.len()
237        } else {
238            limit
239        })?;
240
241        let mut i1;
242        let mut i2;
243
244        let items = &mut self
245            .data
246            .tree
247            .nearest_n::<SquaredEuclidean>(&[loc.0, loc.1], nearest_limit);
248
249        let items: &mut dyn Iterator<Item = (_, &ArchivedCitiesRecord)> =
250            if let Some(countries) = countries {
251                // normalize
252                let countries = countries
253                    .iter()
254                    .map(|code| code.as_ref())
255                    .collect::<Vec<_>>();
256
257                i1 = items.iter_mut().filter_map(move |nearest| {
258                    let geonameid = self
259                        .data
260                        .tree_index_to_geonameid
261                        .get(&u32_le::from(nearest.item))?;
262                    let city = self.data.geonames.get(geonameid)?;
263                    let country = city.country.as_ref()?;
264                    if countries.contains(&country.code.as_str()) {
265                        Some((nearest, city))
266                    } else {
267                        None
268                    }
269                });
270                &mut i1
271            } else {
272                i2 = items.iter_mut().filter_map(|nearest| {
273                    let geonameid = self
274                        .data
275                        .tree_index_to_geonameid
276                        .get(&u32_le::from(nearest.item))?;
277                    let city = self.data.geonames.get(geonameid)?;
278                    Some((nearest, city))
279                });
280                &mut i2
281            };
282
283        if let Some(k) = k.map(f32_le::from_native) {
284            let mut points = items
285                .map(|item| {
286                    (
287                        item.0.distance,
288                        item.0.distance - k * (item.1.population.to_native() as f32),
289                        item.1,
290                    )
291                })
292                .take(limit)
293                .collect::<Vec<_>>();
294
295            points.sort_unstable_by(|a, b| {
296                a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal)
297            });
298
299            Some(
300                points
301                    .iter()
302                    .map(|p| ArchivedReverseItem {
303                        distance: p.0,
304                        score: p.1,
305                        city: p.2,
306                    })
307                    .collect(),
308            )
309        } else {
310            Some(
311                items
312                    .map(|item| ArchivedReverseItem {
313                        distance: item.0.distance,
314                        score: item.0.distance,
315                        city: item.1,
316                    })
317                    .take(limit)
318                    .collect(),
319            )
320        }
321    }
322
323    /// Get country info by iso 2-letter country code.
324    pub fn country_info(&self, country_code: &str) -> Option<&ArchivedCountryRecord> {
325        self.data.country_info_by_code.get(country_code)
326    }
327
328    #[cfg(feature = "geoip2")]
329    pub fn geoip2_lookup(&self, addr: IpAddr) -> Option<&ArchivedCitiesRecord> {
330        match self.geoip2.as_ref() {
331            Some(reader) => {
332                let result = reader.lookup(addr).ok()?;
333                let city = result.city?;
334                let id = city.geoname_id?;
335                self.data.geonames.get(&u32_le::from_native(id))
336            }
337            None => {
338                #[cfg(feature = "tracing")]
339                tracing::warn!("Geoip2 reader is't configured!");
340                None
341            }
342        }
343    }
344}
345
346impl TryFrom<IndexData> for EngineData {
347    type Error = rkyv::rancor::Error;
348    fn try_from(data: IndexData) -> Result<EngineData, Self::Error> {
349        Ok(EngineData {
350            data: rkyv::to_bytes(&data)?,
351            metadata: None,
352            #[cfg(feature = "geoip2")]
353            geoip2: None,
354        })
355    }
356}
357
358impl TryFrom<rkyv::util::AlignedVec> for EngineData {
359    type Error = rkyv::rancor::Error;
360    fn try_from(bytes: rkyv::util::AlignedVec) -> Result<EngineData, Self::Error> {
361        Ok(EngineData {
362            data: bytes,
363            metadata: None,
364            #[cfg(feature = "geoip2")]
365            geoip2: None,
366        })
367    }
368}