1use std::{future::Future, sync::atomic::AtomicU64};
2
3use async_channel::{Sender, WeakSender};
4
5use crate::{
6 executor::Context,
7 message::{Envelope, Message},
8};
9
10static ADDRESS_COUNTER: AtomicU64 = AtomicU64::new(0);
11
12pub trait Actor: Sized {
18 fn starting(&mut self, _ctx: &Context<Self>) -> impl Future<Output = ()> + Send {
19 std::future::ready(())
20 }
21
22 fn stopping(&mut self, _ctx: &Context<Self>) -> impl Future<Output = ()> + Send {
23 std::future::ready(())
24 }
25}
26
27pub trait Handler<M>
32where
33 Self: Actor,
34 M: Message,
35{
36 fn handle(&mut self, msg: M, ctx: &Context<Self>) -> impl Future<Output = ()> + Send;
38}
39
40#[derive(Debug)]
45pub struct Address<A> {
46 id: u64,
47 sender: Sender<Envelope<A>>,
48}
49
50impl<A> PartialEq for Address<A> {
51 fn eq(&self, other: &Self) -> bool {
52 self.id == other.id
53 }
54}
55
56unsafe impl<A> std::marker::Send for Address<A> {}
60unsafe impl<A> std::marker::Sync for Address<A> {}
62impl<A> std::marker::Unpin for Address<A> {}
63
64impl<A> Clone for Address<A> {
65 fn clone(&self) -> Self {
66 Self {
67 sender: self.sender.clone(),
68 id: self.id,
69 }
70 }
71}
72
73impl<A> Address<A> {
74 pub(crate) fn new(sender: Sender<Envelope<A>>) -> Self {
75 let id = ADDRESS_COUNTER.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
76
77 Self { sender, id }
78 }
79
80 pub fn downgrade(&self) -> WeakAddress<A> {
81 let sender = self.sender.downgrade();
82 WeakAddress::new(self.id, sender)
83 }
84
85 #[cfg(feature = "mocking")]
94 pub fn new_leak(cap: usize) -> Self {
95 let (sender, receiver) = async_channel::bounded::<Envelope<A>>(cap);
96 Box::leak(Box::new(receiver));
97 Self::new(sender)
98 }
99}
100
101impl<A> Address<A>
102where
103 A: 'static + Actor + Send,
104{
105 pub async fn send<M>(&self, message: M)
109 where
110 A: Handler<M>,
111 M: Message,
112 {
113 let env = Envelope::pack(message);
114
115 let _ = self.sender.send(env).await;
117 }
118
119 pub fn try_send<M>(&self, message: M)
120 where
121 A: Handler<M>,
122 M: Message,
123 {
124 let env = Envelope::pack(message);
125
126 let _ = self.sender.try_send(env);
128 }
129}
130
131#[derive(Debug)]
136pub struct WeakAddress<A> {
137 id: u64,
138 sender: WeakSender<Envelope<A>>,
139}
140
141impl<A> Clone for WeakAddress<A> {
142 fn clone(&self) -> Self {
143 Self {
144 id: self.id,
145 sender: self.sender.clone(),
146 }
147 }
148}
149
150impl<A> WeakAddress<A> {
151 pub(crate) fn new(id: u64, sender: WeakSender<Envelope<A>>) -> Self {
152 Self { id, sender }
153 }
154
155 pub fn upgrade(&self) -> Option<Address<A>> {
156 let sender = self.sender.upgrade()?;
157 Some(Address::new(sender))
158 }
159}
160
161unsafe impl<A> std::marker::Send for WeakAddress<A> {}
165unsafe impl<A> std::marker::Sync for WeakAddress<A> {}
167impl<A> std::marker::Unpin for WeakAddress<A> {}
168
169#[cfg(test)]
170mod test {
171 use std::sync::Mutex;
172
173 use crate::Executor;
174
175 use super::*;
176
177 struct Msg;
178 struct Act;
179 impl Actor for Act {}
180 impl Handler<Msg> for Act {
181 async fn handle(&mut self, _msg: Msg, _ctx: &Context<Self>) {}
182 }
183
184 #[test]
185 fn partial_eq_on_clone() {
186 let (_executor, address) = Executor::new(Act);
187 let same_address = address.clone();
188 assert!(address.eq(&same_address));
189 }
190
191 #[test]
192 fn partial_eq_on_different_addrs() {
193 let (_executor_1, address_1) = Executor::new(Act);
194 let (_executor_2, address_2) = Executor::new(Act);
195 assert!(address_1.ne(&address_2));
196 }
197
198 #[test]
199 fn partial_eq_on_a_thousand_different_addrs() {
200 let mut addrs: Vec<Address<Act>> = Vec::new();
201 for _ in 0..1_000 {
202 let (_executor_1, address) = Executor::new(Act);
203 for addr in addrs.iter() {
204 assert!(addr.ne(&address));
205 }
206 addrs.push(address);
207 }
208 }
209
210 #[test]
211 fn partial_eq_on_a_thousand_different_threads() {
212 const NUM_THREAD: usize = 1_000;
213 let addrs = Mutex::new(Vec::<Address<Act>>::new());
214 std::thread::scope(|s| {
215 for _ in 0..NUM_THREAD {
216 s.spawn(|| {
217 let (_executor_1, address) = Executor::new(Act);
218 addrs.lock().unwrap().push(address);
219 });
220 }
221 });
222 let addrs = std::mem::take(&mut *addrs.lock().unwrap());
223 assert_eq!(addrs.len(), NUM_THREAD);
224 for i in 0..NUM_THREAD {
225 for j in (i + 1)..NUM_THREAD {
226 assert!(addrs[i].ne(&addrs[j]))
227 }
228 }
229 }
230}