Skip to main content

nautilus_common/python/
msgbus.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2026 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16//! Python bindings for the message bus, including configuration types and the
17//! [`PyMessageBus`] wrapper that routes Python events through the Rust
18//! thread-local [`MessageBus`] via the Any-based dispatch path.
19
20use std::{any::Any, fmt::Debug, rc::Rc};
21
22use ahash::AHashMap;
23use nautilus_core::{UUID4, python::to_pyruntime_err};
24use nautilus_model::identifiers::TraderId;
25use pyo3::{Py, Python, prelude::*, types::PyBytes};
26use ustr::Ustr;
27
28use crate::{
29    enums::SerializationEncoding,
30    msgbus::{
31        self as msgbus_api, BusMessage, MessageBus,
32        core::Subscription,
33        database::{DatabaseConfig, MessageBusConfig},
34        get_message_bus,
35        matching::is_matching,
36        mstr::{Endpoint, MStr, Pattern, Topic},
37        typed_handler::{Handler, ShareableMessageHandler, TypedHandler},
38    },
39};
40
41#[pymethods]
42#[pyo3_stub_gen::derive::gen_stub_pymethods]
43impl BusMessage {
44    #[getter]
45    #[pyo3(name = "topic")]
46    fn py_topic(&self) -> String {
47        self.topic.to_string()
48    }
49
50    #[getter]
51    #[pyo3(name = "payload")]
52    fn py_payload(&self, py: Python<'_>) -> Py<PyBytes> {
53        PyBytes::new(py, self.payload.as_ref()).into()
54    }
55
56    fn __repr__(&self) -> String {
57        format!("{}('{}')", stringify!(BusMessage), self)
58    }
59
60    fn __str__(&self) -> String {
61        self.to_string()
62    }
63}
64
65#[pymethods]
66#[pyo3_stub_gen::derive::gen_stub_pymethods]
67impl DatabaseConfig {
68    /// Configuration for database connections.
69    ///
70    /// # Notes
71    ///
72    /// If `database_type` is `"redis"`, it requires Redis version 6.2 or higher for correct operation.
73    #[new]
74    #[allow(clippy::too_many_arguments)]
75    #[pyo3(signature = (database_type=None, host=None, port=None, username=None, password=None, ssl=None, connection_timeout=None, response_timeout=None, number_of_retries=None, exponent_base=None, max_delay=None, factor=None))]
76    fn py_new(
77        database_type: Option<String>,
78        host: Option<String>,
79        port: Option<u16>,
80        username: Option<String>,
81        password: Option<String>,
82        ssl: Option<bool>,
83        connection_timeout: Option<u16>,
84        response_timeout: Option<u16>,
85        number_of_retries: Option<usize>,
86        exponent_base: Option<u64>,
87        max_delay: Option<u64>,
88        factor: Option<u64>,
89    ) -> Self {
90        let default = Self::default();
91        Self {
92            database_type: database_type.unwrap_or(default.database_type),
93            host,
94            port,
95            username,
96            password,
97            ssl: ssl.unwrap_or(default.ssl),
98            connection_timeout: connection_timeout.unwrap_or(default.connection_timeout),
99            response_timeout: response_timeout.unwrap_or(default.response_timeout),
100            number_of_retries: number_of_retries.unwrap_or(default.number_of_retries),
101            exponent_base: exponent_base.unwrap_or(default.exponent_base),
102            max_delay: max_delay.unwrap_or(default.max_delay),
103            factor: factor.unwrap_or(default.factor),
104        }
105    }
106
107    fn __repr__(&self) -> String {
108        format!("{self:?}")
109    }
110
111    fn __str__(&self) -> String {
112        format!("{self:?}")
113    }
114
115    #[getter]
116    fn database_type(&self) -> &str {
117        &self.database_type
118    }
119
120    #[getter]
121    fn host(&self) -> Option<&str> {
122        self.host.as_deref()
123    }
124
125    #[getter]
126    fn port(&self) -> Option<u16> {
127        self.port
128    }
129
130    #[getter]
131    fn username(&self) -> Option<&str> {
132        self.username.as_deref()
133    }
134
135    #[getter]
136    fn password(&self) -> Option<&str> {
137        self.password.as_deref()
138    }
139
140    #[getter]
141    fn ssl(&self) -> bool {
142        self.ssl
143    }
144
145    #[getter]
146    fn connection_timeout(&self) -> u16 {
147        self.connection_timeout
148    }
149
150    #[getter]
151    fn response_timeout(&self) -> u16 {
152        self.response_timeout
153    }
154
155    #[getter]
156    fn number_of_retries(&self) -> usize {
157        self.number_of_retries
158    }
159
160    #[getter]
161    fn exponent_base(&self) -> u64 {
162        self.exponent_base
163    }
164
165    #[getter]
166    fn max_delay(&self) -> u64 {
167        self.max_delay
168    }
169
170    #[getter]
171    fn factor(&self) -> u64 {
172        self.factor
173    }
174}
175
176#[pymethods]
177#[pyo3_stub_gen::derive::gen_stub_pymethods]
178impl MessageBusConfig {
179    /// Configuration for `MessageBus` instances.
180    #[new]
181    #[allow(clippy::too_many_arguments)]
182    #[pyo3(signature = (database=None, encoding=None, timestamps_as_iso8601=None, buffer_interval_ms=None, autotrim_mins=None, use_trader_prefix=None, use_trader_id=None, use_instance_id=None, streams_prefix=None, stream_per_topic=None, external_streams=None, types_filter=None, heartbeat_interval_secs=None))]
183    fn py_new(
184        database: Option<DatabaseConfig>,
185        encoding: Option<SerializationEncoding>,
186        timestamps_as_iso8601: Option<bool>,
187        buffer_interval_ms: Option<u32>,
188        autotrim_mins: Option<u32>,
189        use_trader_prefix: Option<bool>,
190        use_trader_id: Option<bool>,
191        use_instance_id: Option<bool>,
192        streams_prefix: Option<String>,
193        stream_per_topic: Option<bool>,
194        external_streams: Option<Vec<String>>,
195        types_filter: Option<Vec<String>>,
196        heartbeat_interval_secs: Option<u16>,
197    ) -> Self {
198        let default = Self::default();
199        Self {
200            database,
201            encoding: encoding.unwrap_or(default.encoding),
202            timestamps_as_iso8601: timestamps_as_iso8601.unwrap_or(default.timestamps_as_iso8601),
203            buffer_interval_ms,
204            autotrim_mins,
205            use_trader_prefix: use_trader_prefix.unwrap_or(default.use_trader_prefix),
206            use_trader_id: use_trader_id.unwrap_or(default.use_trader_id),
207            use_instance_id: use_instance_id.unwrap_or(default.use_instance_id),
208            streams_prefix: streams_prefix.unwrap_or(default.streams_prefix),
209            stream_per_topic: stream_per_topic.unwrap_or(default.stream_per_topic),
210            external_streams,
211            types_filter,
212            heartbeat_interval_secs,
213        }
214    }
215
216    fn __repr__(&self) -> String {
217        format!("{self:?}")
218    }
219
220    fn __str__(&self) -> String {
221        format!("{self:?}")
222    }
223
224    #[getter]
225    fn database(&self) -> Option<DatabaseConfig> {
226        self.database.clone()
227    }
228
229    #[getter]
230    fn encoding(&self) -> SerializationEncoding {
231        self.encoding
232    }
233
234    #[getter]
235    fn timestamps_as_iso8601(&self) -> bool {
236        self.timestamps_as_iso8601
237    }
238
239    #[getter]
240    fn buffer_interval_ms(&self) -> Option<u32> {
241        self.buffer_interval_ms
242    }
243
244    #[getter]
245    fn autotrim_mins(&self) -> Option<u32> {
246        self.autotrim_mins
247    }
248
249    #[getter]
250    fn use_trader_prefix(&self) -> bool {
251        self.use_trader_prefix
252    }
253
254    #[getter]
255    fn use_trader_id(&self) -> bool {
256        self.use_trader_id
257    }
258
259    #[getter]
260    fn use_instance_id(&self) -> bool {
261        self.use_instance_id
262    }
263
264    #[getter]
265    fn streams_prefix(&self) -> &str {
266        &self.streams_prefix
267    }
268
269    #[getter]
270    fn stream_per_topic(&self) -> bool {
271        self.stream_per_topic
272    }
273
274    #[getter]
275    fn external_streams(&self) -> Option<Vec<String>> {
276        self.external_streams.clone()
277    }
278
279    #[getter]
280    fn types_filter(&self) -> Option<Vec<String>> {
281        self.types_filter.clone()
282    }
283
284    #[getter]
285    fn heartbeat_interval_secs(&self) -> Option<u16> {
286        self.heartbeat_interval_secs
287    }
288}
289
290/// Wraps a Python object so it can travel through the Rust Any-based message bus.
291pub struct PyMessage(pub Py<PyAny>);
292
293impl Debug for PyMessage {
294    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
295        f.debug_tuple(stringify!(PyMessage))
296            .field(&"<PyObject>")
297            .finish()
298    }
299}
300
301/// Adapts a Python callable as a [`ShareableMessageHandler`].
302///
303/// Expects messages to be [`PyMessage`] instances. Acquires the GIL and calls
304/// the Python callable with the inner Python object.
305pub struct PyCallableHandler {
306    id: Ustr,
307    callable: Py<PyAny>,
308}
309
310impl Debug for PyCallableHandler {
311    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
312        f.debug_struct(stringify!(PyCallableHandler))
313            .field("id", &self.id)
314            .finish()
315    }
316}
317
318impl PyCallableHandler {
319    /// Creates a new handler from a Python callable.
320    ///
321    /// The handler ID is derived from `repr(callable)` for stable identity
322    /// across subscribe/unsubscribe calls.
323    pub fn new(py: Python<'_>, callable: Py<PyAny>) -> PyResult<Self> {
324        let repr_str = callable.bind(py).repr()?.to_string();
325        let id = Ustr::from(&repr_str);
326        Ok(Self { id, callable })
327    }
328}
329
330impl Handler<dyn Any> for PyCallableHandler {
331    fn id(&self) -> Ustr {
332        self.id
333    }
334
335    fn handle(&self, message: &dyn Any) {
336        if let Some(py_msg) = message.downcast_ref::<PyMessage>() {
337            Python::attach(|py| {
338                if let Err(e) = self.callable.call1(py, (&py_msg.0,)) {
339                    log::error!("Python handler {id} failed: {e}", id = self.id);
340                }
341            });
342        } else {
343            log::error!(
344                "Python handler {id} received non-PyMessage type",
345                id = self.id
346            );
347        }
348    }
349}
350
351fn make_handler(py: Python<'_>, callable: Py<PyAny>) -> PyResult<ShareableMessageHandler> {
352    let handler = PyCallableHandler::new(py, callable)?;
353    Ok(TypedHandler(Rc::new(handler) as Rc<dyn Handler<dyn Any>>))
354}
355
356/// Python message bus backed by the Rust thread-local [`MessageBus`].
357///
358/// Provides the same API as the legacy Cython `MessageBus` while routing all
359/// messages through the single Rust bus. Python custom events travel through
360/// the Any-based dispatch path via [`PyMessage`] wrappers.
361#[pyclass(
362    module = "nautilus_trader.core.nautilus_pyo3.common",
363    name = "MessageBus",
364    unsendable
365)]
366#[pyo3_stub_gen::derive::gen_stub_pyclass(module = "nautilus_trader.common")]
367pub struct PyMessageBus {
368    trader_id: TraderId,
369    instance_id: UUID4,
370    name: String,
371    has_backing: bool,
372    serializer: Option<Py<PyAny>>,
373    database: Option<Py<PyAny>>,
374    listeners: Vec<Py<PyAny>>,
375    types_filter: Option<Py<PyAny>>,
376    streaming_types: Vec<Py<PyAny>>,
377    correlation_index: AHashMap<UUID4, Py<PyAny>>,
378    sent_count: u64,
379    req_count: u64,
380    res_count: u64,
381    pub_count: u64,
382}
383
384impl Debug for PyMessageBus {
385    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
386        f.debug_struct(stringify!(PyMessageBus))
387            .field("trader_id", &self.trader_id)
388            .field("name", &self.name)
389            .finish()
390    }
391}
392
393#[pymethods]
394#[pyo3_stub_gen::derive::gen_stub_pymethods]
395impl PyMessageBus {
396    /// Creates a new `MessageBus` instance.
397    ///
398    /// This creates and registers the underlying Rust `MessageBus` as the
399    /// thread-local bus, then wraps it for Python access.
400    #[new]
401    #[pyo3(signature = (trader_id, clock=None, instance_id=None, name=None, serializer=None, database=None, config=None))]
402    #[allow(clippy::too_many_arguments, clippy::needless_pass_by_value)]
403    fn py_new(
404        py: Python<'_>,
405        trader_id: TraderId,
406        clock: Option<Py<PyAny>>,
407        instance_id: Option<UUID4>,
408        name: Option<String>,
409        serializer: Option<Py<PyAny>>,
410        database: Option<Py<PyAny>>,
411        config: Option<Py<PyAny>>,
412    ) -> PyResult<Self> {
413        let _ = clock;
414        let instance_id = instance_id.unwrap_or_default();
415        let bus_name = name.clone();
416        let has_backing = database.is_some();
417
418        let msgbus = MessageBus::new(trader_id, instance_id, bus_name, None);
419        msgbus.register_message_bus();
420
421        let types_filter = if let Some(ref cfg) = config {
422            let tf = cfg.getattr(py, "types_filter")?;
423            if tf.is_none(py) {
424                None
425            } else {
426                // Convert to tuple for isinstance() checks
427                let tuple = py
428                    .import("builtins")?
429                    .call_method1("tuple", (tf,))?
430                    .unbind();
431                Some(tuple)
432            }
433        } else {
434            None
435        };
436
437        Ok(Self {
438            trader_id,
439            instance_id,
440            name: name.unwrap_or_else(|| "MessageBus".to_owned()),
441            has_backing,
442            serializer,
443            database,
444            listeners: Vec::new(),
445            types_filter,
446            streaming_types: Vec::new(),
447            correlation_index: AHashMap::new(),
448            sent_count: 0,
449            req_count: 0,
450            res_count: 0,
451            pub_count: 0,
452        })
453    }
454
455    /// Returns the trader ID associated with the message bus.
456    #[getter]
457    #[pyo3(name = "trader_id")]
458    fn py_trader_id(&self) -> TraderId {
459        self.trader_id
460    }
461
462    /// Returns the instance ID associated with the message bus.
463    #[getter]
464    #[pyo3(name = "instance_id")]
465    fn py_instance_id(&self) -> UUID4 {
466        self.instance_id
467    }
468
469    /// Returns the name of the message bus.
470    #[getter]
471    #[pyo3(name = "name")]
472    fn py_name(&self) -> &str {
473        &self.name
474    }
475
476    /// Returns whether the message bus is backed by a database.
477    #[getter]
478    #[pyo3(name = "has_backing")]
479    fn py_has_backing(&self) -> bool {
480        self.has_backing
481    }
482
483    /// Returns the count of messages sent via point-to-point.
484    #[getter]
485    #[pyo3(name = "sent_count")]
486    fn py_sent_count(&self) -> u64 {
487        self.sent_count
488    }
489
490    /// Returns the count of requests made.
491    #[getter]
492    #[pyo3(name = "req_count")]
493    fn py_req_count(&self) -> u64 {
494        self.req_count
495    }
496
497    /// Returns the count of responses handled.
498    #[getter]
499    #[pyo3(name = "res_count")]
500    fn py_res_count(&self) -> u64 {
501        self.res_count
502    }
503
504    /// Returns the count of messages published.
505    #[getter]
506    #[pyo3(name = "pub_count")]
507    fn py_pub_count(&self) -> u64 {
508        self.pub_count
509    }
510
511    /// Returns all registered endpoint addresses.
512    #[pyo3(name = "endpoints")]
513    fn py_endpoints(&self) -> Vec<String> {
514        let bus = get_message_bus();
515        let bus_ref = bus.borrow();
516        bus_ref.endpoints().into_iter().map(String::from).collect()
517    }
518
519    /// Returns all topics with active subscribers.
520    #[pyo3(name = "topics")]
521    fn py_topics(&self) -> Vec<String> {
522        let bus = get_message_bus();
523        let bus_ref = bus.borrow();
524        let mut topics: Vec<String> = bus_ref.patterns().into_iter().map(String::from).collect();
525        topics.sort();
526        topics.dedup();
527        topics
528    }
529
530    /// Returns subscriptions matching the given topic pattern.
531    #[pyo3(name = "subscriptions")]
532    #[pyo3(signature = (pattern=None))]
533    fn py_subscriptions(&self, pattern: Option<&str>) -> Vec<String> {
534        let bus = get_message_bus();
535        let bus_ref = bus.borrow();
536        let subs: Vec<&Subscription> = bus_ref.subscriptions();
537
538        match pattern {
539            Some(p) => {
540                let filter = MStr::<Pattern>::pattern(p);
541                subs.into_iter()
542                    .filter(|s| is_matching(s.pattern.as_bytes(), filter.as_bytes()))
543                    .map(|s| {
544                        format!(
545                            "Subscription(topic={}, handler={})",
546                            s.pattern, s.handler_id
547                        )
548                    })
549                    .collect()
550            }
551            None => subs
552                .into_iter()
553                .map(|s| {
554                    format!(
555                        "Subscription(topic={}, handler={})",
556                        s.pattern, s.handler_id
557                    )
558                })
559                .collect(),
560        }
561    }
562
563    /// Returns whether there are subscribers for the given topic pattern.
564    #[pyo3(name = "has_subscribers")]
565    #[pyo3(signature = (pattern=None))]
566    fn py_has_subscribers(&self, pattern: Option<&str>) -> bool {
567        let bus = get_message_bus();
568        let bus_ref = bus.borrow();
569
570        match pattern {
571            Some(p) => {
572                let filter = MStr::<Pattern>::pattern(p);
573                bus_ref
574                    .subscriptions()
575                    .iter()
576                    .any(|s| is_matching(s.pattern.as_bytes(), filter.as_bytes()))
577            }
578            None => !bus_ref.subscriptions().is_empty(),
579        }
580    }
581
582    /// Returns whether the given topic and handler is subscribed.
583    #[pyo3(name = "is_subscribed")]
584    fn py_is_subscribed(&self, py: Python<'_>, topic: &str, handler: Py<PyAny>) -> PyResult<bool> {
585        let handler = make_handler(py, handler)?;
586        let pattern = MStr::<Pattern>::pattern(topic);
587        let sub = Subscription::new(pattern, handler, None);
588        Ok(get_message_bus().borrow().subscriptions.contains(&sub))
589    }
590
591    /// Returns whether the given request ID is pending a response.
592    #[pyo3(name = "is_pending_request")]
593    fn py_is_pending_request(&self, request_id: UUID4) -> bool {
594        self.correlation_index.contains_key(&request_id)
595    }
596
597    /// Returns whether the given type is registered for streaming.
598    #[pyo3(name = "is_streaming_type")]
599    #[allow(clippy::needless_pass_by_value)]
600    fn py_is_streaming_type(&self, py: Python<'_>, cls: Py<PyAny>) -> bool {
601        let cls_ref = cls.bind(py);
602        self.streaming_types.iter().any(|t| t.bind(py).is(cls_ref))
603    }
604
605    /// Returns all types registered for streaming.
606    #[pyo3(name = "streaming_types")]
607    fn py_streaming_types(&self, py: Python<'_>) -> Vec<Py<PyAny>> {
608        self.streaming_types
609            .iter()
610            .map(|t| t.clone_ref(py))
611            .collect()
612    }
613
614    /// Registers a handler at the given endpoint address.
615    #[pyo3(name = "register")]
616    fn py_register(&self, py: Python<'_>, endpoint: &str, handler: Py<PyAny>) -> PyResult<()> {
617        let handler = make_handler(py, handler)?;
618        let endpoint = MStr::<Endpoint>::from(endpoint);
619        msgbus_api::register_any(endpoint, handler);
620        Ok(())
621    }
622
623    /// Deregisters the handler from the given endpoint address.
624    #[pyo3(name = "deregister")]
625    #[pyo3(signature = (endpoint, handler=None))]
626    #[allow(clippy::needless_pass_by_value)]
627    fn py_deregister(&self, endpoint: &str, handler: Option<Py<PyAny>>) {
628        let _ = handler;
629        let endpoint = MStr::<Endpoint>::from(endpoint);
630        msgbus_api::deregister_any(endpoint);
631    }
632
633    /// Sends a message to the given endpoint address.
634    #[pyo3(name = "send")]
635    fn py_send(&mut self, endpoint: &str, msg: Py<PyAny>) {
636        let endpoint = MStr::<Endpoint>::from(endpoint);
637        let py_msg = PyMessage(msg);
638        msgbus_api::send_any(endpoint, &py_msg);
639        self.sent_count += 1;
640    }
641
642    /// Sends a request to the given endpoint with correlation tracking.
643    #[pyo3(name = "request")]
644    fn py_request(&mut self, py: Python<'_>, endpoint: &str, request: Py<PyAny>) -> PyResult<()> {
645        let request_ref = request.bind(py);
646
647        let request_id: UUID4 = request_ref.getattr("id")?.extract()?;
648        let callback = request_ref.getattr("callback")?;
649
650        if self.correlation_index.contains_key(&request_id) {
651            log::error!(
652                "Cannot handle request: duplicate ID {request_id} found in correlation index"
653            );
654            return Ok(());
655        }
656
657        if !callback.is_none() {
658            self.correlation_index.insert(request_id, callback.unbind());
659        }
660
661        let endpoint = MStr::<Endpoint>::from(endpoint);
662        let py_msg = PyMessage(request);
663        msgbus_api::send_any(endpoint, &py_msg);
664        self.req_count += 1;
665
666        Ok(())
667    }
668
669    /// Handles a response by invoking the correlated callback.
670    #[pyo3(name = "response")]
671    #[allow(clippy::needless_pass_by_value)]
672    fn py_response(&mut self, py: Python<'_>, response: Py<PyAny>) -> PyResult<()> {
673        let correlation_id: UUID4 = response.getattr(py, "correlation_id")?.extract(py)?;
674
675        if let Some(callback) = self.correlation_index.remove(&correlation_id) {
676            callback.call1(py, (&response,))?;
677        } else {
678            log::debug!("No callback for correlation_id {correlation_id}");
679        }
680
681        self.res_count += 1;
682        Ok(())
683    }
684
685    /// Subscribes to the given topic with the given handler.
686    #[pyo3(name = "subscribe")]
687    #[pyo3(signature = (topic, handler, priority=0))]
688    fn py_subscribe(
689        &self,
690        py: Python<'_>,
691        topic: &str,
692        handler: Py<PyAny>,
693        priority: u8,
694    ) -> PyResult<()> {
695        let handler = make_handler(py, handler)?;
696        let pattern = MStr::<Pattern>::pattern(topic);
697        msgbus_api::subscribe_any(pattern, handler, Some(priority));
698        Ok(())
699    }
700
701    /// Unsubscribes the given handler from the given topic.
702    #[pyo3(name = "unsubscribe")]
703    fn py_unsubscribe(&self, py: Python<'_>, topic: &str, handler: Py<PyAny>) -> PyResult<()> {
704        let handler = make_handler(py, handler)?;
705        let pattern = MStr::<Pattern>::pattern(topic);
706        msgbus_api::unsubscribe_any(pattern, &handler);
707        Ok(())
708    }
709
710    /// Publishes a message for the given topic.
711    #[pyo3(name = "publish")]
712    #[pyo3(signature = (topic, msg, external_pub=true))]
713    #[allow(clippy::needless_pass_by_value)]
714    fn py_publish(
715        &mut self,
716        py: Python<'_>,
717        topic: &str,
718        msg: Py<PyAny>,
719        external_pub: bool,
720    ) -> PyResult<()> {
721        let topic_mstr = MStr::<Topic>::topic(topic).map_err(to_pyruntime_err)?;
722
723        let py_msg = PyMessage(msg.clone_ref(py));
724        msgbus_api::publish_any(topic_mstr, &py_msg);
725
726        if external_pub {
727            self.publish_external(py, topic, &msg)?;
728        }
729
730        self.pub_count += 1;
731        Ok(())
732    }
733
734    /// Disposes of the message bus, clearing all state.
735    #[pyo3(name = "dispose")]
736    fn py_dispose(&mut self, py: Python<'_>) -> PyResult<()> {
737        log::debug!("Closing message bus");
738
739        get_message_bus().borrow_mut().dispose();
740
741        self.correlation_index.clear();
742        self.listeners.clear();
743        self.streaming_types.clear();
744
745        if let Some(ref database) = self.database {
746            let db = database.bind(py);
747            if !db.call_method0("is_closed")?.extract::<bool>()? {
748                db.call_method0("close")?;
749            }
750        }
751
752        log::info!("Closed message bus");
753        Ok(())
754    }
755
756    /// Registers a type for external-to-internal message streaming.
757    #[pyo3(name = "add_streaming_type")]
758    fn py_add_streaming_type(&mut self, cls: Py<PyAny>) {
759        self.streaming_types.push(cls);
760    }
761
762    /// Adds a listener to the message bus.
763    #[pyo3(name = "add_listener")]
764    fn py_add_listener(&mut self, listener: Py<PyAny>) {
765        self.listeners.push(listener);
766    }
767}
768
769impl PyMessageBus {
770    fn publish_external(&self, py: Python<'_>, topic: &str, msg: &Py<PyAny>) -> PyResult<()> {
771        if let Some(ref filter) = self.types_filter {
772            let is_excluded = py
773                .import("builtins")?
774                .call_method1("isinstance", (msg, filter))?
775                .extract::<bool>()?;
776
777            if is_excluded {
778                return Ok(());
779            }
780        }
781
782        // Serialize: raw bytes pass through, other types need a serializer
783        let msg_ref = msg.bind(py);
784        let payload: Py<PyAny> = if msg_ref.is_instance_of::<pyo3::types::PyBytes>() {
785            msg.clone_ref(py)
786        } else if let Some(ref serializer) = self.serializer {
787            serializer.call_method1(py, "serialize", (msg,))?
788        } else {
789            return Ok(());
790        };
791
792        if let Some(ref database) = self.database {
793            let db = database.bind(py);
794            if !db.call_method0("is_closed")?.extract::<bool>()? {
795                db.call_method1("publish", (topic, &payload))?;
796            }
797        }
798
799        for listener in &self.listeners {
800            let l = listener.bind(py);
801            if l.call_method0("is_closed")?.extract::<bool>()? {
802                continue;
803            }
804            l.call_method1("publish", (topic, &payload))?;
805        }
806
807        Ok(())
808    }
809}
810
811#[cfg(test)]
812mod tests {
813    use std::any::Any;
814
815    use pyo3::ffi::c_str;
816    use rstest::rstest;
817
818    use super::*;
819
820    #[rstest]
821    fn test_py_message_downcast() {
822        pyo3::Python::initialize();
823        Python::attach(|py| {
824            let py_obj = py.eval(c_str!("42"), None, None).unwrap();
825            let msg = PyMessage(py_obj.unbind());
826
827            let any_ref: &dyn Any = &msg;
828            let downcasted = any_ref.downcast_ref::<PyMessage>();
829            assert!(downcasted.is_some());
830
831            let inner = &downcasted.unwrap().0;
832            let value: i64 = inner.extract(py).unwrap();
833            assert_eq!(value, 42);
834        });
835    }
836
837    #[rstest]
838    fn test_py_callable_handler_id_stability() {
839        pyo3::Python::initialize();
840        Python::attach(|py| {
841            let callable = py.eval(c_str!("lambda x: x"), None, None).unwrap().unbind();
842
843            let handler1 = PyCallableHandler::new(py, callable.clone_ref(py)).unwrap();
844            let handler2 = PyCallableHandler::new(py, callable).unwrap();
845
846            assert_eq!(handler1.id(), handler2.id());
847        });
848    }
849
850    #[rstest]
851    fn test_py_callable_handler_dispatch() {
852        pyo3::Python::initialize();
853        Python::attach(|py| {
854            let main = py.import("__main__").unwrap();
855            let globals = main.dict();
856            py.run(
857                c_str!("results = []\ndef handler(x): results.append(x)"),
858                Some(&globals),
859                None,
860            )
861            .unwrap();
862
863            let handler_fn = globals.get_item("handler").unwrap().unwrap().unbind();
864            let handler = PyCallableHandler::new(py, handler_fn).unwrap();
865
866            let py_obj = py.eval(c_str!("'hello'"), None, None).unwrap();
867            let msg = PyMessage(py_obj.unbind());
868
869            let any_ref: &dyn Any = &msg;
870            handler.handle(any_ref);
871
872            let results = globals.get_item("results").unwrap().unwrap();
873            let len: usize = results.len().unwrap();
874            assert_eq!(len, 1);
875        });
876    }
877}