#![deny(missing_docs)]
use std::collections::BTreeMap;
use s2::{
cap::Cap, cellid::CellID, cellunion::CellUnion, latlng::LatLng, point::Point,
region::RegionCoverer, s1,
};
use serde::{
de::{MapAccess, Visitor},
ser::SerializeStruct,
Deserialize, Serialize,
};
use serde_derive::{Deserialize, Serialize};
use crate::{
cell_list::{CellList, CellScorer, UserCountScorer},
users::User,
};
const EARTH_RADIUS: f64 = 6.37e6f64;
pub struct GeoshardBuilder<Scorer, UserCollection> {
storage_level: u64,
users: UserCollection,
cell_scorer: Scorer,
min_shard_count: i32,
max_shard_count: i32,
}
impl<Scorer, UserCollection> GeoshardBuilder<Scorer, UserCollection> {
pub fn new(
storage_level: u64,
users: UserCollection,
cell_scorer: Scorer,
min_shard_count: i32,
max_shard_count: i32,
) -> Self {
Self {
storage_level,
cell_scorer,
users,
min_shard_count,
max_shard_count,
}
}
pub fn build<T>(self) -> GeoshardCollection
where
Scorer: CellScorer<UserCollection>,
UserCollection: Iterator<Item = T>,
T: User,
{
let cell_list = self
.cell_scorer
.score_cell_list(CellList::new(self.storage_level), self.users);
let scored_cells = cell_list.cell_list();
let total_load = scored_cells.iter().fold(0, |sum, i| sum + i.1);
let max_size = total_load / self.min_shard_count;
let min_size = total_load / self.max_shard_count;
let mut best_shards: Option<GeoshardCollection> = None;
let mut min_standard_deviation = f64::MAX;
for container_size in min_size..=max_size {
let shards = GeoshardCollection::new(container_size, scored_cells, self.storage_level);
let standard_deviation = shards.standard_deviation();
if standard_deviation < min_standard_deviation {
min_standard_deviation = standard_deviation;
best_shards = Some(shards);
}
}
best_shards.unwrap()
}
}
impl<UserCollection> GeoshardBuilder<UserCountScorer, UserCollection> {
pub fn user_count_scorer(
storage_level: u64,
users: UserCollection,
min_shard_count: i32,
max_shard_count: i32,
) -> Self {
Self {
storage_level,
users,
cell_scorer: UserCountScorer,
max_shard_count,
min_shard_count,
}
}
}
#[derive(Debug)]
pub struct Geoshard {
name: String,
storage_level: u64,
start: CellID,
end: CellID,
cell_score: i32,
cell_union: CellUnion,
size: usize,
}
impl Serialize for Geoshard {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let mut state = serializer.serialize_struct("Geoshard", 6)?;
state.serialize_field("name", &self.name)?;
state.serialize_field("storage_level", &self.storage_level)?;
state.serialize_field("start", &self.start.to_token())?;
state.serialize_field("end", &self.end.to_token())?;
state.serialize_field("cell_score", &self.cell_score)?;
state.serialize_field("size", &self.size)?;
state.end()
}
}
impl<'de> Deserialize<'de> for Geoshard {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
enum Field {
Name,
StorageLevel,
Start,
End,
CellScore,
Size,
}
impl<'de> Deserialize<'de> for Field {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
struct FieldVisitor;
impl<'de> Visitor<'de> for FieldVisitor {
type Value = Field;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str(
"`name` or `storage_level` or `start` or `end` or `cell_score` or `size`",
)
}
fn visit_str<E>(self, value: &str) -> Result<Field, E>
where
E: serde::de::Error,
{
match value {
"name" => Ok(Field::Name),
"storage_level" => Ok(Field::StorageLevel),
"start" => Ok(Field::Start),
"end" => Ok(Field::End),
"cell_score" => Ok(Field::CellScore),
"size" => Ok(Field::Size),
_ => Err(serde::de::Error::unknown_field(value, FIELDS)),
}
}
}
deserializer.deserialize_identifier(FieldVisitor)
}
}
struct GeoshardVisitor;
impl<'de> Visitor<'de> for GeoshardVisitor {
type Value = Geoshard;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("struct Geoshard")
}
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where
A: serde::de::SeqAccess<'de>,
{
let name = seq
.next_element()?
.ok_or_else(|| serde::de::Error::invalid_length(0, &self))?;
let storage_level = seq
.next_element()?
.ok_or_else(|| serde::de::Error::invalid_length(1, &self))?;
let start = seq
.next_element()?
.ok_or_else(|| serde::de::Error::invalid_length(2, &self))?;
let end = seq
.next_element()?
.ok_or_else(|| serde::de::Error::invalid_length(3, &self))?;
let cell_score = seq
.next_element()?
.ok_or_else(|| serde::de::Error::invalid_length(4, &self))?;
let size = seq
.next_element()?
.ok_or_else(|| serde::de::Error::invalid_length(5, &self))?;
Ok(Geoshard::new(
name,
CellID::from_token(start),
CellID::from_token(end),
cell_score,
storage_level,
size,
))
}
fn visit_map<V>(self, mut map: V) -> Result<Self::Value, V::Error>
where
V: MapAccess<'de>,
{
let mut name = None;
let mut storage_level = None;
let mut start = None;
let mut end = None;
let mut cell_score = None;
let mut size = None;
while let Some(key) = map.next_key()? {
match key {
Field::Name => {
if name.is_some() {
return Err(serde::de::Error::duplicate_field("name"));
}
name = Some(map.next_value()?);
}
Field::StorageLevel => {
if storage_level.is_some() {
return Err(serde::de::Error::duplicate_field("storage_level"));
}
storage_level = Some(map.next_value()?);
}
Field::Start => {
if start.is_some() {
return Err(serde::de::Error::duplicate_field("start"));
}
start = Some(CellID::from_token(map.next_value()?));
}
Field::End => {
if end.is_some() {
return Err(serde::de::Error::duplicate_field("end"));
}
end = Some(CellID::from_token(map.next_value()?));
}
Field::CellScore => {
if cell_score.is_some() {
return Err(serde::de::Error::duplicate_field("cell_score"));
}
cell_score = Some(map.next_value()?);
}
Field::Size => {
if size.is_some() {
return Err(serde::de::Error::duplicate_field("size"));
}
size = Some(map.next_value()?);
}
}
}
let name = name.ok_or_else(|| serde::de::Error::missing_field("name"))?;
let start = start.ok_or_else(|| serde::de::Error::missing_field("start"))?;
let end = end.ok_or_else(|| serde::de::Error::missing_field("end"))?;
let cell_score =
cell_score.ok_or_else(|| serde::de::Error::missing_field("cell_score"))?;
let storage_level = storage_level
.ok_or_else(|| serde::de::Error::missing_field("storage_level"))?;
let size = size.ok_or_else(|| serde::de::Error::missing_field("size"))?;
Ok(Geoshard::new(
name,
start,
end,
cell_score,
storage_level,
size,
))
}
}
const FIELDS: &'static [&'static str] =
&["name", "storage_level", "start", "end", "cell_score"];
deserializer.deserialize_struct("Geoshard", FIELDS, GeoshardVisitor)
}
}
impl Geoshard {
pub fn new(
name: String,
start: CellID,
end: CellID,
cell_score: i32,
storage_level: u64,
size: usize,
) -> Self {
let cell_union = CellUnion::from_range(start, end);
Self {
size,
start,
end,
name,
storage_level,
cell_score,
cell_union,
}
}
pub fn name(&self) -> &str {
&self.name
}
pub fn cell_count(&self) -> usize {
self.size
}
pub fn start(&self) -> &CellID {
&self.start
}
pub fn end(&self) -> &CellID {
&self.end
}
pub fn cell_union(&self) -> &CellUnion {
&self.cell_union
}
pub fn storage_level(&self) -> u64 {
self.storage_level
}
}
#[derive(Debug, Deserialize, Serialize)]
pub struct GeoshardCollection {
storage_level: u64,
shards: Vec<Geoshard>,
}
impl GeoshardCollection {
pub fn shards(&self) -> &Vec<Geoshard> {
&self.shards
}
pub fn storage_level(&self) -> u64 {
self.storage_level
}
}
impl GeoshardCollection {
pub fn new(
container_size: i32,
scored_cells: &BTreeMap<CellID, i32>,
storage_level: u64,
) -> Self {
let mut current_start = scored_cells.iter().next().unwrap().0;
let mut current_end = scored_cells.iter().next().unwrap().0;
let mut current_cell_count = 0;
let mut current_score = 0;
let mut shards = Vec::new();
let mut geoshard_count = 1;
for (cell_id, cell_score) in scored_cells {
if cell_score + current_score > container_size {
let shard = Geoshard::new(
format!("geoshard_user_index_{}", geoshard_count),
*current_start,
*current_end,
current_score,
cell_id.level(),
current_cell_count,
);
shards.push(shard);
current_start = cell_id;
current_cell_count = 0;
current_score = 0;
geoshard_count += 1;
}
current_end = cell_id;
current_cell_count += 1;
current_score += cell_score;
}
if geoshard_count != shards.len() {
let shard = Geoshard::new(
format!("geoshard_user_index_{}", geoshard_count),
*current_start,
*current_end,
current_score,
storage_level,
current_cell_count,
);
shards.push(shard);
}
Self {
shards,
storage_level,
}
}
pub fn standard_deviation(&self) -> f64 {
let mean: f64 = self
.shards
.iter()
.fold(0.0, |sum, x| sum + x.cell_score as f64)
/ self.shards.len() as f64;
let varience: f64 = self
.shards
.iter()
.map(|x| (x.cell_score as f64 - mean) * (x.cell_score as f64 - mean))
.sum::<f64>()
/ self.shards.len() as f64;
varience.sqrt()
}
}
#[derive(Debug)]
pub struct GeoshardSearcher {
storage_level: u64,
shards: GeoshardCollection,
}
impl GeoshardSearcher {
pub fn shards(&self) -> &GeoshardCollection {
&self.shards
}
pub fn get_shard_for_user<T>(&self, user: T) -> &Geoshard
where
T: User,
{
let location = user.location();
self.get_shard_from_location(location)
}
pub fn get_cell_id_from_location(&self, location: &LatLng) -> CellID {
CellID::from(location).parent(self.storage_level)
}
pub fn get_shard_from_location(&self, location: &LatLng) -> &Geoshard {
self.get_shard_from_cell_id(&self.get_cell_id_from_location(location))
}
pub fn get_shard_from_cell_id(&self, cell_id: &CellID) -> &Geoshard {
for geoshard in self.shards.shards.iter() {
if geoshard.cell_union().contains_cellid(cell_id) {
return geoshard;
}
}
self.shards.shards.last().unwrap()
}
pub fn get_shards_from_radius(&self, location: &LatLng, radius: u32) -> Vec<&Geoshard> {
self.cell_ids_from_radius(location, radius)
.into_iter()
.map(|cell_id| self.get_shard_from_cell_id(&cell_id))
.collect()
}
pub fn cell_ids_from_radius(&self, location: &LatLng, radius: u32) -> Vec<CellID> {
let center_point = Point::from(location);
let center_angle = s1::Deg(radius as f64 / EARTH_RADIUS).into();
let cap = Cap::from_center_angle(¢er_point, ¢er_angle);
let region_cover = RegionCoverer {
max_level: self.storage_level as u8,
min_level: self.storage_level as u8,
level_mod: 0,
max_cells: 0,
};
region_cover.covering(&cap).0
}
}
impl From<GeoshardCollection> for GeoshardSearcher {
fn from(shards: GeoshardCollection) -> Self {
let storage_level = shards.storage_level;
Self {
storage_level,
shards,
}
}
}
#[cfg(test)]
pub mod test {
use super::*;
use crate::utils::ll;
use rand::Rng;
use lazy_static::lazy_static;
use rand::{distributions::Alphanumeric, prelude::SliceRandom, thread_rng};
use s2::cellid::CellID;
struct RandCityFactory {
cities: Vec<LatLng>,
}
impl RandCityFactory {
fn new_city(&self) -> LatLng {
let mut rng = rand::thread_rng();
self.cities.choose(&mut rng).unwrap().clone()
}
fn cities(&self) -> &Vec<LatLng> {
&self.cities
}
}
impl Default for RandCityFactory {
fn default() -> Self {
let cities: Vec<LatLng> = vec![
ll!(40.745255, 40.745255),
ll!(34.155834, 34.155834),
ll!(42.933334, 42.933334),
ll!(42.095554, 42.095554),
ll!(38.846668, 38.846668),
ll!(41.392502, 41.392502),
ll!(27.192223, 27.192223),
ll!(31.442778, 31.442778),
ll!(40.560001, 40.560001),
ll!(33.193611, 33.193611),
ll!(41.676388, 41.676388),
ll!(41.543056, 41.543056),
ll!(39.554443, 39.554443),
ll!(44.513332, 44.513332),
ll!(37.554169, 37.554169),
ll!(32.349998, 32.349998),
ll!(29.499722, 29.499722),
ll!(33.038334, 33.038334),
ll!(43.614166, 43.614166),
ll!(41.55611, 41.55611),
ll!(34.00, 34.00),
ll!(26.709723, 26.709723),
ll!(38.005001, 38.005001),
ll!(35.970554, 35.970554),
ll!(25.942122, 25.942122),
ll!(33.569443, 33.569443),
ll!(39.799999, 39.799999),
ll!(34.073334, 34.073334),
ll!(40.606388, 40.606388),
ll!(30.601389, 30.601389),
ll!(38.257778, 38.257778),
ll!(37.977222, 37.977222),
ll!(42.373611, 42.373611),
ll!(32.965557, 32.965557),
ll!(37.871666, 37.871666),
ll!(38.951561, 38.951561),
ll!(33.950001, 33.950001),
ll!(30.216667, 30.216667),
ll!(42.580276, 42.580276),
ll!(36.316666, 36.316666),
ll!(37.034946, 37.034946),
ll!(40.689167, 40.689167),
ll!(33.630554, 33.630554),
ll!(39.903057, 39.903057),
ll!(25.978889, 25.978889),
ll!(35.846111, 35.846111),
ll!(34.156113, 34.156113),
ll!(41.18639, 41.18639),
ll!(40.914745, 40.914745),
ll!(42.259445, 42.259445),
ll!(41.520557, 41.520557),
ll!(33.124722, 33.124722),
ll!(39.106667, 39.106667),
ll!(42.101391, 42.101391),
ll!(37.210388, 37.210388),
ll!(33.866669, 33.866669),
ll!(26.012501, 26.012501),
ll!(38.438332, 38.438332),
ll!(33.211666, 33.211666),
ll!(37.070831, 37.070831),
ll!(43.536388, 43.536388),
ll!(45.633331, 45.633331),
ll!(42.271389, 42.271389),
ll!(30.455, 30.455),
ll!(32.492222, 32.492222),
ll!(33.466667, 33.466667),
ll!(32.361668, 32.361668),
ll!(41.763889, 41.763889),
ll!(35.199165, 35.199165),
ll!(37.661388, 37.661388),
ll!(32.907223, 32.907223),
ll!(33.669445, 33.669445),
ll!(39.710835, 39.710835),
ll!(32.705002, 32.705002),
ll!(39.099724, 39.099724),
ll!(35.1175, 35.1175),
ll!(39.791, 39.791),
ll!(39.983334, 39.983334),
ll!(30.266666, 30.266666),
ll!(32.779167, 32.779167),
ll!(37.487846, 37.487846),
ll!(35.25528, 35.25528),
ll!(29.700001, 29.700001),
ll!(26.838619, 26.838619),
ll!(38.473625, 38.473625),
ll!(29.749907, 29.749907),
ll!(40.191891, 40.191891),
ll!(33.830517, 33.830517),
ll!(34.496212, 34.496212),
ll!(37.54129, 37.54129),
ll!(36.082157, 36.082157),
ll!(32.698437, 32.698437),
ll!(33.580944, 33.580944),
ll!(33.427204, 33.427204),
ll!(34.028622, 34.028622),
ll!(32.609856, 32.609856),
ll!(33.405746, 33.405746),
ll!(34.603817, 34.603817),
ll!(44.840797, 44.840797),
ll!(71.290558, 71.290558),
];
Self { cities }
}
}
lazy_static! {
static ref RANDOM_CITY_FACTORY: RandCityFactory = RandCityFactory::default();
}
#[derive(Clone)]
pub struct FakeUser {
pub name: String,
location: LatLng,
}
impl PartialEq for FakeUser {
fn eq(&self, other: &Self) -> bool {
other.name == self.name
}
}
impl FakeUser {
pub fn new() -> Self {
let name: String = thread_rng()
.sample_iter(&Alphanumeric)
.take(30)
.map(char::from)
.collect();
Self {
name,
location: RANDOM_CITY_FACTORY.new_city(),
}
}
}
impl User for &FakeUser {
fn location(&self) -> &LatLng {
&self.location
}
}
macro_rules! shard {
($cell_score:expr) => {
Geoshard::new(
"fake-shard".to_owned(),
CellID::from_token("00001"),
CellID::from_token("00003"),
$cell_score,
0,
2,
)
};
}
pub struct RandomCellScore;
#[test]
fn test_shard_search() {
let geoshards =
GeoshardBuilder::user_count_scorer(4, Box::new(vec![FakeUser::new()].iter()), 40, 100)
.build();
let geoshard_searcher = GeoshardSearcher::from(geoshards);
let geoshard = geoshard_searcher.get_shard_from_location(&ll!(34.181061, -103.345177));
let cell_id = geoshard_searcher.get_cell_id_from_location(&ll!(34.181061, -103.345177));
assert!(geoshard.cell_union().contains_cellid(&cell_id));
}
#[test]
fn test_shard_radius_search() {
let geoshard = GeoshardBuilder::new(
4,
Box::new(vec![FakeUser::new()].iter()),
RandomCellScore,
40,
100,
)
.build();
let geoshards = GeoshardSearcher::from(geoshard);
let geoshards = geoshards.get_shards_from_radius(&ll!(34.181061, -103.345177), 200);
assert_eq!(geoshards.len(), 1);
}
#[test]
fn test_generate_shards() {
let geoshard = GeoshardBuilder::new(
4,
Box::new(vec![FakeUser::new()].iter()),
RandomCellScore,
40,
100,
)
.build();
let shards = geoshard.shards;
if (shards.len() as i32) > 100 || (shards.len() as i32) < 40 {
panic!("Shard len out of range: {}", shards.len());
}
}
impl<UserCollection> CellScorer<UserCollection> for RandomCellScore {
fn score_cell_list<T>(&self, mut cell_list: CellList, _users: UserCollection) -> CellList {
let mock_values = cell_list.mut_cell_list();
let mut rng = rand::thread_rng();
for _ in 0..=1000 {
let rand_lat = rng.gen_range(0.000000..2000.000000);
let rand_long = rng.gen_range(0.000000..2000.000000);
let cell_id = CellID::from(ll!(rand_lat, rand_long));
let rand_load_count = rng.gen_range(0..5);
mock_values.insert(cell_id, rand_load_count);
}
for _ in 0..=100 {
let rand_lat = rng.gen_range(0.000000..2000.000000);
let rand_long = rng.gen_range(0.000000..2000.000000);
let cell_id = CellID::from(ll!(rand_lat, rand_long));
let rand_load_count = rng.gen_range(10..100);
mock_values.insert(cell_id, rand_load_count);
}
for _ in 0..=50 {
let rand_lat = rng.gen_range(0.000000..2000.000000);
let rand_long = rng.gen_range(0.000000..2000.000000);
let cell_id = CellID::from(ll!(rand_lat, rand_long));
let rand_load_count = rng.gen_range(100..500);
mock_values.insert(cell_id, rand_load_count);
}
for _ in 0..=10 {
let rand_lat = rng.gen_range(0.000000..2000.000000);
let rand_long = rng.gen_range(0.000000..2000.000000);
let cell_id = CellID::from(ll!(rand_lat, rand_long));
let rand_load_count = rng.gen_range(1000..2000);
mock_values.insert(cell_id, rand_load_count);
}
cell_list
}
}
#[test]
fn test_standard_deviation() {
let shards = vec![
shard!(9),
shard!(2),
shard!(5),
shard!(4),
shard!(12),
shard!(7),
shard!(8),
shard!(11),
shard!(9),
shard!(3),
shard!(7),
shard!(4),
shard!(12),
shard!(5),
shard!(4),
shard!(10),
shard!(9),
shard!(6),
shard!(9),
shard!(4),
];
let geoshard_collection = GeoshardCollection {
shards,
storage_level: 4,
};
let standard_dev = geoshard_collection.standard_deviation();
assert_eq!(standard_dev, 2.9832867780352594_f64)
}
}