1use std::{any::Any, fmt::Debug, rc::Rc};
21
22use ahash::AHashMap;
23use nautilus_core::{UUID4, python::to_pyruntime_err};
24use nautilus_model::identifiers::TraderId;
25use pyo3::{Py, Python, prelude::*, types::PyBytes};
26use ustr::Ustr;
27
28use crate::{
29 enums::SerializationEncoding,
30 msgbus::{
31 self as msgbus_api, BusMessage, MessageBus,
32 core::Subscription,
33 database::{DatabaseConfig, MessageBusConfig},
34 get_message_bus,
35 matching::is_matching,
36 mstr::{Endpoint, MStr, Pattern, Topic},
37 typed_handler::{Handler, ShareableMessageHandler, TypedHandler},
38 },
39};
40
41#[pymethods]
42#[pyo3_stub_gen::derive::gen_stub_pymethods]
43impl BusMessage {
44 #[getter]
45 #[pyo3(name = "topic")]
46 fn py_topic(&self) -> String {
47 self.topic.to_string()
48 }
49
50 #[getter]
51 #[pyo3(name = "payload")]
52 fn py_payload(&self, py: Python<'_>) -> Py<PyBytes> {
53 PyBytes::new(py, self.payload.as_ref()).into()
54 }
55
56 fn __repr__(&self) -> String {
57 format!("{}('{}')", stringify!(BusMessage), self)
58 }
59
60 fn __str__(&self) -> String {
61 self.to_string()
62 }
63}
64
65#[pymethods]
66#[pyo3_stub_gen::derive::gen_stub_pymethods]
67impl DatabaseConfig {
68 #[new]
74 #[allow(clippy::too_many_arguments)]
75 #[pyo3(signature = (database_type=None, host=None, port=None, username=None, password=None, ssl=None, connection_timeout=None, response_timeout=None, number_of_retries=None, exponent_base=None, max_delay=None, factor=None))]
76 fn py_new(
77 database_type: Option<String>,
78 host: Option<String>,
79 port: Option<u16>,
80 username: Option<String>,
81 password: Option<String>,
82 ssl: Option<bool>,
83 connection_timeout: Option<u16>,
84 response_timeout: Option<u16>,
85 number_of_retries: Option<usize>,
86 exponent_base: Option<u64>,
87 max_delay: Option<u64>,
88 factor: Option<u64>,
89 ) -> Self {
90 let default = Self::default();
91 Self {
92 database_type: database_type.unwrap_or(default.database_type),
93 host,
94 port,
95 username,
96 password,
97 ssl: ssl.unwrap_or(default.ssl),
98 connection_timeout: connection_timeout.unwrap_or(default.connection_timeout),
99 response_timeout: response_timeout.unwrap_or(default.response_timeout),
100 number_of_retries: number_of_retries.unwrap_or(default.number_of_retries),
101 exponent_base: exponent_base.unwrap_or(default.exponent_base),
102 max_delay: max_delay.unwrap_or(default.max_delay),
103 factor: factor.unwrap_or(default.factor),
104 }
105 }
106
107 fn __repr__(&self) -> String {
108 format!("{self:?}")
109 }
110
111 fn __str__(&self) -> String {
112 format!("{self:?}")
113 }
114
115 #[getter]
116 fn database_type(&self) -> &str {
117 &self.database_type
118 }
119
120 #[getter]
121 fn host(&self) -> Option<&str> {
122 self.host.as_deref()
123 }
124
125 #[getter]
126 fn port(&self) -> Option<u16> {
127 self.port
128 }
129
130 #[getter]
131 fn username(&self) -> Option<&str> {
132 self.username.as_deref()
133 }
134
135 #[getter]
136 fn password(&self) -> Option<&str> {
137 self.password.as_deref()
138 }
139
140 #[getter]
141 fn ssl(&self) -> bool {
142 self.ssl
143 }
144
145 #[getter]
146 fn connection_timeout(&self) -> u16 {
147 self.connection_timeout
148 }
149
150 #[getter]
151 fn response_timeout(&self) -> u16 {
152 self.response_timeout
153 }
154
155 #[getter]
156 fn number_of_retries(&self) -> usize {
157 self.number_of_retries
158 }
159
160 #[getter]
161 fn exponent_base(&self) -> u64 {
162 self.exponent_base
163 }
164
165 #[getter]
166 fn max_delay(&self) -> u64 {
167 self.max_delay
168 }
169
170 #[getter]
171 fn factor(&self) -> u64 {
172 self.factor
173 }
174}
175
176#[pymethods]
177#[pyo3_stub_gen::derive::gen_stub_pymethods]
178impl MessageBusConfig {
179 #[new]
181 #[allow(clippy::too_many_arguments)]
182 #[pyo3(signature = (database=None, encoding=None, timestamps_as_iso8601=None, buffer_interval_ms=None, autotrim_mins=None, use_trader_prefix=None, use_trader_id=None, use_instance_id=None, streams_prefix=None, stream_per_topic=None, external_streams=None, types_filter=None, heartbeat_interval_secs=None))]
183 fn py_new(
184 database: Option<DatabaseConfig>,
185 encoding: Option<SerializationEncoding>,
186 timestamps_as_iso8601: Option<bool>,
187 buffer_interval_ms: Option<u32>,
188 autotrim_mins: Option<u32>,
189 use_trader_prefix: Option<bool>,
190 use_trader_id: Option<bool>,
191 use_instance_id: Option<bool>,
192 streams_prefix: Option<String>,
193 stream_per_topic: Option<bool>,
194 external_streams: Option<Vec<String>>,
195 types_filter: Option<Vec<String>>,
196 heartbeat_interval_secs: Option<u16>,
197 ) -> Self {
198 let default = Self::default();
199 Self {
200 database,
201 encoding: encoding.unwrap_or(default.encoding),
202 timestamps_as_iso8601: timestamps_as_iso8601.unwrap_or(default.timestamps_as_iso8601),
203 buffer_interval_ms,
204 autotrim_mins,
205 use_trader_prefix: use_trader_prefix.unwrap_or(default.use_trader_prefix),
206 use_trader_id: use_trader_id.unwrap_or(default.use_trader_id),
207 use_instance_id: use_instance_id.unwrap_or(default.use_instance_id),
208 streams_prefix: streams_prefix.unwrap_or(default.streams_prefix),
209 stream_per_topic: stream_per_topic.unwrap_or(default.stream_per_topic),
210 external_streams,
211 types_filter,
212 heartbeat_interval_secs,
213 }
214 }
215
216 fn __repr__(&self) -> String {
217 format!("{self:?}")
218 }
219
220 fn __str__(&self) -> String {
221 format!("{self:?}")
222 }
223
224 #[getter]
225 fn database(&self) -> Option<DatabaseConfig> {
226 self.database.clone()
227 }
228
229 #[getter]
230 fn encoding(&self) -> SerializationEncoding {
231 self.encoding
232 }
233
234 #[getter]
235 fn timestamps_as_iso8601(&self) -> bool {
236 self.timestamps_as_iso8601
237 }
238
239 #[getter]
240 fn buffer_interval_ms(&self) -> Option<u32> {
241 self.buffer_interval_ms
242 }
243
244 #[getter]
245 fn autotrim_mins(&self) -> Option<u32> {
246 self.autotrim_mins
247 }
248
249 #[getter]
250 fn use_trader_prefix(&self) -> bool {
251 self.use_trader_prefix
252 }
253
254 #[getter]
255 fn use_trader_id(&self) -> bool {
256 self.use_trader_id
257 }
258
259 #[getter]
260 fn use_instance_id(&self) -> bool {
261 self.use_instance_id
262 }
263
264 #[getter]
265 fn streams_prefix(&self) -> &str {
266 &self.streams_prefix
267 }
268
269 #[getter]
270 fn stream_per_topic(&self) -> bool {
271 self.stream_per_topic
272 }
273
274 #[getter]
275 fn external_streams(&self) -> Option<Vec<String>> {
276 self.external_streams.clone()
277 }
278
279 #[getter]
280 fn types_filter(&self) -> Option<Vec<String>> {
281 self.types_filter.clone()
282 }
283
284 #[getter]
285 fn heartbeat_interval_secs(&self) -> Option<u16> {
286 self.heartbeat_interval_secs
287 }
288}
289
290pub struct PyMessage(pub Py<PyAny>);
292
293impl Debug for PyMessage {
294 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
295 f.debug_tuple(stringify!(PyMessage))
296 .field(&"<PyObject>")
297 .finish()
298 }
299}
300
301pub struct PyCallableHandler {
306 id: Ustr,
307 callable: Py<PyAny>,
308}
309
310impl Debug for PyCallableHandler {
311 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
312 f.debug_struct(stringify!(PyCallableHandler))
313 .field("id", &self.id)
314 .finish()
315 }
316}
317
318impl PyCallableHandler {
319 pub fn new(py: Python<'_>, callable: Py<PyAny>) -> PyResult<Self> {
324 let repr_str = callable.bind(py).repr()?.to_string();
325 let id = Ustr::from(&repr_str);
326 Ok(Self { id, callable })
327 }
328}
329
330impl Handler<dyn Any> for PyCallableHandler {
331 fn id(&self) -> Ustr {
332 self.id
333 }
334
335 fn handle(&self, message: &dyn Any) {
336 if let Some(py_msg) = message.downcast_ref::<PyMessage>() {
337 Python::attach(|py| {
338 if let Err(e) = self.callable.call1(py, (&py_msg.0,)) {
339 log::error!("Python handler {id} failed: {e}", id = self.id);
340 }
341 });
342 } else {
343 log::error!(
344 "Python handler {id} received non-PyMessage type",
345 id = self.id
346 );
347 }
348 }
349}
350
351fn make_handler(py: Python<'_>, callable: Py<PyAny>) -> PyResult<ShareableMessageHandler> {
352 let handler = PyCallableHandler::new(py, callable)?;
353 Ok(TypedHandler(Rc::new(handler) as Rc<dyn Handler<dyn Any>>))
354}
355
356#[pyclass(
362 module = "nautilus_trader.core.nautilus_pyo3.common",
363 name = "MessageBus",
364 unsendable
365)]
366#[pyo3_stub_gen::derive::gen_stub_pyclass(module = "nautilus_trader.common")]
367pub struct PyMessageBus {
368 trader_id: TraderId,
369 instance_id: UUID4,
370 name: String,
371 has_backing: bool,
372 serializer: Option<Py<PyAny>>,
373 database: Option<Py<PyAny>>,
374 listeners: Vec<Py<PyAny>>,
375 types_filter: Option<Py<PyAny>>,
376 streaming_types: Vec<Py<PyAny>>,
377 correlation_index: AHashMap<UUID4, Py<PyAny>>,
378 sent_count: u64,
379 req_count: u64,
380 res_count: u64,
381 pub_count: u64,
382}
383
384impl Debug for PyMessageBus {
385 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
386 f.debug_struct(stringify!(PyMessageBus))
387 .field("trader_id", &self.trader_id)
388 .field("name", &self.name)
389 .finish()
390 }
391}
392
393#[pymethods]
394#[pyo3_stub_gen::derive::gen_stub_pymethods]
395impl PyMessageBus {
396 #[new]
401 #[pyo3(signature = (trader_id, clock=None, instance_id=None, name=None, serializer=None, database=None, config=None))]
402 #[allow(clippy::too_many_arguments, clippy::needless_pass_by_value)]
403 fn py_new(
404 py: Python<'_>,
405 trader_id: TraderId,
406 clock: Option<Py<PyAny>>,
407 instance_id: Option<UUID4>,
408 name: Option<String>,
409 serializer: Option<Py<PyAny>>,
410 database: Option<Py<PyAny>>,
411 config: Option<Py<PyAny>>,
412 ) -> PyResult<Self> {
413 let _ = clock;
414 let instance_id = instance_id.unwrap_or_default();
415 let bus_name = name.clone();
416 let has_backing = database.is_some();
417
418 let msgbus = MessageBus::new(trader_id, instance_id, bus_name, None);
419 msgbus.register_message_bus();
420
421 let types_filter = if let Some(ref cfg) = config {
422 let tf = cfg.getattr(py, "types_filter")?;
423 if tf.is_none(py) {
424 None
425 } else {
426 let tuple = py
428 .import("builtins")?
429 .call_method1("tuple", (tf,))?
430 .unbind();
431 Some(tuple)
432 }
433 } else {
434 None
435 };
436
437 Ok(Self {
438 trader_id,
439 instance_id,
440 name: name.unwrap_or_else(|| "MessageBus".to_owned()),
441 has_backing,
442 serializer,
443 database,
444 listeners: Vec::new(),
445 types_filter,
446 streaming_types: Vec::new(),
447 correlation_index: AHashMap::new(),
448 sent_count: 0,
449 req_count: 0,
450 res_count: 0,
451 pub_count: 0,
452 })
453 }
454
455 #[getter]
457 #[pyo3(name = "trader_id")]
458 fn py_trader_id(&self) -> TraderId {
459 self.trader_id
460 }
461
462 #[getter]
464 #[pyo3(name = "instance_id")]
465 fn py_instance_id(&self) -> UUID4 {
466 self.instance_id
467 }
468
469 #[getter]
471 #[pyo3(name = "name")]
472 fn py_name(&self) -> &str {
473 &self.name
474 }
475
476 #[getter]
478 #[pyo3(name = "has_backing")]
479 fn py_has_backing(&self) -> bool {
480 self.has_backing
481 }
482
483 #[getter]
485 #[pyo3(name = "sent_count")]
486 fn py_sent_count(&self) -> u64 {
487 self.sent_count
488 }
489
490 #[getter]
492 #[pyo3(name = "req_count")]
493 fn py_req_count(&self) -> u64 {
494 self.req_count
495 }
496
497 #[getter]
499 #[pyo3(name = "res_count")]
500 fn py_res_count(&self) -> u64 {
501 self.res_count
502 }
503
504 #[getter]
506 #[pyo3(name = "pub_count")]
507 fn py_pub_count(&self) -> u64 {
508 self.pub_count
509 }
510
511 #[pyo3(name = "endpoints")]
513 fn py_endpoints(&self) -> Vec<String> {
514 let bus = get_message_bus();
515 let bus_ref = bus.borrow();
516 bus_ref.endpoints().into_iter().map(String::from).collect()
517 }
518
519 #[pyo3(name = "topics")]
521 fn py_topics(&self) -> Vec<String> {
522 let bus = get_message_bus();
523 let bus_ref = bus.borrow();
524 let mut topics: Vec<String> = bus_ref.patterns().into_iter().map(String::from).collect();
525 topics.sort();
526 topics.dedup();
527 topics
528 }
529
530 #[pyo3(name = "subscriptions")]
532 #[pyo3(signature = (pattern=None))]
533 fn py_subscriptions(&self, pattern: Option<&str>) -> Vec<String> {
534 let bus = get_message_bus();
535 let bus_ref = bus.borrow();
536 let subs: Vec<&Subscription> = bus_ref.subscriptions();
537
538 match pattern {
539 Some(p) => {
540 let filter = MStr::<Pattern>::pattern(p);
541 subs.into_iter()
542 .filter(|s| is_matching(s.pattern.as_bytes(), filter.as_bytes()))
543 .map(|s| {
544 format!(
545 "Subscription(topic={}, handler={})",
546 s.pattern, s.handler_id
547 )
548 })
549 .collect()
550 }
551 None => subs
552 .into_iter()
553 .map(|s| {
554 format!(
555 "Subscription(topic={}, handler={})",
556 s.pattern, s.handler_id
557 )
558 })
559 .collect(),
560 }
561 }
562
563 #[pyo3(name = "has_subscribers")]
565 #[pyo3(signature = (pattern=None))]
566 fn py_has_subscribers(&self, pattern: Option<&str>) -> bool {
567 let bus = get_message_bus();
568 let bus_ref = bus.borrow();
569
570 match pattern {
571 Some(p) => {
572 let filter = MStr::<Pattern>::pattern(p);
573 bus_ref
574 .subscriptions()
575 .iter()
576 .any(|s| is_matching(s.pattern.as_bytes(), filter.as_bytes()))
577 }
578 None => !bus_ref.subscriptions().is_empty(),
579 }
580 }
581
582 #[pyo3(name = "is_subscribed")]
584 fn py_is_subscribed(&self, py: Python<'_>, topic: &str, handler: Py<PyAny>) -> PyResult<bool> {
585 let handler = make_handler(py, handler)?;
586 let pattern = MStr::<Pattern>::pattern(topic);
587 let sub = Subscription::new(pattern, handler, None);
588 Ok(get_message_bus().borrow().subscriptions.contains(&sub))
589 }
590
591 #[pyo3(name = "is_pending_request")]
593 fn py_is_pending_request(&self, request_id: UUID4) -> bool {
594 self.correlation_index.contains_key(&request_id)
595 }
596
597 #[pyo3(name = "is_streaming_type")]
599 #[allow(clippy::needless_pass_by_value)]
600 fn py_is_streaming_type(&self, py: Python<'_>, cls: Py<PyAny>) -> bool {
601 let cls_ref = cls.bind(py);
602 self.streaming_types.iter().any(|t| t.bind(py).is(cls_ref))
603 }
604
605 #[pyo3(name = "streaming_types")]
607 fn py_streaming_types(&self, py: Python<'_>) -> Vec<Py<PyAny>> {
608 self.streaming_types
609 .iter()
610 .map(|t| t.clone_ref(py))
611 .collect()
612 }
613
614 #[pyo3(name = "register")]
616 fn py_register(&self, py: Python<'_>, endpoint: &str, handler: Py<PyAny>) -> PyResult<()> {
617 let handler = make_handler(py, handler)?;
618 let endpoint = MStr::<Endpoint>::from(endpoint);
619 msgbus_api::register_any(endpoint, handler);
620 Ok(())
621 }
622
623 #[pyo3(name = "deregister")]
625 #[pyo3(signature = (endpoint, handler=None))]
626 #[allow(clippy::needless_pass_by_value)]
627 fn py_deregister(&self, endpoint: &str, handler: Option<Py<PyAny>>) {
628 let _ = handler;
629 let endpoint = MStr::<Endpoint>::from(endpoint);
630 msgbus_api::deregister_any(endpoint);
631 }
632
633 #[pyo3(name = "send")]
635 fn py_send(&mut self, endpoint: &str, msg: Py<PyAny>) {
636 let endpoint = MStr::<Endpoint>::from(endpoint);
637 let py_msg = PyMessage(msg);
638 msgbus_api::send_any(endpoint, &py_msg);
639 self.sent_count += 1;
640 }
641
642 #[pyo3(name = "request")]
644 fn py_request(&mut self, py: Python<'_>, endpoint: &str, request: Py<PyAny>) -> PyResult<()> {
645 let request_ref = request.bind(py);
646
647 let request_id: UUID4 = request_ref.getattr("id")?.extract()?;
648 let callback = request_ref.getattr("callback")?;
649
650 if self.correlation_index.contains_key(&request_id) {
651 log::error!(
652 "Cannot handle request: duplicate ID {request_id} found in correlation index"
653 );
654 return Ok(());
655 }
656
657 if !callback.is_none() {
658 self.correlation_index.insert(request_id, callback.unbind());
659 }
660
661 let endpoint = MStr::<Endpoint>::from(endpoint);
662 let py_msg = PyMessage(request);
663 msgbus_api::send_any(endpoint, &py_msg);
664 self.req_count += 1;
665
666 Ok(())
667 }
668
669 #[pyo3(name = "response")]
671 #[allow(clippy::needless_pass_by_value)]
672 fn py_response(&mut self, py: Python<'_>, response: Py<PyAny>) -> PyResult<()> {
673 let correlation_id: UUID4 = response.getattr(py, "correlation_id")?.extract(py)?;
674
675 if let Some(callback) = self.correlation_index.remove(&correlation_id) {
676 callback.call1(py, (&response,))?;
677 } else {
678 log::debug!("No callback for correlation_id {correlation_id}");
679 }
680
681 self.res_count += 1;
682 Ok(())
683 }
684
685 #[pyo3(name = "subscribe")]
687 #[pyo3(signature = (topic, handler, priority=0))]
688 fn py_subscribe(
689 &self,
690 py: Python<'_>,
691 topic: &str,
692 handler: Py<PyAny>,
693 priority: u8,
694 ) -> PyResult<()> {
695 let handler = make_handler(py, handler)?;
696 let pattern = MStr::<Pattern>::pattern(topic);
697 msgbus_api::subscribe_any(pattern, handler, Some(priority));
698 Ok(())
699 }
700
701 #[pyo3(name = "unsubscribe")]
703 fn py_unsubscribe(&self, py: Python<'_>, topic: &str, handler: Py<PyAny>) -> PyResult<()> {
704 let handler = make_handler(py, handler)?;
705 let pattern = MStr::<Pattern>::pattern(topic);
706 msgbus_api::unsubscribe_any(pattern, &handler);
707 Ok(())
708 }
709
710 #[pyo3(name = "publish")]
712 #[pyo3(signature = (topic, msg, external_pub=true))]
713 #[allow(clippy::needless_pass_by_value)]
714 fn py_publish(
715 &mut self,
716 py: Python<'_>,
717 topic: &str,
718 msg: Py<PyAny>,
719 external_pub: bool,
720 ) -> PyResult<()> {
721 let topic_mstr = MStr::<Topic>::topic(topic).map_err(to_pyruntime_err)?;
722
723 let py_msg = PyMessage(msg.clone_ref(py));
724 msgbus_api::publish_any(topic_mstr, &py_msg);
725
726 if external_pub {
727 self.publish_external(py, topic, &msg)?;
728 }
729
730 self.pub_count += 1;
731 Ok(())
732 }
733
734 #[pyo3(name = "dispose")]
736 fn py_dispose(&mut self, py: Python<'_>) -> PyResult<()> {
737 log::debug!("Closing message bus");
738
739 get_message_bus().borrow_mut().dispose();
740
741 self.correlation_index.clear();
742 self.listeners.clear();
743 self.streaming_types.clear();
744
745 if let Some(ref database) = self.database {
746 let db = database.bind(py);
747 if !db.call_method0("is_closed")?.extract::<bool>()? {
748 db.call_method0("close")?;
749 }
750 }
751
752 log::info!("Closed message bus");
753 Ok(())
754 }
755
756 #[pyo3(name = "add_streaming_type")]
758 fn py_add_streaming_type(&mut self, cls: Py<PyAny>) {
759 self.streaming_types.push(cls);
760 }
761
762 #[pyo3(name = "add_listener")]
764 fn py_add_listener(&mut self, listener: Py<PyAny>) {
765 self.listeners.push(listener);
766 }
767}
768
769impl PyMessageBus {
770 fn publish_external(&self, py: Python<'_>, topic: &str, msg: &Py<PyAny>) -> PyResult<()> {
771 if let Some(ref filter) = self.types_filter {
772 let is_excluded = py
773 .import("builtins")?
774 .call_method1("isinstance", (msg, filter))?
775 .extract::<bool>()?;
776
777 if is_excluded {
778 return Ok(());
779 }
780 }
781
782 let msg_ref = msg.bind(py);
784 let payload: Py<PyAny> = if msg_ref.is_instance_of::<pyo3::types::PyBytes>() {
785 msg.clone_ref(py)
786 } else if let Some(ref serializer) = self.serializer {
787 serializer.call_method1(py, "serialize", (msg,))?
788 } else {
789 return Ok(());
790 };
791
792 if let Some(ref database) = self.database {
793 let db = database.bind(py);
794 if !db.call_method0("is_closed")?.extract::<bool>()? {
795 db.call_method1("publish", (topic, &payload))?;
796 }
797 }
798
799 for listener in &self.listeners {
800 let l = listener.bind(py);
801 if l.call_method0("is_closed")?.extract::<bool>()? {
802 continue;
803 }
804 l.call_method1("publish", (topic, &payload))?;
805 }
806
807 Ok(())
808 }
809}
810
811#[cfg(test)]
812mod tests {
813 use std::any::Any;
814
815 use pyo3::ffi::c_str;
816 use rstest::rstest;
817
818 use super::*;
819
820 #[rstest]
821 fn test_py_message_downcast() {
822 pyo3::Python::initialize();
823 Python::attach(|py| {
824 let py_obj = py.eval(c_str!("42"), None, None).unwrap();
825 let msg = PyMessage(py_obj.unbind());
826
827 let any_ref: &dyn Any = &msg;
828 let downcasted = any_ref.downcast_ref::<PyMessage>();
829 assert!(downcasted.is_some());
830
831 let inner = &downcasted.unwrap().0;
832 let value: i64 = inner.extract(py).unwrap();
833 assert_eq!(value, 42);
834 });
835 }
836
837 #[rstest]
838 fn test_py_callable_handler_id_stability() {
839 pyo3::Python::initialize();
840 Python::attach(|py| {
841 let callable = py.eval(c_str!("lambda x: x"), None, None).unwrap().unbind();
842
843 let handler1 = PyCallableHandler::new(py, callable.clone_ref(py)).unwrap();
844 let handler2 = PyCallableHandler::new(py, callable).unwrap();
845
846 assert_eq!(handler1.id(), handler2.id());
847 });
848 }
849
850 #[rstest]
851 fn test_py_callable_handler_dispatch() {
852 pyo3::Python::initialize();
853 Python::attach(|py| {
854 let main = py.import("__main__").unwrap();
855 let globals = main.dict();
856 py.run(
857 c_str!("results = []\ndef handler(x): results.append(x)"),
858 Some(&globals),
859 None,
860 )
861 .unwrap();
862
863 let handler_fn = globals.get_item("handler").unwrap().unwrap().unbind();
864 let handler = PyCallableHandler::new(py, handler_fn).unwrap();
865
866 let py_obj = py.eval(c_str!("'hello'"), None, None).unwrap();
867 let msg = PyMessage(py_obj.unbind());
868
869 let any_ref: &dyn Any = &msg;
870 handler.handle(any_ref);
871
872 let results = globals.get_item("results").unwrap().unwrap();
873 let len: usize = results.len().unwrap();
874 assert_eq!(len, 1);
875 });
876 }
877}