1use crate::config::StaticBase;
2use crate::discovery::ServiceRegistry;
3use crate::error::{Result, RsCtrlError};
4use bincode;
5use once_cell::sync::Lazy;
6use std::collections::{HashMap, HashSet};
7use std::time::{Duration, Instant};
8use tracing::{debug, info, warn};
9use zmq::{Context, Socket};
10
11static ZMQ_CONTEXT: Lazy<Context> = Lazy::new(|| Context::new());
12
13const RPC_MAGIC: u8 = 0x52; const RPC_MSG_REQUEST: u8 = 0x01;
19const RPC_MSG_RESPONSE: u8 = 0x02;
20const RPC_HEADER_LEN: usize = 10;
21
22fn build_rpc_header(msg_type: u8, request_id: u64) -> [u8; RPC_HEADER_LEN] {
23 let mut hdr = [0u8; RPC_HEADER_LEN];
24 hdr[0] = RPC_MAGIC;
25 hdr[1] = msg_type;
26 hdr[2..RPC_HEADER_LEN].copy_from_slice(&request_id.to_le_bytes());
27 hdr
28}
29
30fn parse_rpc_header(data: &[u8]) -> Option<(u8, u64)> {
31 if data.len() < RPC_HEADER_LEN || data[0] != RPC_MAGIC {
32 return None;
33 }
34 let msg_type = data[1];
35 if msg_type != RPC_MSG_REQUEST && msg_type != RPC_MSG_RESPONSE {
36 return None;
37 }
38 let rid = u64::from_le_bytes(data[2..RPC_HEADER_LEN].try_into().unwrap());
39 Some((msg_type, rid))
40}
41
42fn parse_host_port(s: &str) -> Option<(String, u16)> {
44 let idx = s.rfind(':')?;
45 if idx == 0 {
46 return None;
47 }
48 let (host, port_part) = s.split_at(idx);
49 let port = port_part[1..].parse::<u16>().ok()?;
50 if host.is_empty() {
51 return None;
52 }
53 Some((host.to_string(), port))
54}
55
56struct SubSocket {
57 socket: Socket,
58 topics: HashSet<String>,
60}
61
62pub struct PubSubManager {
68 shared_pub: Option<Socket>,
69 subs: HashMap<String, SubSocket>,
70 registry: ServiceRegistry,
71 static_nodes: HashMap<String, (String, u16)>,
73 pending_subs: HashMap<String, String>,
74 my_id: String,
75
76 publish_hz: i64,
78 subscribe_hz: i64,
79 last_publish: HashMap<String, Instant>, last_sub_poll: HashMap<String, Instant>, }
82
83impl PubSubManager {
84 pub fn new(static_cfg: &StaticBase, registry: ServiceRegistry) -> Result<Self> {
85 for (topic_key, target) in &static_cfg.publishers {
87 if target != "self" {
88 return Err(RsCtrlError::Config(format!(
89 "publisher '{}' has target '{}' — only \"self\" is supported",
90 topic_key, target
91 )));
92 }
93 }
94
95 let mut subs = HashMap::new();
96 let mut pending_subs = HashMap::new();
97
98 let shared_pub = if static_cfg.publishers.is_empty() {
99 None
100 } else {
101 let socket = ZMQ_CONTEXT.socket(zmq::PUB)?;
102 let endpoint = format!("tcp://{}:{}", static_cfg.host, static_cfg.port);
103 socket.set_sndhwm(1000)?;
104 socket.bind(&endpoint)?;
105 info!("📢 [PUB] bound to {} (topics: {:?})", endpoint, static_cfg.publishers);
106 Some(socket)
107 };
108
109 let static_nodes: HashMap<String, (String, u16)> = static_cfg
110 .static_nodes
111 .iter()
112 .filter_map(|(k, v)| parse_host_port(v).map(|hp| (k.clone(), hp)))
113 .collect();
114
115 for (local_name, target_node_id) in &static_cfg.subscribers {
116 let addr = registry.get_address(target_node_id).or_else(|| {
117 static_nodes
118 .get(target_node_id)
119 .map(|(h, p)| (h.clone(), *p))
120 });
121 if let Some((host, port)) = addr {
122 Self::connect_sub(&mut subs, local_name, target_node_id, &host, port)?;
123 } else {
124 warn!("⏳ [SUB] '{}' waiting for '{}'", local_name, target_node_id);
125 pending_subs.insert(local_name.clone(), target_node_id.clone());
126 }
127 }
128
129 Ok(Self {
130 shared_pub,
131 subs,
132 registry,
133 static_nodes,
134 pending_subs,
135 my_id: static_cfg.my_id.clone(),
136 publish_hz: static_cfg.publish_hz,
137 subscribe_hz: static_cfg.subscribe_hz,
138 last_publish: HashMap::new(),
139 last_sub_poll: HashMap::new(),
140 })
141 }
142
143 fn connect_sub(
144 subs: &mut HashMap<String, SubSocket>,
145 local_name: &str,
146 target_id: &str,
147 host: &str,
148 port: u16,
149 ) -> Result<()> {
150 let socket = ZMQ_CONTEXT.socket(zmq::SUB)?;
151 let endpoint = format!("tcp://{}:{}", host, port);
152 socket.connect(&endpoint)?;
153 socket.set_subscribe(b"")?; socket.set_rcvtimeo(0)?;
155 socket.set_reconnect_ivl(100)?;
156 socket.set_reconnect_ivl_max(5000)?;
157 socket.set_rcvhwm(1000)?;
158
159 info!(
160 "🔗 [SUB] '{}' connected to {} (Target: {})",
161 local_name, endpoint, target_id
162 );
163 subs.insert(
164 local_name.to_string(),
165 SubSocket {
166 socket,
167 topics: HashSet::new(),
168 },
169 );
170 Ok(())
171 }
172
173 pub fn set_publish_hz(&mut self, hz: i64) {
179 self.publish_hz = hz;
180 }
181
182 pub fn set_subscribe_hz(&mut self, hz: i64) {
188 self.subscribe_hz = hz;
189 }
190
191 pub fn set_sub_topics<S: AsRef<str>>(&mut self, local_name: &str, topics: &[S]) -> Result<()> {
196 let entry = self
197 .subs
198 .get_mut(local_name)
199 .ok_or_else(|| RsCtrlError::Comms(format!("SUB '{}' not found", local_name)))?;
200 entry.topics.clear();
201 for t in topics {
202 entry.topics.insert(t.as_ref().to_string());
203 }
204 Ok(())
205 }
206
207 pub fn tick(&mut self) -> Result<()> {
208 let mut to_connect = Vec::new();
209 for (local_name, target_id) in &self.pending_subs.clone() {
210 let addr = self.registry.get_address(target_id).or_else(|| {
211 self.static_nodes
212 .get(target_id)
213 .map(|(h, p)| (h.clone(), *p))
214 });
215 if let Some((host, port)) = addr {
216 to_connect.push((local_name.clone(), target_id.clone(), host, port));
217 }
218 }
219 for (local_name, target_id, host, port) in to_connect {
220 match Self::connect_sub(&mut self.subs, &local_name, &target_id, &host, port) {
221 Ok(_) => {
222 self.pending_subs.remove(&local_name);
223 }
224 Err(e) => warn!("Failed to connect {} to {}: {}", local_name, target_id, e),
225 }
226 }
227 Ok(())
228 }
229
230 fn trim_stale_rate_entries(map: &mut HashMap<String, Instant>, now: Instant) {
231 if map.len() > 64 {
232 map.retain(|_, v| now.duration_since(*v) < Duration::from_secs(60));
233 }
234 }
235
236 fn send_raw_inner(
239 &mut self,
240 topic_key: &str,
241 sub_topic: &str,
242 payload: &[u8],
243 bypass_rate: bool,
244 ) -> Result<()> {
245 if self.publish_hz < 0 {
246 return Ok(());
247 }
248 if self.publish_hz > 0 && !bypass_rate {
249 let now = Instant::now();
250 let min_interval = Duration::from_secs_f64(1.0 / self.publish_hz as f64);
251 if let Some(last) = self.last_publish.get(topic_key) {
252 if now.duration_since(*last) < min_interval {
253 return Ok(());
254 }
255 }
256 self.last_publish.insert(topic_key.to_string(), now);
257 Self::trim_stale_rate_entries(&mut self.last_publish, now);
258 }
259
260 let socket = self.shared_pub.as_ref().ok_or_else(|| {
261 RsCtrlError::Comms(format!("Pub key '{}' not initialized", topic_key))
262 })?;
263
264 let id_bytes = self.my_id.as_bytes();
265 let topic_bytes = sub_topic.as_bytes();
266
267 match socket.send_multipart(&[id_bytes, topic_bytes, payload], zmq::DONTWAIT) {
268 Ok(_) => Ok(()),
269 Err(e) if e == zmq::Error::EAGAIN => Ok(()),
270 Err(e) => Err(RsCtrlError::Zmq(e)),
271 }
272 }
273
274 pub fn publish_raw(&mut self, topic_key: &str, sub_topic: &str, payload: &[u8]) -> Result<()> {
278 self.send_raw_inner(topic_key, sub_topic, payload, false)
279 }
280
281 pub fn publish_topic<T: serde::Serialize>(
283 &mut self,
284 topic_key: &str,
285 sub_topic: &str,
286 data: &T,
287 ) -> Result<()> {
288 let payload = bincode::serialize(data)?;
289 self.send_raw_inner(topic_key, sub_topic, &payload, false)
290 }
291
292 fn try_recv_inner(&mut self, local_name: &str) -> Result<Option<(String, String, Vec<u8>)>> {
295 let _ = self.tick();
296
297 if self.subscribe_hz < 0 {
298 return Ok(None);
299 }
300 if self.subscribe_hz > 0 {
301 let now = Instant::now();
302 let min_interval = Duration::from_secs_f64(1.0 / self.subscribe_hz as f64);
303 if let Some(last) = self.last_sub_poll.get(local_name) {
304 if now.duration_since(*last) < min_interval {
305 return Ok(None);
306 }
307 }
308 self.last_sub_poll.insert(local_name.to_string(), now);
309 Self::trim_stale_rate_entries(&mut self.last_sub_poll, now);
310 }
311
312 let Some(sub_entry) = self.subs.get(local_name) else {
313 return Ok(None);
314 };
315
316 match sub_entry.socket.recv_multipart(0) {
317 Ok(frames) => {
318 if frames.len() < 3 {
319 return Ok(None);
320 }
321 let sender_id = String::from_utf8_lossy(&frames[0]).to_string();
322 let sub_topic = String::from_utf8_lossy(&frames[1]).to_string();
323
324 if let Some(entry) = self.subs.get(local_name) {
325 if !entry.topics.is_empty() && !entry.topics.contains(&sub_topic) {
326 return Ok(None);
327 }
328 }
329
330 let payload = frames[2].to_vec();
331 Ok(Some((sender_id, sub_topic, payload)))
332 }
333 Err(e) if e == zmq::Error::EAGAIN => Ok(None),
334 Err(e) => {
335 debug!("Recv error on {}: {}", local_name, e);
336 Ok(None)
337 }
338 }
339 }
340
341 pub fn try_recv_raw(
345 &mut self,
346 local_name: &str,
347 ) -> Result<Option<(String, String, Vec<u8>)>> {
348 self.try_recv_inner(local_name)
349 }
350
351 pub fn try_recv_specific<T: for<'de> serde::Deserialize<'de>>(
353 &mut self,
354 local_name: &str,
355 target_sub: &str,
356 ) -> Result<Option<T>> {
357 if let Some((_sender, topic, bytes)) = self.try_recv_raw(local_name)? {
358 if topic == target_sub {
359 let data = bincode::deserialize(&bytes)?;
360 return Ok(Some(data));
361 }
362 }
363 Ok(None)
364 }
365
366 pub fn publish_request(
376 &mut self,
377 topic_key: &str,
378 sub_topic: &str,
379 request_id: u64,
380 payload: &[u8],
381 ) -> Result<()> {
382 let header = build_rpc_header(RPC_MSG_REQUEST, request_id);
383 let mut buf = Vec::with_capacity(RPC_HEADER_LEN + payload.len());
384 buf.extend_from_slice(&header);
385 buf.extend_from_slice(payload);
386 self.send_raw_inner(topic_key, sub_topic, &buf, true)
387 }
388
389 pub fn publish_response(
391 &mut self,
392 topic_key: &str,
393 sub_topic: &str,
394 request_id: u64,
395 payload: &[u8],
396 ) -> Result<()> {
397 let header = build_rpc_header(RPC_MSG_RESPONSE, request_id);
398 let mut buf = Vec::with_capacity(RPC_HEADER_LEN + payload.len());
399 buf.extend_from_slice(&header);
400 buf.extend_from_slice(payload);
401 self.send_raw_inner(topic_key, sub_topic, &buf, true)
402 }
403
404 pub fn try_recv_request(
408 &mut self,
409 local_name: &str,
410 ) -> Result<Option<(String, u64, String, Vec<u8>)>> {
411 if let Some((sender, sub_topic, raw)) = self.try_recv_inner(local_name)? {
412 if let Some((msg_type, rid)) = parse_rpc_header(&raw) {
413 if msg_type == RPC_MSG_REQUEST {
414 let payload = raw[RPC_HEADER_LEN..].to_vec();
415 return Ok(Some((sender, rid, sub_topic, payload)));
416 }
417 }
418 }
419 Ok(None)
420 }
421
422 pub fn try_recv_response(
426 &mut self,
427 local_name: &str,
428 ) -> Result<Option<(String, u64, String, Vec<u8>)>> {
429 if let Some((sender, sub_topic, raw)) = self.try_recv_inner(local_name)? {
430 if let Some((msg_type, rid)) = parse_rpc_header(&raw) {
431 if msg_type == RPC_MSG_RESPONSE {
432 let payload = raw[RPC_HEADER_LEN..].to_vec();
433 return Ok(Some((sender, rid, sub_topic, payload)));
434 }
435 }
436 }
437 Ok(None)
438 }
439}