Skip to main content

rs_ctrl_os/
comms.rs

1use crate::config::StaticBase;
2use crate::discovery::ServiceRegistry;
3use crate::error::{Result, RsCtrlError};
4use bincode;
5use once_cell::sync::Lazy;
6use std::collections::{HashMap, HashSet};
7use std::time::{Duration, Instant};
8use tracing::{debug, info, warn};
9use zmq::{Context, Socket};
10
11static ZMQ_CONTEXT: Lazy<Context> = Lazy::new(|| Context::new());
12
13// --- RPC binary envelope ---
14// Wraps user payload inside a 10-byte header for request/response correlation.
15// Wire format inside payload frame: [magic: 1B][type: 1B][request_id: u64 LE]
16
17const RPC_MAGIC: u8 = 0x52; // 'R'
18const RPC_MSG_REQUEST: u8 = 0x01;
19const RPC_MSG_RESPONSE: u8 = 0x02;
20const RPC_HEADER_LEN: usize = 10;
21
22fn build_rpc_header(msg_type: u8, request_id: u64) -> [u8; RPC_HEADER_LEN] {
23    let mut hdr = [0u8; RPC_HEADER_LEN];
24    hdr[0] = RPC_MAGIC;
25    hdr[1] = msg_type;
26    hdr[2..RPC_HEADER_LEN].copy_from_slice(&request_id.to_le_bytes());
27    hdr
28}
29
30fn parse_rpc_header(data: &[u8]) -> Option<(u8, u64)> {
31    if data.len() < RPC_HEADER_LEN || data[0] != RPC_MAGIC {
32        return None;
33    }
34    let msg_type = data[1];
35    if msg_type != RPC_MSG_REQUEST && msg_type != RPC_MSG_RESPONSE {
36        return None;
37    }
38    let rid = u64::from_le_bytes(data[2..RPC_HEADER_LEN].try_into().unwrap());
39    Some((msg_type, rid))
40}
41
42/// 解析 "host:port" 或 "[::1]:port",不进行 DNS 解析。
43fn parse_host_port(s: &str) -> Option<(String, u16)> {
44    let idx = s.rfind(':')?;
45    if idx == 0 {
46        return None;
47    }
48    let (host, port_part) = s.split_at(idx);
49    let port = port_part[1..].parse::<u16>().ok()?;
50    if host.is_empty() {
51        return None;
52    }
53    Some((host.to_string(), port))
54}
55
56struct SubSocket {
57    socket: Socket,
58    /// 若为空集,则不过滤 sub_topic;非空时,仅保留在集合内的 sub_topic。
59    topics: HashSet<String>,
60}
61
62/// Pub/Sub 管理器
63///
64/// - 发布频率控制:通过 `publish_hz` 限制 `publish_topic` 的最大发送速率。
65/// - 订阅频率控制:通过 `subscribe_hz` 限制 `try_recv_*` 的轮询频率。
66///   频率配置建议从各节点的 `[dynamic]` 中传入(如 `publish_hz` / `subscribe_hz`)。
67pub struct PubSubManager {
68    shared_pub: Option<Socket>,
69    subs: HashMap<String, SubSocket>,
70    registry: ServiceRegistry,
71    /// node_id -> (host, port),discovery 失败时的 fallback
72    static_nodes: HashMap<String, (String, u16)>,
73    pending_subs: HashMap<String, String>,
74    my_id: String,
75
76    // 频率控制(节点级别)
77    publish_hz: i64,
78    subscribe_hz: i64,
79    last_publish: HashMap<String, Instant>, // 按 topic_key 跟踪
80    last_sub_poll: HashMap<String, Instant>, // 按 local_name 跟踪
81}
82
83impl PubSubManager {
84    pub fn new(static_cfg: &StaticBase, registry: ServiceRegistry) -> Result<Self> {
85        // Validate: only "self" publishers are supported.
86        for (topic_key, target) in &static_cfg.publishers {
87            if target != "self" {
88                return Err(RsCtrlError::Config(format!(
89                    "publisher '{}' has target '{}' — only \"self\" is supported",
90                    topic_key, target
91                )));
92            }
93        }
94
95        let mut subs = HashMap::new();
96        let mut pending_subs = HashMap::new();
97
98        let shared_pub = if static_cfg.publishers.is_empty() {
99            None
100        } else {
101            let socket = ZMQ_CONTEXT.socket(zmq::PUB)?;
102            let endpoint = format!("tcp://{}:{}", static_cfg.host, static_cfg.port);
103            socket.set_sndhwm(1000)?;
104            socket.bind(&endpoint)?;
105            info!("📢 [PUB] bound to {} (topics: {:?})", endpoint, static_cfg.publishers);
106            Some(socket)
107        };
108
109        let static_nodes: HashMap<String, (String, u16)> = static_cfg
110            .static_nodes
111            .iter()
112            .filter_map(|(k, v)| parse_host_port(v).map(|hp| (k.clone(), hp)))
113            .collect();
114
115        for (local_name, target_node_id) in &static_cfg.subscribers {
116            let addr = registry.get_address(target_node_id).or_else(|| {
117                static_nodes
118                    .get(target_node_id)
119                    .map(|(h, p)| (h.clone(), *p))
120            });
121            if let Some((host, port)) = addr {
122                Self::connect_sub(&mut subs, local_name, target_node_id, &host, port)?;
123            } else {
124                warn!("⏳ [SUB] '{}' waiting for '{}'", local_name, target_node_id);
125                pending_subs.insert(local_name.clone(), target_node_id.clone());
126            }
127        }
128
129        Ok(Self {
130            shared_pub,
131            subs,
132            registry,
133            static_nodes,
134            pending_subs,
135            my_id: static_cfg.my_id.clone(),
136            publish_hz: static_cfg.publish_hz,
137            subscribe_hz: static_cfg.subscribe_hz,
138            last_publish: HashMap::new(),
139            last_sub_poll: HashMap::new(),
140        })
141    }
142
143    fn connect_sub(
144        subs: &mut HashMap<String, SubSocket>,
145        local_name: &str,
146        target_id: &str,
147        host: &str,
148        port: u16,
149    ) -> Result<()> {
150        let socket = ZMQ_CONTEXT.socket(zmq::SUB)?;
151        let endpoint = format!("tcp://{}:{}", host, port);
152        socket.connect(&endpoint)?;
153        socket.set_subscribe(b"")?; // Subscribe all, filter by app logic
154        socket.set_rcvtimeo(0)?;
155        socket.set_reconnect_ivl(100)?;
156        socket.set_reconnect_ivl_max(5000)?;
157        socket.set_rcvhwm(1000)?;
158
159        info!(
160            "🔗 [SUB] '{}' connected to {} (Target: {})",
161            local_name, endpoint, target_id
162        );
163        subs.insert(
164            local_name.to_string(),
165            SubSocket {
166                socket,
167                topics: HashSet::new(),
168            },
169        );
170        Ok(())
171    }
172
173    /// 设置节点级发布频率(Hz)。
174    ///
175    /// - `hz > 0`:对所有 `publish_topic` 生效,按最小时间间隔限频;
176    /// - `hz = 0`:动态频率(有多少发多快,仍受 ZMQ HWM 影响)。
177    /// - `hz < 0`:不发布(publish_topic 直接返回 Ok(()))。
178    pub fn set_publish_hz(&mut self, hz: i64) {
179        self.publish_hz = hz;
180    }
181
182    /// 设置节点级订阅/处理频率(Hz)。
183    ///
184    /// - `hz > 0`:对所有 `try_recv_*` 生效,按最小时间间隔限频;
185    /// - `hz = 0`:动态频率(按调用频率尝试收取)。
186    /// - `hz < 0`:不订阅/不消费(直接返回 Ok(None))。
187    pub fn set_subscribe_hz(&mut self, hz: i64) {
188        self.subscribe_hz = hz;
189    }
190
191    /// 为指定本地订阅名配置需要保留的 sub_topic 列表。
192    ///
193    /// - `topics` 为空:不过滤任何 sub_topic(保留所有)。
194    /// - 非空:仅当收到的 sub_topic 在此列表中时才返回;其他 sub_topic 会被静默丢弃。
195    pub fn set_sub_topics<S: AsRef<str>>(&mut self, local_name: &str, topics: &[S]) -> Result<()> {
196        let entry = self
197            .subs
198            .get_mut(local_name)
199            .ok_or_else(|| RsCtrlError::Comms(format!("SUB '{}' not found", local_name)))?;
200        entry.topics.clear();
201        for t in topics {
202            entry.topics.insert(t.as_ref().to_string());
203        }
204        Ok(())
205    }
206
207    pub fn tick(&mut self) -> Result<()> {
208        let mut to_connect = Vec::new();
209        for (local_name, target_id) in &self.pending_subs.clone() {
210            let addr = self.registry.get_address(target_id).or_else(|| {
211                self.static_nodes
212                    .get(target_id)
213                    .map(|(h, p)| (h.clone(), *p))
214            });
215            if let Some((host, port)) = addr {
216                to_connect.push((local_name.clone(), target_id.clone(), host, port));
217            }
218        }
219        for (local_name, target_id, host, port) in to_connect {
220            match Self::connect_sub(&mut self.subs, &local_name, &target_id, &host, port) {
221                Ok(_) => {
222                    self.pending_subs.remove(&local_name);
223                }
224                Err(e) => warn!("Failed to connect {} to {}: {}", local_name, target_id, e),
225            }
226        }
227        Ok(())
228    }
229
230    fn trim_stale_rate_entries(map: &mut HashMap<String, Instant>, now: Instant) {
231        if map.len() > 64 {
232            map.retain(|_, v| now.duration_since(*v) < Duration::from_secs(60));
233        }
234    }
235
236    /// Core send: builds 3-frame multipart and sends on the PUB socket.
237    /// When `bypass_rate` is true, the publish_hz rate limiter is skipped.
238    fn send_raw_inner(
239        &mut self,
240        topic_key: &str,
241        sub_topic: &str,
242        payload: &[u8],
243        bypass_rate: bool,
244    ) -> Result<()> {
245        if self.publish_hz < 0 {
246            return Ok(());
247        }
248        if self.publish_hz > 0 && !bypass_rate {
249            let now = Instant::now();
250            let min_interval = Duration::from_secs_f64(1.0 / self.publish_hz as f64);
251            if let Some(last) = self.last_publish.get(topic_key) {
252                if now.duration_since(*last) < min_interval {
253                    return Ok(());
254                }
255            }
256            self.last_publish.insert(topic_key.to_string(), now);
257            Self::trim_stale_rate_entries(&mut self.last_publish, now);
258        }
259
260        let socket = self.shared_pub.as_ref().ok_or_else(|| {
261            RsCtrlError::Comms(format!("Pub key '{}' not initialized", topic_key))
262        })?;
263
264        let id_bytes = self.my_id.as_bytes();
265        let topic_bytes = sub_topic.as_bytes();
266
267        match socket.send_multipart(&[id_bytes, topic_bytes, payload], zmq::DONTWAIT) {
268            Ok(_) => Ok(()),
269            Err(e) if e == zmq::Error::EAGAIN => Ok(()),
270            Err(e) => Err(RsCtrlError::Zmq(e)),
271        }
272    }
273
274    /// 发布原始字节(不经过 serde/bincode,直接透传)。
275    /// 适用于图像、点云等已编码的二进制数据(JPEG、压缩点云等)。
276    /// 频率控制与 `publish_topic` 共享。
277    pub fn publish_raw(&mut self, topic_key: &str, sub_topic: &str, payload: &[u8]) -> Result<()> {
278        self.send_raw_inner(topic_key, sub_topic, payload, false)
279    }
280
281    /// 发布特定子话题 (Bincode 序列化)
282    pub fn publish_topic<T: serde::Serialize>(
283        &mut self,
284        topic_key: &str,
285        sub_topic: &str,
286        data: &T,
287    ) -> Result<()> {
288        let payload = bincode::serialize(data)?;
289        self.send_raw_inner(topic_key, sub_topic, &payload, false)
290    }
291
292    /// Core receive: performs tick + rate limiting + ZMQ recv + topic filter.
293    /// Returns all 3 frames: (sender_id, sub_topic, payload).
294    fn try_recv_inner(&mut self, local_name: &str) -> Result<Option<(String, String, Vec<u8>)>> {
295        let _ = self.tick();
296
297        if self.subscribe_hz < 0 {
298            return Ok(None);
299        }
300        if self.subscribe_hz > 0 {
301            let now = Instant::now();
302            let min_interval = Duration::from_secs_f64(1.0 / self.subscribe_hz as f64);
303            if let Some(last) = self.last_sub_poll.get(local_name) {
304                if now.duration_since(*last) < min_interval {
305                    return Ok(None);
306                }
307            }
308            self.last_sub_poll.insert(local_name.to_string(), now);
309            Self::trim_stale_rate_entries(&mut self.last_sub_poll, now);
310        }
311
312        let Some(sub_entry) = self.subs.get(local_name) else {
313            return Ok(None);
314        };
315
316        match sub_entry.socket.recv_multipart(0) {
317            Ok(frames) => {
318                if frames.len() < 3 {
319                    return Ok(None);
320                }
321                let sender_id = String::from_utf8_lossy(&frames[0]).to_string();
322                let sub_topic = String::from_utf8_lossy(&frames[1]).to_string();
323
324                if let Some(entry) = self.subs.get(local_name) {
325                    if !entry.topics.is_empty() && !entry.topics.contains(&sub_topic) {
326                        return Ok(None);
327                    }
328                }
329
330                let payload = frames[2].to_vec();
331                Ok(Some((sender_id, sub_topic, payload)))
332            }
333            Err(e) if e == zmq::Error::EAGAIN => Ok(None),
334            Err(e) => {
335                debug!("Recv error on {}: {}", local_name, e);
336                Ok(None)
337            }
338        }
339    }
340
341    /// 接收原始字节 (由用户反序列化)
342    /// 返回 (sender_id, sub_topic, payload),其中 sender_id 是发布者的 node_id。
343    /// 内部自动调用 tick(),无需在主循环手动调用。
344    pub fn try_recv_raw(
345        &mut self,
346        local_name: &str,
347    ) -> Result<Option<(String, String, Vec<u8>)>> {
348        self.try_recv_inner(local_name)
349    }
350
351    /// 辅助:直接接收并反序列化为特定类型 (如果知道具体话题)
352    pub fn try_recv_specific<T: for<'de> serde::Deserialize<'de>>(
353        &mut self,
354        local_name: &str,
355        target_sub: &str,
356    ) -> Result<Option<T>> {
357        if let Some((_sender, topic, bytes)) = self.try_recv_raw(local_name)? {
358            if topic == target_sub {
359                let data = bincode::deserialize(&bytes)?;
360                return Ok(Some(data));
361            }
362        }
363        Ok(None)
364    }
365
366    // --- RPC methods: request-response on top of PUB/SUB ---
367    //
368    // RPC messages use a 10-byte binary envelope prepended to the payload:
369    //   [magic: 'R'][type: 0x01=req/0x02=res][request_id: u64 LE]
370    //
371    // These methods **bypass** the publish_hz rate limiter so that
372    // imperative commands (e.g. emergency stop) are never silently dropped.
373
374    /// 发布 RPC 请求。自动绕过发布频率限制。
375    pub fn publish_request(
376        &mut self,
377        topic_key: &str,
378        sub_topic: &str,
379        request_id: u64,
380        payload: &[u8],
381    ) -> Result<()> {
382        let header = build_rpc_header(RPC_MSG_REQUEST, request_id);
383        let mut buf = Vec::with_capacity(RPC_HEADER_LEN + payload.len());
384        buf.extend_from_slice(&header);
385        buf.extend_from_slice(payload);
386        self.send_raw_inner(topic_key, sub_topic, &buf, true)
387    }
388
389    /// 发布 RPC 响应。自动绕过发布频率限制。
390    pub fn publish_response(
391        &mut self,
392        topic_key: &str,
393        sub_topic: &str,
394        request_id: u64,
395        payload: &[u8],
396    ) -> Result<()> {
397        let header = build_rpc_header(RPC_MSG_RESPONSE, request_id);
398        let mut buf = Vec::with_capacity(RPC_HEADER_LEN + payload.len());
399        buf.extend_from_slice(&header);
400        buf.extend_from_slice(payload);
401        self.send_raw_inner(topic_key, sub_topic, &buf, true)
402    }
403
404    /// 接收 RPC 请求。
405    /// 返回 `(sender_id, request_id, sub_topic, payload)`。
406    /// 非 RPC 请求的消息会被静默丢弃,请通过 sub_topic 分离普通流量和 RPC 流量。
407    pub fn try_recv_request(
408        &mut self,
409        local_name: &str,
410    ) -> Result<Option<(String, u64, String, Vec<u8>)>> {
411        if let Some((sender, sub_topic, raw)) = self.try_recv_inner(local_name)? {
412            if let Some((msg_type, rid)) = parse_rpc_header(&raw) {
413                if msg_type == RPC_MSG_REQUEST {
414                    let payload = raw[RPC_HEADER_LEN..].to_vec();
415                    return Ok(Some((sender, rid, sub_topic, payload)));
416                }
417            }
418        }
419        Ok(None)
420    }
421
422    /// 接收 RPC 响应。
423    /// 返回 `(sender_id, request_id, sub_topic, payload)`。
424    /// 非 RPC 响应的消息会被静默丢弃。
425    pub fn try_recv_response(
426        &mut self,
427        local_name: &str,
428    ) -> Result<Option<(String, u64, String, Vec<u8>)>> {
429        if let Some((sender, sub_topic, raw)) = self.try_recv_inner(local_name)? {
430            if let Some((msg_type, rid)) = parse_rpc_header(&raw) {
431                if msg_type == RPC_MSG_RESPONSE {
432                    let payload = raw[RPC_HEADER_LEN..].to_vec();
433                    return Ok(Some((sender, rid, sub_topic, payload)));
434                }
435            }
436        }
437        Ok(None)
438    }
439}