use std::collections::VecDeque;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering};
use std::thread::{self, JoinHandle};
use std::time::{Duration, Instant};
use crossbeam_channel::{Receiver, Sender, bounded};
use parking_lot::Mutex;
pub type ActorId = u64;
pub type MessageId = u64;
pub struct Message<T> {
pub id: MessageId,
pub payload: T,
pub created_at: Instant,
}
impl<T> Message<T> {
pub fn new(id: MessageId, payload: T) -> Self {
Self {
id,
payload,
created_at: Instant::now(),
}
}
pub fn age(&self) -> Duration {
self.created_at.elapsed()
}
}
pub struct Response<R> {
pub message_id: MessageId,
pub result: Result<R, ActorError>,
pub processing_time: Duration,
}
#[derive(Debug)]
pub enum ActorError {
MailboxFull,
ActorStopped,
HandlerError(String),
Timeout,
}
impl std::fmt::Display for ActorError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ActorError::MailboxFull => write!(f, "Actor mailbox is full"),
ActorError::ActorStopped => write!(f, "Actor has stopped"),
ActorError::HandlerError(e) => write!(f, "Handler error: {}", e),
ActorError::Timeout => write!(f, "Request timed out"),
}
}
}
impl std::error::Error for ActorError {}
pub trait Handler<M, R>: Send + Sync {
fn handle(&mut self, message: M) -> Result<R, ActorError>;
}
#[derive(Debug, Clone, Default)]
pub struct ActorStats {
pub messages_processed: u64,
pub messages_pending: usize,
pub total_processing_time_us: u64,
pub max_processing_time_us: u64,
pub avg_wait_time_us: u64,
}
struct ActorInner<M, R, H: Handler<M, R>> {
#[allow(dead_code)]
id: ActorId,
handler: H,
inbox: Receiver<Message<M>>,
running: Arc<AtomicBool>,
stats: ActorStats,
_phantom: std::marker::PhantomData<R>,
}
impl<M: Send + 'static, R: Send + 'static, H: Handler<M, R> + 'static> ActorInner<M, R, H> {
fn run(mut self, response_tx: Sender<Response<R>>) {
while self.running.load(Ordering::Acquire) {
match self.inbox.recv_timeout(Duration::from_millis(100)) {
Ok(msg) => {
let wait_time = msg.age();
let start = Instant::now();
let result = self.handler.handle(msg.payload);
let processing_time = start.elapsed();
self.stats.messages_processed += 1;
let proc_us = processing_time.as_micros() as u64;
self.stats.total_processing_time_us += proc_us;
if proc_us > self.stats.max_processing_time_us {
self.stats.max_processing_time_us = proc_us;
}
let wait_us = wait_time.as_micros() as u64;
let n = self.stats.messages_processed;
self.stats.avg_wait_time_us =
(self.stats.avg_wait_time_us * (n - 1) + wait_us) / n;
let _ = response_tx.send(Response {
message_id: msg.id,
result,
processing_time,
});
}
Err(crossbeam_channel::RecvTimeoutError::Timeout) => {
continue;
}
Err(crossbeam_channel::RecvTimeoutError::Disconnected) => {
break;
}
}
}
}
}
pub struct ActorRef<M, R> {
id: ActorId,
inbox: Sender<Message<M>>,
responses: Receiver<Response<R>>,
next_message_id: AtomicU64,
running: Arc<AtomicBool>,
}
impl<M: Send + 'static, R: Send + 'static> ActorRef<M, R> {
pub fn ask(&self, message: M) -> Result<R, ActorError> {
self.ask_timeout(message, Duration::from_secs(30))
}
pub fn ask_timeout(&self, message: M, timeout: Duration) -> Result<R, ActorError> {
if !self.running.load(Ordering::Acquire) {
return Err(ActorError::ActorStopped);
}
let id = self.next_message_id.fetch_add(1, Ordering::SeqCst);
let msg = Message::new(id, message);
self.inbox.send(msg).map_err(|_| ActorError::ActorStopped)?;
match self.responses.recv_timeout(timeout) {
Ok(resp) => resp.result,
Err(_) => Err(ActorError::Timeout),
}
}
pub fn tell(&self, message: M) -> Result<(), ActorError> {
if !self.running.load(Ordering::Acquire) {
return Err(ActorError::ActorStopped);
}
let id = self.next_message_id.fetch_add(1, Ordering::SeqCst);
let msg = Message::new(id, message);
self.inbox.try_send(msg).map_err(|e| match e {
crossbeam_channel::TrySendError::Full(_) => ActorError::MailboxFull,
crossbeam_channel::TrySendError::Disconnected(_) => ActorError::ActorStopped,
})
}
pub fn is_running(&self) -> bool {
self.running.load(Ordering::Acquire)
}
pub fn id(&self) -> ActorId {
self.id
}
pub fn stop(&self) {
self.running.store(false, Ordering::Release);
}
}
pub struct Actor;
impl Actor {
pub fn spawn<M, R, H>(
id: ActorId,
handler: H,
mailbox_size: usize,
) -> (ActorRef<M, R>, JoinHandle<()>)
where
M: Send + 'static,
R: Send + 'static,
H: Handler<M, R> + 'static,
{
let (inbox_tx, inbox_rx) = bounded(mailbox_size);
let (resp_tx, resp_rx) = bounded(mailbox_size);
let running = Arc::new(AtomicBool::new(true));
let inner = ActorInner {
id,
handler,
inbox: inbox_rx,
running: Arc::clone(&running),
stats: ActorStats::default(),
_phantom: std::marker::PhantomData,
};
let handle = thread::spawn(move || {
inner.run(resp_tx);
});
let actor_ref = ActorRef {
id,
inbox: inbox_tx,
responses: resp_rx,
next_message_id: AtomicU64::new(1),
running,
};
(actor_ref, handle)
}
}
pub struct ActorPool<M: Send + Clone + 'static, R: Send + 'static> {
actors: Vec<ActorRef<M, R>>,
handles: Mutex<Vec<JoinHandle<()>>>,
next_actor: AtomicUsize,
#[allow(dead_code)]
next_actor_id: AtomicU64,
}
impl<M: Send + Clone + 'static, R: Send + 'static> ActorPool<M, R> {
pub fn new<F, H>(size: usize, factory: F, mailbox_size: usize) -> Self
where
F: Fn() -> H,
H: Handler<M, R> + 'static,
{
let mut actors = Vec::with_capacity(size);
let mut handles = Vec::with_capacity(size);
let next_id = AtomicU64::new(1);
for _ in 0..size {
let id = next_id.fetch_add(1, Ordering::SeqCst);
let handler = factory();
let (actor_ref, handle) = Actor::spawn(id, handler, mailbox_size);
actors.push(actor_ref);
handles.push(handle);
}
Self {
actors,
handles: Mutex::new(handles),
next_actor: AtomicUsize::new(0),
next_actor_id: next_id,
}
}
pub fn ask(&self, message: M) -> Result<R, ActorError> {
let idx = self.next_actor.fetch_add(1, Ordering::Relaxed) % self.actors.len();
self.actors[idx].ask(message)
}
pub fn ask_actor(&self, actor_idx: usize, message: M) -> Result<R, ActorError> {
if actor_idx >= self.actors.len() {
return Err(ActorError::HandlerError("Invalid actor index".to_string()));
}
self.actors[actor_idx].ask(message)
}
pub fn broadcast(&self, message: M) -> Vec<Result<R, ActorError>> {
self.actors.iter().map(|a| a.ask(message.clone())).collect()
}
pub fn size(&self) -> usize {
self.actors.len()
}
pub fn shutdown(&self) {
for actor in &self.actors {
actor.stop();
}
let mut handles = self.handles.lock();
for handle in handles.drain(..) {
let _ = handle.join();
}
}
}
impl<M: Send + Clone + 'static, R: Send + 'static> Drop for ActorPool<M, R> {
fn drop(&mut self) {
self.shutdown();
}
}
#[allow(dead_code)]
pub struct WorkStealingPool<M: Send + 'static, R: Send + 'static> {
actors: Vec<Arc<ActorRef<M, R>>>,
queues: Vec<Arc<Mutex<VecDeque<Message<M>>>>>,
handles: Mutex<Vec<JoinHandle<()>>>,
running: Arc<AtomicBool>,
}
#[derive(Debug, Clone)]
pub enum AffinityHint {
KeyBased(u64),
LeastLoaded,
RoundRobin,
Specific(ActorId),
}
pub struct RequestRouter<M: Send + Clone + 'static, R: Send + 'static> {
pool: Arc<ActorPool<M, R>>,
key_to_actor: Mutex<std::collections::HashMap<u64, usize>>,
}
impl<M: Send + Clone + 'static, R: Send + 'static> RequestRouter<M, R> {
pub fn new(pool: Arc<ActorPool<M, R>>) -> Self {
Self {
pool,
key_to_actor: Mutex::new(std::collections::HashMap::new()),
}
}
pub fn route(&self, message: M, hint: AffinityHint) -> Result<R, ActorError> {
let actor_idx = match hint {
AffinityHint::KeyBased(key) => {
let mut mapping = self.key_to_actor.lock();
*mapping
.entry(key)
.or_insert_with(|| (key as usize) % self.pool.size())
}
AffinityHint::LeastLoaded => {
0 }
AffinityHint::RoundRobin => {
return self.pool.ask(message);
}
AffinityHint::Specific(id) => (id as usize) % self.pool.size(),
};
self.pool.ask_actor(actor_idx, message)
}
}
#[cfg(test)]
mod tests {
use super::*;
struct EchoHandler;
impl Handler<String, String> for EchoHandler {
fn handle(&mut self, message: String) -> Result<String, ActorError> {
Ok(format!("Echo: {}", message))
}
}
struct CounterHandler {
count: u64,
}
impl Handler<(), u64> for CounterHandler {
fn handle(&mut self, _: ()) -> Result<u64, ActorError> {
self.count += 1;
Ok(self.count)
}
}
#[test]
fn test_actor_spawn() {
let (actor, handle) = Actor::spawn(1, EchoHandler, 10);
let result = actor.ask("Hello".to_string()).unwrap();
assert_eq!(result, "Echo: Hello");
actor.stop();
handle.join().unwrap();
}
#[test]
fn test_actor_pool() {
let pool = ActorPool::new(4, || EchoHandler, 100);
let results: Vec<_> = (0..10)
.map(|i| pool.ask(format!("Message {}", i)))
.collect();
for (i, result) in results.into_iter().enumerate() {
assert_eq!(result.unwrap(), format!("Echo: Message {}", i));
}
pool.shutdown();
}
#[test]
fn test_counter_handler() {
let (actor, handle) = Actor::spawn(1, CounterHandler { count: 0 }, 10);
assert_eq!(actor.ask(()).unwrap(), 1);
assert_eq!(actor.ask(()).unwrap(), 2);
assert_eq!(actor.ask(()).unwrap(), 3);
actor.stop();
handle.join().unwrap();
}
#[test]
fn test_broadcast() {
let pool = ActorPool::new(4, || CounterHandler { count: 0 }, 100);
let results = pool.broadcast(());
assert_eq!(results.len(), 4);
for result in results {
assert_eq!(result.unwrap(), 1);
}
pool.shutdown();
}
#[test]
fn test_request_router() {
let pool = Arc::new(ActorPool::new(4, || EchoHandler, 100));
let router = RequestRouter::new(Arc::clone(&pool));
let result1 = router
.route("Test1".to_string(), AffinityHint::KeyBased(42))
.unwrap();
let result2 = router
.route("Test2".to_string(), AffinityHint::KeyBased(42))
.unwrap();
assert!(result1.starts_with("Echo:"));
assert!(result2.starts_with("Echo:"));
pool.shutdown();
}
}