dora_ros2_bridge_msg_gen/
lib.rs

1// Based on https://github.com/rclrust/rclrust/tree/3a48dbb8f23a3d67d3031351da3ed236a354f039/rclrust-msg-gen
2
3#![warn(
4    rust_2018_idioms,
5    elided_lifetimes_in_paths,
6    clippy::all,
7    clippy::nursery
8)]
9
10use std::path::Path;
11
12use quote::{ToTokens, format_ident, quote};
13
14pub mod parser;
15pub mod types;
16
17pub use crate::parser::get_packages;
18use crate::types::Package;
19
20#[allow(clippy::cognitive_complexity)]
21pub fn generate_package(package: &Package, create_cxx_bridge: bool) -> proc_macro2::TokenStream {
22    let mut shared_type_defs = Vec::new();
23    let mut message_struct_impls = Vec::new();
24    let mut message_topic_defs = Vec::new();
25    let mut message_topic_impls = Vec::new();
26    let mut service_defs = Vec::new();
27    let mut service_impls = Vec::new();
28    let mut service_creation_defs = Vec::new();
29    let mut service_creation_impls = Vec::new();
30
31    let mut action_defs = Vec::new();
32    let mut action_impls = Vec::new();
33    let mut action_creation_defs = Vec::new();
34    let mut action_creation_impls = Vec::new();
35
36    for message in &package.messages {
37        let (def, imp) = message.struct_token_stream(&package.name, create_cxx_bridge);
38        shared_type_defs.push(def);
39        message_struct_impls.push(imp);
40        if create_cxx_bridge {
41            let (topic_def, topic_impl) = message.topic_def(&package.name);
42            message_topic_defs.push(topic_def);
43            message_topic_impls.push(topic_impl);
44        }
45    }
46
47    for service in &package.services {
48        let (def, imp) = service.struct_token_stream(&package.name, create_cxx_bridge);
49        service_defs.push(def);
50        service_impls.push(imp);
51        if create_cxx_bridge {
52            let (service_creation_def, service_creation_impl) =
53                service.cxx_service_creation_functions(&package.name);
54            service_creation_defs.push(service_creation_def);
55            service_creation_impls.push(service_creation_impl);
56        }
57    }
58
59    for action in &package.actions {
60        let (def, imp) = action.struct_token_stream(&package.name, create_cxx_bridge);
61        action_defs.push(def);
62        action_impls.push(imp);
63        if create_cxx_bridge {
64            let (action_creation_def, action_creation_impl) =
65                action.cxx_action_creation_functions(&package.name);
66            let action_creation_def = quote! { #action_creation_def };
67            let action_creation_impl = quote! { #action_creation_impl };
68            action_creation_defs.push(action_creation_def);
69            action_creation_impls.push(action_creation_impl);
70        }
71    }
72
73    let aliases = package.aliases_token_stream();
74
75    let (attributes, ffi_imports, extern_block, rust_imports) = if create_cxx_bridge {
76        let reuse_bindings = package.reuse_bindings_token_stream();
77        let rust_imports = generate_rust_imports_for_cxx();
78        (
79            quote! { #[cxx::bridge] },
80            quote! {},
81            quote! {
82                extern "Rust" {
83                    #(#message_topic_defs)*
84                    #(#service_creation_defs)*
85                    #(#action_creation_defs)*
86                }
87                extern "C++" {
88                    include!("ros2-bridge/impl.rs.h");
89                    include!("dora-node-api.h");
90                    type CombinedEvents = crate::ffi::CombinedEvents;
91                    type CombinedEvent = crate::ffi::CombinedEvent;
92
93                    type Ros2Context = crate::ros2::default_impl::Ros2Context;
94                    type Ros2Node = crate::ros2::default_impl::Ros2Node;
95
96                    type Ros2Durability = crate::ros2::default_impl::ffi::Ros2Durability;
97                    type Ros2Liveliness = crate::ros2::default_impl::ffi::Ros2Liveliness;
98                    type Ros2ActionClientQosPolicies = crate::ros2::default_impl::ffi::Ros2ActionClientQosPolicies;
99                    type Ros2QosPolicies = crate::ros2::default_impl::ffi::Ros2QosPolicies;
100
101                    type U16String = crate::ros2::default_impl::ffi::U16String;
102
103                    #reuse_bindings
104                }
105            },
106            quote! {
107                #rust_imports
108
109                use crate::ros2::default_impl::Ros2Node;
110            },
111        )
112    } else {
113        let dependencies = package.dependencies_import_token_stream();
114        (
115            quote! {},
116            quote! {
117                use serde::{Deserialize, Serialize};
118
119                #[allow(unused_imports)]
120                use crate::messages::default_impl::ffi::*;
121                #dependencies
122            },
123            quote! {},
124            quote! {},
125        )
126    };
127
128    quote! {
129        #rust_imports
130
131        #attributes
132        pub mod ffi {
133            #ffi_imports
134
135            #extern_block
136            #(#shared_type_defs)*
137            #(#service_defs)*
138            #(#action_defs)*
139        }
140
141        #(#message_struct_impls)*
142
143        #(#message_topic_impls)*
144        #(#service_creation_impls)*
145        #(#action_creation_impls)*
146
147        #(#service_impls)*
148        #(#action_impls)*
149
150        #aliases
151    }
152}
153
154#[allow(clippy::cognitive_complexity)]
155pub fn generate<P>(paths: &[P], out_dir: &Path, create_cxx_bridge: bool) -> proc_macro2::TokenStream
156where
157    P: AsRef<Path>,
158{
159    use rust_format::Formatter;
160    let packages = get_packages(paths).unwrap();
161    let mut mod_decl = vec![];
162    let msg_dir = out_dir.join("msg");
163    if !msg_dir.exists() {
164        std::fs::create_dir(&msg_dir).unwrap();
165    }
166    // generate mod
167    for package in packages.iter() {
168        let mod_impl = generate_package(package, create_cxx_bridge);
169        let generated_string = rust_format::PrettyPlease::default()
170            .format_tokens(mod_impl)
171            .unwrap();
172        let package_name = &package.name;
173        let file_path = msg_dir.join(format!("{}.rs", package_name));
174        std::fs::write(&file_path, generated_string).unwrap();
175        let file_path_str = file_path.to_str().unwrap();
176        let package_name_ident = format_ident!("{}", package_name);
177        mod_decl.push(quote! {
178            #[path = #file_path_str]
179            pub mod #package_name_ident;
180        });
181    }
182
183    {
184        let generated_default_impls = generate_default_impls(create_cxx_bridge);
185        let generated_string = rust_format::PrettyPlease::default()
186            .format_tokens(generated_default_impls)
187            .unwrap();
188        let file_path = out_dir.join("impl.rs");
189        std::fs::write(&file_path, generated_string).unwrap();
190        let file_path_str = file_path.to_str().unwrap();
191        mod_decl.push(quote! {
192            #[path = #file_path_str]
193            pub mod default_impl;
194        });
195    }
196
197    quote! {
198        #(#mod_decl)*
199
200        pub use default_impl::*;
201    }
202}
203
204fn generate_default_impls(create_cxx_bridge: bool) -> proc_macro2::TokenStream {
205    let cxx_ros2_decl = quote! {
206        extern "Rust" {
207            type Ros2Context;
208            type Ros2Node;
209            fn init_ros2_context() -> Result<Box<Ros2Context>>;
210            fn new_node(self: &Ros2Context, name_space: &str, base_name: &str) -> Result<Box<Ros2Node>>;
211            fn qos_default() -> Ros2QosPolicies;
212            fn actionqos_default() -> Ros2ActionClientQosPolicies;
213        }
214
215        #[derive(Debug, Clone)]
216        pub struct Ros2QosPolicies {
217            pub durability: Ros2Durability,
218            pub liveliness: Ros2Liveliness,
219            pub lease_duration: f64,
220            pub reliable: bool,
221            pub max_blocking_time: f64,
222            pub keep_all: bool,
223            pub keep_last: i32,
224        }
225
226        #[derive(Debug, Clone)]
227        pub struct Ros2ActionClientQosPolicies {
228            pub goal_service: Ros2QosPolicies,
229            pub result_service: Ros2QosPolicies,
230            pub cancel_service: Ros2QosPolicies,
231            pub feedback_subscription: Ros2QosPolicies,
232            pub status_subscription: Ros2QosPolicies,
233        }
234
235        /// DDS 2.2.3.4 DURABILITY
236        #[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
237        pub enum Ros2Durability {
238            Volatile,
239            TransientLocal,
240            Transient,
241            Persistent,
242        }
243
244        /// DDS 2.2.3.11 LIVELINESS
245        #[derive(Copy, Clone, Debug, PartialEq)]
246        pub enum Ros2Liveliness {
247            Automatic,
248            ManualByParticipant,
249            ManualByTopic,
250        }
251    };
252    let cxx_ros2_impl = quote! {
253        pub struct Ros2Context{
254            context: crate::ros2_client::Context,
255            executor: std::sync::Arc<futures::executor::ThreadPool>,
256        }
257
258        fn init_ros2_context() -> eyre::Result<Box<Ros2Context>> {
259            Ok(Box::new(Ros2Context{
260                context: crate::ros2_client::Context::new()?,
261                executor: std::sync::Arc::new(futures::executor::ThreadPool::new()?),
262            }))
263        }
264
265        impl Ros2Context {
266            fn new_node(&self, name_space: &str, base_name: &str) -> eyre::Result<Box<Ros2Node>> {
267                use futures::task::SpawnExt as _;
268                use eyre::WrapErr as _;
269
270                let name = crate::ros2_client::NodeName::new(name_space, base_name).map_err(|e| eyre::eyre!(e))?;
271                let options = crate::ros2_client::NodeOptions::new().enable_rosout(true);
272                let mut node = self.context.new_node(name, options)
273                    .map_err(|e| eyre::eyre!("failed to create ROS2 node: {e:?}"))?;
274
275                let spinner = node.spinner().context("failed to create spinner")?;
276                self.executor.spawn(async {
277                    if let Err(err) = spinner.spin().await {
278                        eprintln!("ros2 spinner failed: {err:?}");
279                    }
280                })
281                .context("failed to spawn ros2 spinner")?;
282
283                Ok(Box::new(Ros2Node{ node, executor: self.executor.clone(), }))
284            }
285        }
286
287        pub struct Ros2Node {
288            pub node : ros2_client::Node,
289            pub executor: std::sync::Arc<futures::executor::ThreadPool>,
290        }
291
292        unsafe impl cxx::ExternType for Ros2Node {
293            type Id = cxx::type_id!("Ros2Node");
294            type Kind = cxx::kind::Opaque;
295        }
296
297        unsafe impl cxx::ExternType for Ros2Context {
298            type Id = cxx::type_id!("Ros2Context");
299            type Kind = cxx::kind::Opaque;
300        }
301
302        fn qos_default() -> ffi::Ros2QosPolicies {
303            ffi::Ros2QosPolicies::new(None, None, None, None, None, None, None)
304        }
305
306        fn actionqos_default() -> ffi::Ros2ActionClientQosPolicies {
307            ffi::Ros2ActionClientQosPolicies::new(
308                Some(qos_default()),
309                Some(qos_default()),
310                Some(qos_default()),
311                Some(qos_default()),
312                Some(qos_default())
313            )
314        }
315
316        impl ffi::Ros2QosPolicies {
317            pub fn new(
318                durability: Option<ffi::Ros2Durability>,
319                liveliness: Option<ffi::Ros2Liveliness>,
320                reliable: Option<bool>,
321                keep_all: Option<bool>,
322                lease_duration: Option<f64>,
323                max_blocking_time: Option<f64>,
324                keep_last: Option<i32>,
325            ) -> Self {
326                Self {
327                    durability: durability.unwrap_or(ffi::Ros2Durability::Volatile),
328                    liveliness: liveliness.unwrap_or(ffi::Ros2Liveliness::Automatic),
329                    lease_duration: lease_duration.unwrap_or(f64::INFINITY),
330                    reliable: reliable.unwrap_or(false),
331                    max_blocking_time: max_blocking_time.unwrap_or(0.0),
332                    keep_all: keep_all.unwrap_or(false),
333                    keep_last: keep_last.unwrap_or(1),
334                }
335            }
336        }
337
338        impl From<ffi::Ros2QosPolicies> for rustdds::QosPolicies {
339            fn from(value: ffi::Ros2QosPolicies) -> Self {
340                rustdds::QosPolicyBuilder::new()
341                    .durability(value.durability.into())
342                    .liveliness(value.liveliness.convert(value.lease_duration))
343                    .reliability(if value.reliable {
344                        rustdds::policy::Reliability::Reliable {
345                            max_blocking_time: rustdds::Duration::from_frac_seconds(
346                                value.max_blocking_time,
347                            ),
348                        }
349                    } else {
350                        rustdds::policy::Reliability::BestEffort
351                    })
352                    .history(if value.keep_all {
353                        rustdds::policy::History::KeepAll
354                    } else {
355                        rustdds::policy::History::KeepLast {
356                            depth: value.keep_last,
357                        }
358                    })
359                    .build()
360            }
361        }
362
363
364
365        impl From<ffi::Ros2Durability> for rustdds::policy::Durability {
366            fn from(value: ffi::Ros2Durability) -> Self {
367                match value {
368                    ffi::Ros2Durability::Volatile => rustdds::policy::Durability::Volatile,
369                    ffi::Ros2Durability::TransientLocal => rustdds::policy::Durability::TransientLocal,
370                    ffi::Ros2Durability::Transient => rustdds::policy::Durability::Transient,
371                    ffi::Ros2Durability::Persistent => rustdds::policy::Durability::Persistent,
372                    _ => unreachable!(), // required because enums are represented as integers in bridge
373                }
374            }
375        }
376
377
378        impl ffi::Ros2Liveliness {
379            fn convert(self, lease_duration: f64) -> rustdds::policy::Liveliness {
380                let lease_duration = if lease_duration.is_infinite() {
381                    rustdds::Duration::INFINITE
382                } else {
383                    rustdds::Duration::from_frac_seconds(lease_duration)
384                };
385                match self {
386                    ffi::Ros2Liveliness::Automatic => rustdds::policy::Liveliness::Automatic { lease_duration },
387                    ffi::Ros2Liveliness::ManualByParticipant => {
388                        rustdds::policy::Liveliness::ManualByParticipant { lease_duration }
389                    }
390                    ffi::Ros2Liveliness::ManualByTopic => rustdds::policy::Liveliness::ManualByTopic { lease_duration },
391                    _ => unreachable!(), // required because enums are represented as integers in bridge
392                }
393            }
394        }
395
396        impl ffi::Ros2ActionClientQosPolicies {
397            pub fn new(
398                goal_service: Option<ffi::Ros2QosPolicies>,
399                result_service: Option<ffi::Ros2QosPolicies>,
400                cancel_service: Option<ffi::Ros2QosPolicies>,
401                feedback_subscription: Option<ffi::Ros2QosPolicies>,
402                status_subscription: Option<ffi::Ros2QosPolicies>,
403            ) -> Self {
404                Self {
405                    goal_service: goal_service.unwrap_or_else(|| ffi::Ros2QosPolicies::new(None, None, None, None, None, None, None)),
406                    result_service: result_service.unwrap_or_else(|| ffi::Ros2QosPolicies::new(None, None, None, None, None, None, None)),
407                    cancel_service: cancel_service.unwrap_or_else(|| ffi::Ros2QosPolicies::new(None, None, None, None, None, None, None)),
408                    feedback_subscription: feedback_subscription.unwrap_or_else(|| ffi::Ros2QosPolicies::new(None, None, None, None, None, None, None)),
409                    status_subscription: status_subscription.unwrap_or_else(|| ffi::Ros2QosPolicies::new(None, None, None, None, None, None, None)),
410                }
411            }
412        }
413
414        impl From<ffi::Ros2ActionClientQosPolicies> for crate::ros2_client::action::ActionClientQosPolicies {
415            fn from(value: ffi::Ros2ActionClientQosPolicies) -> Self {
416                crate::ros2_client::action::ActionClientQosPolicies {
417                    goal_service: value.goal_service.into(),
418                    result_service: value.result_service.into(),
419                    cancel_service: value.cancel_service.into(),
420                    feedback_subscription: value.feedback_subscription.into(),
421                    status_subscription: value.status_subscription.into(),
422                }
423            }
424        }
425    };
426    let u16str_decl = quote! {
427        #[derive(Debug, Default, Clone, PartialEq, Eq, Serialize, Deserialize)]
428        pub struct U16String {
429            pub chars: Vec<u16>,
430        }
431    };
432    let u16str_impl = quote! {
433        impl ffi::U16String {
434            #[allow(dead_code)]
435            fn from_str(arg: &str) -> Self {
436                Self {
437                    chars: crate::_core::widestring::U16String::from_str(arg).into_vec(),
438                }
439            }
440        }
441
442        impl crate::_core::InternalDefault for ffi::U16String {
443            fn _default() -> Self {
444                Default::default()
445            }
446        }
447    };
448
449    let (attribute, ffi_imports, rust_imports) = if create_cxx_bridge {
450        (
451            quote! {
452                #[cxx::bridge]
453            },
454            quote! {},
455            generate_rust_imports_for_cxx(),
456        )
457    } else {
458        (
459            quote! {},
460            quote! {
461                use serde::{Deserialize, Serialize};
462            },
463            quote! {},
464        )
465    };
466
467    let mut declares = vec![u16str_decl];
468    let mut implements = vec![u16str_impl];
469
470    if create_cxx_bridge {
471        declares.push(cxx_ros2_decl);
472        implements.push(cxx_ros2_impl);
473    }
474
475    quote! {
476        #rust_imports
477
478        #attribute
479        pub mod ffi {
480            #ffi_imports
481            #(#declares)*
482        }
483
484        #(#implements)*
485    }
486}
487
488fn generate_rust_imports_for_cxx() -> proc_macro2::TokenStream {
489    quote! {
490        #[allow(unused_imports)]
491        use crate::prelude::*;
492    }
493}