use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
use std::sync::Arc;
use atomr_cluster_sharding::{EntityRef, MessageExtractor, ShardCoordinator, ShardRegion};
use atomr_core::actor::ActorRef;
use crate::device::DeviceMsg;
pub struct RoutedDeviceMsg {
pub entity_id: String,
pub msg: DeviceMsg,
}
pub struct DeviceExtractor {
shard_count: usize,
}
impl DeviceExtractor {
pub fn new(shard_count: usize) -> Self {
Self {
shard_count: shard_count.max(1),
}
}
}
impl MessageExtractor for DeviceExtractor {
type Message = RoutedDeviceMsg;
fn entity_id(&self, message: &Self::Message) -> String {
message.entity_id.clone()
}
fn shard_id(&self, message: &Self::Message) -> String {
let mut h = DefaultHasher::new();
message.entity_id.hash(&mut h);
let n = h.finish() as usize % self.shard_count;
format!("shard-{n}")
}
}
pub struct PlacementShardingAdapter {
region: Arc<ShardRegion<DeviceExtractor>>,
}
impl PlacementShardingAdapter {
pub fn start(
region_id: impl Into<String>,
devices: Vec<ActorRef<DeviceMsg>>,
shard_count: usize,
) -> Self {
let n_devices = devices.len().max(1);
let n_shards = if shard_count == 0 {
n_devices
} else {
shard_count
};
let extractor = Arc::new(DeviceExtractor::new(n_shards));
let coord = Arc::new(ShardCoordinator::new());
let devices = Arc::new(devices);
let devices_for_factory = devices.clone();
let region = ShardRegion::new(
region_id,
extractor,
coord,
Arc::new(move || {
let devices = devices_for_factory.clone();
Box::new(move |entity_id: &str, msg: RoutedDeviceMsg| {
if devices.is_empty() {
return;
}
let mut h = DefaultHasher::new();
entity_id.hash(&mut h);
let idx = (h.finish() as usize) % devices.len();
devices[idx].tell(msg.msg);
})
}),
);
Self { region }
}
pub fn entity(&self, entity_id: impl Into<String>) -> EntityRef<DeviceExtractor> {
EntityRef::new(self.region.clone(), entity_id.into())
}
pub fn region(&self) -> Arc<ShardRegion<DeviceExtractor>> {
self.region.clone()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::device::{DeviceActor, DeviceConfig};
use atomr_config::Config;
use atomr_core::actor::ActorSystem;
use std::time::Duration;
use tokio::sync::oneshot;
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn entity_ref_routes_to_one_of_the_devices() {
let sys = ActorSystem::create("sharding-adapter", Config::empty())
.await
.unwrap();
let d0 = sys
.actor_of(DeviceActor::props(DeviceConfig::mock(0)), "d0")
.unwrap();
let d1 = sys
.actor_of(DeviceActor::props(DeviceConfig::mock(1)), "d1")
.unwrap();
let adapter = PlacementShardingAdapter::start("gpu", vec![d0, d1], 16);
let entity = adapter.entity("user-42");
let (tx, rx) = oneshot::channel();
entity.tell(RoutedDeviceMsg {
entity_id: "user-42".into(),
msg: DeviceMsg::Allocate { len: 16, reply: tx },
});
let _ = tokio::time::timeout(Duration::from_secs(2), rx)
.await
.expect("Allocate reply should arrive within timeout");
sys.terminate().await;
}
}