1use std::collections::HashMap;
9use std::sync::Arc;
10use std::time::{Duration, Instant};
11
12use dashmap::DashMap;
13use parking_lot::RwLock;
14use tokio::sync::mpsc;
15
16use atomr_core::actor::Address;
17
18use crate::endpoint::{spawn_endpoint, EndpointHandle, InboundEnvelope, InboundPdu};
19use crate::failure_detector_registry::FailureDetectorRegistry;
20use crate::metrics::RemoteMetrics;
21use crate::pdu::DisassociateReason;
22use crate::settings::RemoteSettings;
23use crate::transport::akka_protocol::{AkkaProtocolTransport, ProtocolEvent};
24use crate::transport::{Transport, TransportError};
25
26#[derive(Debug, Clone, Copy, PartialEq, Eq)]
34#[non_exhaustive]
35pub enum AssociationState {
36 Idle,
37 Pending,
38 Connected,
39 Quarantined,
40 Tombstoned,
41}
42
43#[derive(Debug, Clone)]
44struct PeerEntry {
45 state: AssociationState,
46 state_since: Instant,
48 attempt: u32,
50}
51
52impl PeerEntry {
53 fn new() -> Self {
54 Self { state: AssociationState::Idle, state_since: Instant::now(), attempt: 0 }
55 }
56
57 fn transition(&mut self, next: AssociationState) {
58 self.state = next;
59 self.state_since = Instant::now();
60 if next == AssociationState::Connected {
61 self.attempt = 0;
62 }
63 }
64}
65
66#[derive(Clone)]
67pub struct EndpointManager {
68 inner: Arc<EndpointManagerInner>,
69}
70
71struct EndpointManagerInner {
72 protocol: Arc<AkkaProtocolTransport>,
73 settings: RemoteSettings,
74 local_address: RwLock<Option<Address>>,
75 endpoints: DashMap<String, EndpointHandle>,
76 peers: RwLock<HashMap<String, PeerEntry>>,
77 inbound_sink: mpsc::UnboundedSender<InboundEnvelope>,
78 inbound_rx: parking_lot::Mutex<Option<mpsc::UnboundedReceiver<InboundEnvelope>>>,
79 failure_detectors: FailureDetectorRegistry,
80 metrics: RemoteMetrics,
81}
82
83impl EndpointManager {
84 pub fn new(protocol: Arc<AkkaProtocolTransport>, settings: RemoteSettings) -> Self {
85 let (inbound_tx, inbound_rx) = mpsc::unbounded_channel();
86 Self {
87 inner: Arc::new(EndpointManagerInner {
88 protocol,
89 settings,
90 local_address: RwLock::new(None),
91 endpoints: DashMap::new(),
92 peers: RwLock::new(HashMap::new()),
93 inbound_sink: inbound_tx,
94 inbound_rx: parking_lot::Mutex::new(Some(inbound_rx)),
95 failure_detectors: FailureDetectorRegistry::default_phi(),
96 metrics: RemoteMetrics::new(),
97 }),
98 }
99 }
100
101 pub fn metrics(&self) -> RemoteMetrics {
102 self.inner.metrics.clone()
103 }
104
105 pub fn failure_detectors(&self) -> FailureDetectorRegistry {
106 self.inner.failure_detectors.clone()
107 }
108
109 pub fn settings(&self) -> &RemoteSettings {
110 &self.inner.settings
111 }
112
113 pub fn protocol(&self) -> Arc<AkkaProtocolTransport> {
114 self.inner.protocol.clone()
115 }
116
117 pub fn local_address(&self) -> Option<Address> {
118 self.inner.local_address.read().clone()
119 }
120
121 pub async fn start(&self) -> Result<Address, TransportError> {
124 let address = self.inner.protocol.start().await?;
125 *self.inner.local_address.write() = Some(address.clone());
126 self.start_dispatch();
127 Ok(address)
128 }
129
130 fn start_dispatch(&self) {
131 let mgr = self.clone();
132 let mut events = self.inner.protocol.events();
133 tokio::spawn(async move {
134 while let Some(ev) = events.recv().await {
135 mgr.dispatch_event(ev).await;
136 }
137 });
138 }
139
140 async fn dispatch_event(&self, ev: ProtocolEvent) {
141 match ev {
142 ProtocolEvent::Associated(assoc) => {
143 self.inner.failure_detectors.heartbeat(&assoc.address);
144 let key = assoc.address.to_string();
145 let mut peers = self.inner.peers.write();
146 let entry = peers.entry(key.clone()).or_insert_with(PeerEntry::new);
147 entry.transition(AssociationState::Connected);
148 drop(peers);
149 if !self.inner.endpoints.contains_key(&key) {
150 let handle = spawn_endpoint(
151 self.inner.protocol.clone(),
152 self.inner.settings.clone(),
153 assoc.address.clone(),
154 assoc.uid,
155 self.inner.inbound_sink.clone(),
156 );
157 self.inner.endpoints.insert(key, handle);
158 } else {
159 if let Some(h) = self.inner.endpoints.get(&key) {
161 h.resend();
162 }
163 }
164 }
165 ProtocolEvent::Disassociated { peer, reason } => {
166 self.inner.failure_detectors.remove(&peer);
167 let key = peer.to_string();
168 if let Some((_, h)) = self.inner.endpoints.remove(&key) {
169 h.shutdown(reason.clone());
170 }
171 let mut peers = self.inner.peers.write();
172 let entry = peers.entry(key.clone()).or_insert_with(PeerEntry::new);
173 match reason {
174 DisassociateReason::Quarantined => {
175 entry.transition(AssociationState::Quarantined);
176 }
177 _ => {
178 entry.transition(AssociationState::Idle);
179 }
180 }
181 }
182 ProtocolEvent::Payload { from, pdu } => {
183 use crate::pdu::AkkaPdu;
184 self.inner.failure_detectors.heartbeat(&from);
185 let key = from.to_string();
186 let bytes = match crate::codec::encode_pdu(&pdu) {
187 Ok(b) => b.len(),
188 Err(_) => 0,
189 };
190 self.inner.metrics.record_receive(&from, bytes);
191 let inbound = match pdu {
192 AkkaPdu::Payload(env) => Some(InboundPdu::Payload(env)),
193 AkkaPdu::Ack(ack) => Some(InboundPdu::Ack(ack)),
194 _ => None,
195 };
196 if let Some(p) = inbound {
197 if let Some(h) = self.inner.endpoints.get(&key) {
198 h.deliver(p);
199 }
200 }
201 }
202 }
203 }
204
205 pub async fn endpoint_for(&self, target: &Address) -> Result<EndpointHandle, TransportError> {
208 let key = target.to_string();
209 if let Some(h) = self.inner.endpoints.get(&key) {
210 return Ok(h.clone());
211 }
212 {
214 let peers = self.inner.peers.read();
215 if let Some(p) = peers.get(&key) {
216 if p.state == AssociationState::Quarantined
217 && p.state_since.elapsed() < self.inner.settings.quarantine_duration
218 {
219 return Err(TransportError::HandshakeRejected(format!("{key} is quarantined")));
220 }
221 if p.state == AssociationState::Tombstoned {
222 return Err(TransportError::HandshakeRejected(format!("{key} is tombstoned")));
223 }
224 }
225 }
226 {
228 let mut peers = self.inner.peers.write();
229 let e = peers.entry(key.clone()).or_insert_with(PeerEntry::new);
230 e.transition(AssociationState::Pending);
231 e.attempt = e.attempt.saturating_add(1);
232 }
233 let local = self.inner.local_address.read().clone().ok_or(TransportError::Closed)?;
234 self.inner.protocol.associate(target, &local).await?;
235
236 let deadline = Instant::now() + self.inner.settings.handshake_timeout;
240 loop {
241 if let Some(h) = self.inner.endpoints.get(&key) {
242 return Ok(h.clone());
243 }
244 if Instant::now() > deadline {
245 let mut peers = self.inner.peers.write();
246 if let Some(e) = peers.get_mut(&key) {
247 e.transition(AssociationState::Idle);
248 }
249 return Err(TransportError::HandshakeRejected(format!("handshake timeout to {target}")));
250 }
251 tokio::time::sleep(Duration::from_millis(20)).await;
252 }
253 }
254
255 pub async fn quarantine(&self, target: &Address) {
258 let key = target.to_string();
259 if let Some((_, h)) = self.inner.endpoints.remove(&key) {
260 h.shutdown(DisassociateReason::Quarantined);
261 }
262 let _ = self.inner.protocol.disassociate(target, DisassociateReason::Quarantined).await;
263 let mut peers = self.inner.peers.write();
264 let e = peers.entry(key).or_insert_with(PeerEntry::new);
265 e.transition(AssociationState::Quarantined);
266 }
267
268 pub fn tombstone(&self, target: &Address) {
270 let key = target.to_string();
271 if let Some((_, h)) = self.inner.endpoints.remove(&key) {
272 h.shutdown(DisassociateReason::Other("tombstoned".into()));
273 }
274 let mut peers = self.inner.peers.write();
275 let e = peers.entry(key).or_insert_with(PeerEntry::new);
276 e.transition(AssociationState::Tombstoned);
277 }
278
279 pub fn purge_tombstones(&self, older_than: Duration) -> usize {
284 let mut peers = self.inner.peers.write();
285 let before = peers.len();
286 peers.retain(|_, e| {
287 !(e.state == AssociationState::Tombstoned && e.state_since.elapsed() >= older_than)
288 });
289 before - peers.len()
290 }
291
292 pub fn peer_state(&self, target: &Address) -> Option<AssociationState> {
295 self.inner.peers.read().get(&target.to_string()).map(|e| e.state)
296 }
297
298 pub fn take_inbound(&self) -> mpsc::UnboundedReceiver<InboundEnvelope> {
302 self.inner.inbound_rx.lock().take().unwrap_or_else(|| {
303 let (_t, r) = mpsc::unbounded_channel();
304 r
305 })
306 }
307
308 pub fn peer_states(&self) -> Vec<(String, &'static str, u32)> {
310 self.inner.peers.read().iter().map(|(k, p)| (k.clone(), state_name(p.state), p.attempt)).collect()
311 }
312
313 pub async fn shutdown(&self) -> Result<(), TransportError> {
314 for kv in self.inner.endpoints.iter() {
315 kv.value().shutdown(DisassociateReason::Normal);
316 }
317 self.inner.endpoints.clear();
318 self.inner.protocol.shutdown().await
319 }
320}
321
322fn state_name(s: AssociationState) -> &'static str {
323 match s {
324 AssociationState::Idle => "idle",
325 AssociationState::Pending => "pending",
326 AssociationState::Connected => "connected",
327 AssociationState::Quarantined => "quarantined",
328 AssociationState::Tombstoned => "tombstoned",
329 }
330}