1use std::sync::atomic::AtomicUsize;
2
3use proc_macro::TokenStream;
4use quote::{format_ident, quote};
5use syn::{parse_macro_input, spanned::Spanned, DeriveInput};
6
7fn ast_hash(ast: &syn::DeriveInput) -> usize {
8 use std::hash::{Hash, Hasher};
9 let mut hasher = std::collections::hash_map::DefaultHasher::new();
10 ast.hash(&mut hasher);
11 let full_hash = hasher.finish();
12
13 #[cfg(target_pointer_width = "64")]
14 {
15 full_hash as usize
16 }
17 #[cfg(target_pointer_width = "32")]
18 {
19 (((full_hash >> 32) as u32) ^ (full_hash as u32)) as usize
20 }
21 #[cfg(not(any(target_pointer_width = "32", target_pointer_width = "64")))]
22 compile_error!("Unsupported target_pointer_width");
23}
24
25#[proc_macro_derive(IpcSafe)]
26pub fn derive_transmittable(ts: TokenStream) -> TokenStream {
27 let ast = parse_macro_input!(ts as syn::DeriveInput);
28 derive_transmittable_inner(ast).unwrap_or_else(|e| e).into()
29}
30
31fn derive_transmittable_inner(
32 ast: DeriveInput,
33) -> Result<proc_macro2::TokenStream, proc_macro2::TokenStream> {
34 let ident = ast.ident.clone();
35 let transmittable_checks = match &ast.data {
36 syn::Data::Struct(r#struct) => generate_transmittable_checks_struct(&ast, r#struct)?,
37 syn::Data::Enum(r#enum) => generate_transmittable_checks_enum(&ast, r#enum)?,
38 syn::Data::Union(r#union) => generate_transmittable_checks_union(&ast, r#union)?,
39 };
40 let result = quote! {
41 #transmittable_checks
42
43 unsafe impl flatipc::IpcSafe for #ident {}
44 };
45
46 Ok(result)
47}
48
49#[proc_macro_derive(Ipc)]
50pub fn derive_ipc(ts: TokenStream) -> TokenStream {
51 let ast = parse_macro_input!(ts as syn::DeriveInput);
52 derive_ipc_inner(ast).unwrap_or_else(|e| e).into()
53}
54
55fn derive_ipc_inner(ast: DeriveInput) -> Result<proc_macro2::TokenStream, proc_macro2::TokenStream> {
56 ensure_valid_repr(&ast)?;
58
59 let transmittable_checks = match &ast.data {
60 syn::Data::Struct(r#struct) => generate_transmittable_checks_struct(&ast, r#struct)?,
61 syn::Data::Enum(r#enum) => generate_transmittable_checks_enum(&ast, r#enum)?,
62 syn::Data::Union(r#union) => generate_transmittable_checks_union(&ast, r#union)?,
63 };
64
65 let ipc_struct = generate_ipc_struct(&ast)?;
66 Ok(quote! {
67 #transmittable_checks
68 #ipc_struct
69 })
70}
71
72fn ensure_valid_repr(ast: &DeriveInput) -> Result<(), proc_macro2::TokenStream> {
73 let mut repr_c = false;
74 for attr in ast.attrs.iter() {
75 if attr.path().is_ident("repr") {
76 attr.parse_nested_meta(|meta| {
77 if meta.path.is_ident("C") {
78 repr_c = true;
79 }
80 Ok(())
81 })
82 .map_err(|e| e.to_compile_error())?;
83 }
84 }
85 if !repr_c {
86 Err(syn::Error::new(ast.span(), "Structs must be marked as repr(C) to be IPC-safe")
87 .to_compile_error())
88 } else {
89 Ok(())
90 }
91}
92
93fn type_to_string(ty: &syn::Type) -> String {
94 match ty {
95 syn::Type::Array(_type_array) => "Array".to_owned(),
96 syn::Type::BareFn(_type_bare_fn) => "BareFn".to_owned(),
97 syn::Type::Group(_type_group) => "Group".to_owned(),
98 syn::Type::ImplTrait(_type_impl_trait) => "ImplTrait".to_owned(),
99 syn::Type::Infer(_type_infer) => "Infer".to_owned(),
100 syn::Type::Macro(_type_macro) => "Macro".to_owned(),
101 syn::Type::Never(_type_never) => "Never".to_owned(),
102 syn::Type::Paren(_type_paren) => "Paren".to_owned(),
103 syn::Type::Path(_type_path) => "Path".to_owned(),
104 syn::Type::Ptr(_type_ptr) => "Ptr".to_owned(),
105 syn::Type::Reference(_type_reference) => "Reference".to_owned(),
106 syn::Type::Slice(_type_slice) => "Slice".to_owned(),
107 syn::Type::TraitObject(_type_trait_object) => "TraitObject".to_owned(),
108 syn::Type::Tuple(_type_tuple) => "Tuple".to_owned(),
109 syn::Type::Verbatim(_token_stream) => "Verbatim".to_owned(),
110 _ => "Other (Unknown)".to_owned(),
111 }
112}
113
114fn ensure_type_exists_for(ty: &syn::Type) -> Result<proc_macro2::TokenStream, proc_macro2::TokenStream> {
115 match ty {
116 syn::Type::Path(_) => {
117 static ATOMIC_INDEX: AtomicUsize = AtomicUsize::new(0);
118 let fn_name = format_ident!(
119 "assert_type_exists_for_parameter_{}",
120 ATOMIC_INDEX.fetch_add(1, std::sync::atomic::Ordering::SeqCst)
121 );
122 Ok(quote! {
123 fn #fn_name (_var: #ty) { ensure_is_transmittable::<#ty>(); }
124 })
125 }
126 syn::Type::Tuple(tuple) => {
127 let mut check_functions = vec![];
128 for ty in tuple.elems.iter() {
129 check_functions.push(ensure_type_exists_for(ty)?);
130 }
131 Ok(quote! {
132 #(#check_functions)*
133 })
134 }
135 syn::Type::Array(array) => ensure_type_exists_for(&array.elem),
136 _ => Err(syn::Error::new(ty.span(), format!("The type `{}` is unsupported", type_to_string(ty)))
137 .to_compile_error()),
138 }
139}
140
141fn generate_transmittable_checks_enum(
142 ast: &syn::DeriveInput,
143 enm: &syn::DataEnum,
144) -> Result<proc_macro2::TokenStream, proc_macro2::TokenStream> {
145 let mut variants = Vec::new();
146
147 let surrounding_function = format_ident!("ensure_members_are_transmittable_for_{}", ast.ident);
148 for variant in &enm.variants {
149 let fields = match &variant.fields {
150 syn::Fields::Named(fields) => {
151 fields.named.iter().map(|f| ensure_type_exists_for(&f.ty)).collect()
152 }
153 syn::Fields::Unnamed(fields) => {
154 fields.unnamed.iter().map(|f| ensure_type_exists_for(&f.ty)).collect()
155 }
156 syn::Fields::Unit => Vec::new(),
157 };
158
159 let mut vetted_fields = vec![];
160 for field in fields {
161 match field {
162 Ok(f) => vetted_fields.push(f),
163 Err(e) => return Err(e),
164 }
165 }
166
167 variants.push(quote! {
168 #(#vetted_fields)*
169 });
170 }
171
172 Ok(quote! {
173 #[allow(non_snake_case, dead_code)]
174 fn #surrounding_function () {
175 pub fn ensure_is_transmittable<T: flatipc::IpcSafe>() {}
176 #(#variants)*
177 }
178
179 })
180}
181
182fn generate_transmittable_checks_struct(
183 ast: &syn::DeriveInput,
184 strct: &syn::DataStruct,
185) -> Result<proc_macro2::TokenStream, proc_macro2::TokenStream> {
186 let surrounding_function = format_ident!("ensure_members_are_transmittable_for_{}", ast.ident);
187 let fields = match &strct.fields {
188 syn::Fields::Named(fields) => fields.named.iter().map(|f| ensure_type_exists_for(&f.ty)).collect(),
189 syn::Fields::Unnamed(fields) => {
190 fields.unnamed.iter().map(|f| ensure_type_exists_for(&f.ty)).collect()
191 }
192 syn::Fields::Unit => Vec::new(),
193 };
194 let mut vetted_fields = vec![];
195 for field in fields {
196 match field {
197 Ok(f) => vetted_fields.push(f),
198 Err(e) => return Err(e),
199 }
200 }
201 Ok(quote! {
202 #[allow(non_snake_case, dead_code)]
203 fn #surrounding_function () {
204 pub fn ensure_is_transmittable<T: flatipc::IpcSafe>() {}
205 #(#vetted_fields)*
206 }
207 })
208}
209
210fn generate_transmittable_checks_union(
211 ast: &syn::DeriveInput,
212 unn: &syn::DataUnion,
213) -> Result<proc_macro2::TokenStream, proc_macro2::TokenStream> {
214 let surrounding_function = format_ident!("ensure_members_are_transmittable_for_{}", ast.ident);
215 let fields: Vec<Result<proc_macro2::TokenStream, proc_macro2::TokenStream>> =
216 unn.fields.named.iter().map(|f| ensure_type_exists_for(&f.ty)).collect();
217
218 let mut vetted_fields = vec![];
219 for field in fields {
220 match field {
221 Ok(f) => vetted_fields.push(f),
222 Err(e) => return Err(e),
223 }
224 }
225 Ok(quote! {
226 #[allow(non_snake_case, dead_code)]
227 fn #surrounding_function () {
228 pub fn ensure_is_transmittable<T: flatipc::IpcSafe>() {}
229 #(#vetted_fields)*
230 }
231 })
232}
233
234fn generate_ipc_struct(ast: &DeriveInput) -> Result<proc_macro2::TokenStream, proc_macro2::TokenStream> {
235 let visibility = ast.vis.clone();
236 let ident = ast.ident.clone();
237 let ipc_ident = format_ident!("Ipc{}", ast.ident);
238 let ident_size = quote! { core::mem::size_of::< #ident >() };
239 let padded_size = quote! { (#ident_size + (4096 - 1)) & !(4096 - 1) };
240 let padding_size = quote! { #padded_size - #ident_size };
241 let hash = ast_hash(ast);
242
243 let build_message = quote! {
244 use xous::definitions::{MemoryMessage, MemoryAddress, MemoryRange};
245 let mut buf = unsafe { MemoryRange::new(data.as_ptr() as usize, data.len()) }.unwrap();
246 let msg = MemoryMessage {
247 id: opcode,
248 buf,
249 offset: MemoryAddress::new(signature),
250 valid: None,
251 };
252 };
253
254 let lend = if cfg!(feature = "xous") {
255 quote! {
256 #build_message
257 xous::send_message(connection, xous::Message::MutableBorrow(msg))?;
258 }
259 } else {
260 quote! {
261 flatipc::backend::mock::IPC_MACHINE.lock().unwrap().lend(connection, opcode, signature, 0, &data);
262 }
263 };
264
265 let try_lend = if cfg!(feature = "xous") {
266 quote! {
267 #build_message
268 xous::try_send_message(connection, xous::Message::MutableBorrow(msg))?;
269 }
270 } else {
271 quote! {
272 flatipc::backend::mock::IPC_MACHINE.lock().unwrap().lend(connection, opcode, signature, 0, &data);
273 }
274 };
275
276 let lend_mut = if cfg!(feature = "xous") {
277 quote! {
278 #build_message
279 xous::send_message(connection, xous::Message::MutableBorrow(msg))?;
280 }
281 } else {
282 quote! {
283 flatipc::backend::mock::IPC_MACHINE.lock().unwrap().lend_mut(connection, opcode, signature, 0, &mut data);
284 }
285 };
286
287 let try_lend_mut = if cfg!(feature = "xous") {
288 quote! {
289 #build_message
290 xous::try_send_message(connection, xous::Message::MutableBorrow(msg))?;
291 }
292 } else {
293 quote! {
294 flatipc::backend::mock::IPC_MACHINE.lock().unwrap().lend_mut(connection, opcode, signature, 0, &mut data);
295 }
296 };
297
298 let memory_messages = if cfg!(feature = "xous") {
299 quote! {
300 fn from_memory_message<'a>(msg: &'a xous::MemoryMessage) -> Option<&'a Self> {
301 if msg.buf.len() < core::mem::size_of::< #ipc_ident >() {
302 return None;
303 }
304 let signature = msg.offset.map(|offset| offset.get()).unwrap_or_default();
305 if signature != #hash {
306 return None;
307 }
308 unsafe { Some(&*(msg.buf.as_ptr() as *const #ipc_ident)) }
309 }
310
311 fn from_memory_message_mut<'a>(msg: &'a mut xous::MemoryMessage) -> Option<&'a mut Self> {
312 if msg.buf.len() < core::mem::size_of::< #ipc_ident >() {
313 return None;
314 }
315 let signature = msg.offset.map(|offset| offset.get()).unwrap_or_default();
316 if signature != #hash {
317 return None;
318 }
319 unsafe { Some(&mut *(msg.buf.as_mut_ptr() as *mut #ipc_ident)) }
320 }
321 }
322 } else {
323 quote! {}
324 };
325
326 Ok(quote! {
327 #[repr(C, align(4096))]
328 #visibility struct #ipc_ident {
329 original: #ident,
330 padding: [u8; #padding_size],
331 }
332
333 impl core::ops::Deref for #ipc_ident {
334 type Target = #ident ;
335 fn deref(&self) -> &Self::Target {
336 &self.original
337 }
338 }
339
340 impl core::ops::DerefMut for #ipc_ident {
341 fn deref_mut(&mut self) -> &mut Self::Target {
342 &mut self.original
343 }
344 }
345
346 impl flatipc::IntoIpc for #ident {
347 type IpcType = #ipc_ident;
348 fn into_ipc(self) -> Self::IpcType {
349 #ipc_ident {
350 original: self,
351 padding: [0; #padding_size],
352 }
353 }
354 }
355
356 unsafe impl flatipc::Ipc for #ipc_ident {
357 type Original = #ident ;
358
359 fn from_slice<'a>(data: &'a [u8], signature: usize) -> Option<&'a Self> {
360 if data.len() < core::mem::size_of::< #ipc_ident >() {
361 return None;
362 }
363 if signature != #hash {
364 return None;
365 }
366 unsafe { Some(&*(data.as_ptr() as *const u8 as *const #ipc_ident)) }
367 }
368
369 unsafe fn from_buffer_unchecked<'a>(data: &'a [u8]) -> &'a Self {
370 &*(data.as_ptr() as *const u8 as *const #ipc_ident)
371 }
372
373 fn from_slice_mut<'a>(data: &'a mut [u8], signature: usize) -> Option<&'a mut Self> {
374 if data.len() < core::mem::size_of::< #ipc_ident >() {
375 return None;
376 }
377 if signature != #hash {
378 return None;
379 }
380 unsafe { Some(&mut *(data.as_mut_ptr() as *mut u8 as *mut #ipc_ident)) }
381 }
382
383 unsafe fn from_buffer_mut_unchecked<'a>(data: &'a mut [u8]) -> &'a mut Self {
384 unsafe { &mut *(data.as_mut_ptr() as *mut u8 as *mut #ipc_ident) }
385 }
386
387 fn lend(&self, connection: flatipc::CID, opcode: usize) -> Result<(), flatipc::Error> {
388 let signature = self.signature();
389 let data = unsafe {
390 core::slice::from_raw_parts(
391 self as *const #ipc_ident as *const u8,
392 core::mem::size_of::< #ipc_ident >(),
393 )
394 };
395 #lend
396 Ok(())
397 }
398
399 fn try_lend(&self, connection: flatipc::CID, opcode: usize) -> Result<(), flatipc::Error> {
400 let signature = self.signature();
401 let data = unsafe {
402 core::slice::from_raw_parts(
403 self as *const #ipc_ident as *const u8,
404 core::mem::size_of::< #ipc_ident >(),
405 )
406 };
407 #try_lend
408 Ok(())
409 }
410
411 fn lend_mut(&mut self, connection: flatipc::CID, opcode: usize) -> Result<(), flatipc::Error> {
412 let signature = self.signature();
413 let mut data = unsafe {
414 core::slice::from_raw_parts_mut(
415 self as *mut #ipc_ident as *mut u8,
416 #padded_size,
417 )
418 };
419 #lend_mut
420 Ok(())
421 }
422
423 fn try_lend_mut(&mut self, connection: flatipc::CID, opcode: usize) -> Result<(), flatipc::Error> {
424 let signature = self.signature();
425 let mut data = unsafe {
426 core::slice::from_raw_parts_mut(
427 self as *mut #ipc_ident as *mut u8,
428 #padded_size,
429 )
430 };
431 #try_lend_mut
432 Ok(())
433 }
434
435 fn as_original(&self) -> &Self::Original {
436 &self.original
437 }
438
439 fn as_original_mut(&mut self) -> &mut Self::Original {
440 &mut self.original
441 }
442
443 fn into_original(self) -> Self::Original {
444 self.original
445 }
446
447 fn signature(&self) -> usize {
448 #hash
449 }
450
451 #memory_messages
452 }
453 })
454}