flatipc_derive/
lib.rs

1use std::sync::atomic::AtomicUsize;
2
3use proc_macro::TokenStream;
4use quote::{format_ident, quote};
5use syn::{parse_macro_input, spanned::Spanned, DeriveInput};
6
7fn ast_hash(ast: &syn::DeriveInput) -> usize {
8    use std::hash::{Hash, Hasher};
9    let mut hasher = std::collections::hash_map::DefaultHasher::new();
10    ast.hash(&mut hasher);
11    let full_hash = hasher.finish();
12
13    #[cfg(target_pointer_width = "64")]
14    {
15        full_hash as usize
16    }
17    #[cfg(target_pointer_width = "32")]
18    {
19        (((full_hash >> 32) as u32) ^ (full_hash as u32)) as usize
20    }
21    #[cfg(not(any(target_pointer_width = "32", target_pointer_width = "64")))]
22    compile_error!("Unsupported target_pointer_width");
23}
24
25#[proc_macro_derive(IpcSafe)]
26pub fn derive_transmittable(ts: TokenStream) -> TokenStream {
27    let ast = parse_macro_input!(ts as syn::DeriveInput);
28    derive_transmittable_inner(ast).unwrap_or_else(|e| e).into()
29}
30
31fn derive_transmittable_inner(
32    ast: DeriveInput,
33) -> Result<proc_macro2::TokenStream, proc_macro2::TokenStream> {
34    let ident = ast.ident.clone();
35    let transmittable_checks = match &ast.data {
36        syn::Data::Struct(r#struct) => generate_transmittable_checks_struct(&ast, r#struct)?,
37        syn::Data::Enum(r#enum) => generate_transmittable_checks_enum(&ast, r#enum)?,
38        syn::Data::Union(r#union) => generate_transmittable_checks_union(&ast, r#union)?,
39    };
40    let result = quote! {
41        #transmittable_checks
42
43        unsafe impl flatipc::IpcSafe for #ident {}
44    };
45
46    Ok(result)
47}
48
49#[proc_macro_derive(Ipc)]
50pub fn derive_ipc(ts: TokenStream) -> TokenStream {
51    let ast = parse_macro_input!(ts as syn::DeriveInput);
52    derive_ipc_inner(ast).unwrap_or_else(|e| e).into()
53}
54
55fn derive_ipc_inner(ast: DeriveInput) -> Result<proc_macro2::TokenStream, proc_macro2::TokenStream> {
56    // Ensure the thing is using a repr we support.
57    ensure_valid_repr(&ast)?;
58
59    let transmittable_checks = match &ast.data {
60        syn::Data::Struct(r#struct) => generate_transmittable_checks_struct(&ast, r#struct)?,
61        syn::Data::Enum(r#enum) => generate_transmittable_checks_enum(&ast, r#enum)?,
62        syn::Data::Union(r#union) => generate_transmittable_checks_union(&ast, r#union)?,
63    };
64
65    let ipc_struct = generate_ipc_struct(&ast)?;
66    Ok(quote! {
67        #transmittable_checks
68        #ipc_struct
69    })
70}
71
72fn ensure_valid_repr(ast: &DeriveInput) -> Result<(), proc_macro2::TokenStream> {
73    let mut repr_c = false;
74    for attr in ast.attrs.iter() {
75        if attr.path().is_ident("repr") {
76            attr.parse_nested_meta(|meta| {
77                if meta.path.is_ident("C") {
78                    repr_c = true;
79                }
80                Ok(())
81            })
82            .map_err(|e| e.to_compile_error())?;
83        }
84    }
85    if !repr_c {
86        Err(syn::Error::new(ast.span(), "Structs must be marked as repr(C) to be IPC-safe")
87            .to_compile_error())
88    } else {
89        Ok(())
90    }
91}
92
93fn type_to_string(ty: &syn::Type) -> String {
94    match ty {
95        syn::Type::Array(_type_array) => "Array".to_owned(),
96        syn::Type::BareFn(_type_bare_fn) => "BareFn".to_owned(),
97        syn::Type::Group(_type_group) => "Group".to_owned(),
98        syn::Type::ImplTrait(_type_impl_trait) => "ImplTrait".to_owned(),
99        syn::Type::Infer(_type_infer) => "Infer".to_owned(),
100        syn::Type::Macro(_type_macro) => "Macro".to_owned(),
101        syn::Type::Never(_type_never) => "Never".to_owned(),
102        syn::Type::Paren(_type_paren) => "Paren".to_owned(),
103        syn::Type::Path(_type_path) => "Path".to_owned(),
104        syn::Type::Ptr(_type_ptr) => "Ptr".to_owned(),
105        syn::Type::Reference(_type_reference) => "Reference".to_owned(),
106        syn::Type::Slice(_type_slice) => "Slice".to_owned(),
107        syn::Type::TraitObject(_type_trait_object) => "TraitObject".to_owned(),
108        syn::Type::Tuple(_type_tuple) => "Tuple".to_owned(),
109        syn::Type::Verbatim(_token_stream) => "Verbatim".to_owned(),
110        _ => "Other (Unknown)".to_owned(),
111    }
112}
113
114fn ensure_type_exists_for(ty: &syn::Type) -> Result<proc_macro2::TokenStream, proc_macro2::TokenStream> {
115    match ty {
116        syn::Type::Path(_) => {
117            static ATOMIC_INDEX: AtomicUsize = AtomicUsize::new(0);
118            let fn_name = format_ident!(
119                "assert_type_exists_for_parameter_{}",
120                ATOMIC_INDEX.fetch_add(1, std::sync::atomic::Ordering::SeqCst)
121            );
122            Ok(quote! {
123                fn #fn_name (_var: #ty) { ensure_is_transmittable::<#ty>(); }
124            })
125        }
126        syn::Type::Tuple(tuple) => {
127            let mut check_functions = vec![];
128            for ty in tuple.elems.iter() {
129                check_functions.push(ensure_type_exists_for(ty)?);
130            }
131            Ok(quote! {
132                #(#check_functions)*
133            })
134        }
135        syn::Type::Array(array) => ensure_type_exists_for(&array.elem),
136        _ => Err(syn::Error::new(ty.span(), format!("The type `{}` is unsupported", type_to_string(ty)))
137            .to_compile_error()),
138    }
139}
140
141fn generate_transmittable_checks_enum(
142    ast: &syn::DeriveInput,
143    enm: &syn::DataEnum,
144) -> Result<proc_macro2::TokenStream, proc_macro2::TokenStream> {
145    let mut variants = Vec::new();
146
147    let surrounding_function = format_ident!("ensure_members_are_transmittable_for_{}", ast.ident);
148    for variant in &enm.variants {
149        let fields = match &variant.fields {
150            syn::Fields::Named(fields) => {
151                fields.named.iter().map(|f| ensure_type_exists_for(&f.ty)).collect()
152            }
153            syn::Fields::Unnamed(fields) => {
154                fields.unnamed.iter().map(|f| ensure_type_exists_for(&f.ty)).collect()
155            }
156            syn::Fields::Unit => Vec::new(),
157        };
158
159        let mut vetted_fields = vec![];
160        for field in fields {
161            match field {
162                Ok(f) => vetted_fields.push(f),
163                Err(e) => return Err(e),
164            }
165        }
166
167        variants.push(quote! {
168                #(#vetted_fields)*
169        });
170    }
171
172    Ok(quote! {
173        #[allow(non_snake_case, dead_code)]
174        fn #surrounding_function () {
175            pub fn ensure_is_transmittable<T: flatipc::IpcSafe>() {}
176            #(#variants)*
177        }
178
179    })
180}
181
182fn generate_transmittable_checks_struct(
183    ast: &syn::DeriveInput,
184    strct: &syn::DataStruct,
185) -> Result<proc_macro2::TokenStream, proc_macro2::TokenStream> {
186    let surrounding_function = format_ident!("ensure_members_are_transmittable_for_{}", ast.ident);
187    let fields = match &strct.fields {
188        syn::Fields::Named(fields) => fields.named.iter().map(|f| ensure_type_exists_for(&f.ty)).collect(),
189        syn::Fields::Unnamed(fields) => {
190            fields.unnamed.iter().map(|f| ensure_type_exists_for(&f.ty)).collect()
191        }
192        syn::Fields::Unit => Vec::new(),
193    };
194    let mut vetted_fields = vec![];
195    for field in fields {
196        match field {
197            Ok(f) => vetted_fields.push(f),
198            Err(e) => return Err(e),
199        }
200    }
201    Ok(quote! {
202        #[allow(non_snake_case, dead_code)]
203        fn #surrounding_function () {
204            pub fn ensure_is_transmittable<T: flatipc::IpcSafe>() {}
205            #(#vetted_fields)*
206        }
207    })
208}
209
210fn generate_transmittable_checks_union(
211    ast: &syn::DeriveInput,
212    unn: &syn::DataUnion,
213) -> Result<proc_macro2::TokenStream, proc_macro2::TokenStream> {
214    let surrounding_function = format_ident!("ensure_members_are_transmittable_for_{}", ast.ident);
215    let fields: Vec<Result<proc_macro2::TokenStream, proc_macro2::TokenStream>> =
216        unn.fields.named.iter().map(|f| ensure_type_exists_for(&f.ty)).collect();
217
218    let mut vetted_fields = vec![];
219    for field in fields {
220        match field {
221            Ok(f) => vetted_fields.push(f),
222            Err(e) => return Err(e),
223        }
224    }
225    Ok(quote! {
226        #[allow(non_snake_case, dead_code)]
227        fn #surrounding_function () {
228            pub fn ensure_is_transmittable<T: flatipc::IpcSafe>() {}
229            #(#vetted_fields)*
230        }
231    })
232}
233
234fn generate_ipc_struct(ast: &DeriveInput) -> Result<proc_macro2::TokenStream, proc_macro2::TokenStream> {
235    let visibility = ast.vis.clone();
236    let ident = ast.ident.clone();
237    let ipc_ident = format_ident!("Ipc{}", ast.ident);
238    let ident_size = quote! { core::mem::size_of::< #ident >() };
239    let padded_size = quote! { (#ident_size + (4096 - 1)) & !(4096 - 1) };
240    let padding_size = quote! { #padded_size - #ident_size };
241    let hash = ast_hash(ast);
242
243    let build_message = quote! {
244        use xous::definitions::{MemoryMessage, MemoryAddress, MemoryRange};
245        let mut buf = unsafe { MemoryRange::new(data.as_ptr() as usize, data.len()) }.unwrap();
246        let msg = MemoryMessage {
247            id: opcode,
248            buf,
249            offset: MemoryAddress::new(signature),
250            valid: None,
251        };
252    };
253
254    let lend = if cfg!(feature = "xous") {
255        quote! {
256            #build_message
257            xous::send_message(connection, xous::Message::MutableBorrow(msg))?;
258        }
259    } else {
260        quote! {
261            flatipc::backend::mock::IPC_MACHINE.lock().unwrap().lend(connection, opcode, signature, 0, &data);
262        }
263    };
264
265    let try_lend = if cfg!(feature = "xous") {
266        quote! {
267            #build_message
268            xous::try_send_message(connection, xous::Message::MutableBorrow(msg))?;
269        }
270    } else {
271        quote! {
272            flatipc::backend::mock::IPC_MACHINE.lock().unwrap().lend(connection, opcode, signature, 0, &data);
273        }
274    };
275
276    let lend_mut = if cfg!(feature = "xous") {
277        quote! {
278            #build_message
279            xous::send_message(connection, xous::Message::MutableBorrow(msg))?;
280        }
281    } else {
282        quote! {
283            flatipc::backend::mock::IPC_MACHINE.lock().unwrap().lend_mut(connection, opcode, signature, 0, &mut data);
284        }
285    };
286
287    let try_lend_mut = if cfg!(feature = "xous") {
288        quote! {
289            #build_message
290            xous::try_send_message(connection, xous::Message::MutableBorrow(msg))?;
291        }
292    } else {
293        quote! {
294            flatipc::backend::mock::IPC_MACHINE.lock().unwrap().lend_mut(connection, opcode, signature, 0, &mut data);
295        }
296    };
297
298    let memory_messages = if cfg!(feature = "xous") {
299        quote! {
300            fn from_memory_message<'a>(msg: &'a xous::MemoryMessage) -> Option<&'a Self> {
301                if msg.buf.len() < core::mem::size_of::< #ipc_ident >() {
302                    return None;
303                }
304                let signature = msg.offset.map(|offset| offset.get()).unwrap_or_default();
305                if signature != #hash {
306                    return None;
307                }
308                unsafe { Some(&*(msg.buf.as_ptr() as *const #ipc_ident)) }
309            }
310
311            fn from_memory_message_mut<'a>(msg: &'a mut xous::MemoryMessage) -> Option<&'a mut Self> {
312                if msg.buf.len() < core::mem::size_of::< #ipc_ident >() {
313                    return None;
314                }
315                let signature = msg.offset.map(|offset| offset.get()).unwrap_or_default();
316                if signature != #hash {
317                    return None;
318                }
319                unsafe { Some(&mut *(msg.buf.as_mut_ptr() as *mut #ipc_ident)) }
320            }
321        }
322    } else {
323        quote! {}
324    };
325
326    Ok(quote! {
327        #[repr(C, align(4096))]
328        #visibility struct #ipc_ident {
329            original: #ident,
330            padding: [u8; #padding_size],
331        }
332
333        impl core::ops::Deref for #ipc_ident {
334            type Target = #ident ;
335            fn deref(&self) -> &Self::Target {
336                &self.original
337            }
338        }
339
340        impl core::ops::DerefMut for #ipc_ident {
341            fn deref_mut(&mut self) -> &mut Self::Target {
342                &mut self.original
343            }
344        }
345
346        impl flatipc::IntoIpc for #ident {
347            type IpcType = #ipc_ident;
348            fn into_ipc(self) -> Self::IpcType {
349                #ipc_ident {
350                    original: self,
351                    padding: [0; #padding_size],
352                }
353            }
354        }
355
356        unsafe impl flatipc::Ipc for #ipc_ident {
357            type Original = #ident ;
358
359            fn from_slice<'a>(data: &'a [u8], signature: usize) -> Option<&'a Self> {
360                if data.len() < core::mem::size_of::< #ipc_ident >() {
361                    return None;
362                }
363                if signature != #hash {
364                    return None;
365                }
366                unsafe { Some(&*(data.as_ptr() as *const u8 as *const #ipc_ident)) }
367            }
368
369            unsafe fn from_buffer_unchecked<'a>(data: &'a [u8]) -> &'a Self {
370                &*(data.as_ptr() as *const u8 as *const #ipc_ident)
371            }
372
373            fn from_slice_mut<'a>(data: &'a mut [u8], signature: usize) -> Option<&'a mut Self> {
374                if data.len() < core::mem::size_of::< #ipc_ident >() {
375                    return None;
376                }
377                if signature != #hash {
378                    return None;
379                }
380                unsafe { Some(&mut *(data.as_mut_ptr() as *mut u8 as *mut #ipc_ident)) }
381            }
382
383            unsafe fn from_buffer_mut_unchecked<'a>(data: &'a mut [u8]) -> &'a mut Self {
384                unsafe { &mut *(data.as_mut_ptr() as *mut u8 as *mut #ipc_ident) }
385            }
386
387            fn lend(&self, connection: flatipc::CID, opcode: usize) -> Result<(), flatipc::Error> {
388                let signature = self.signature();
389                let data = unsafe {
390                    core::slice::from_raw_parts(
391                        self as *const #ipc_ident as *const u8,
392                        core::mem::size_of::< #ipc_ident >(),
393                    )
394                };
395                #lend
396                Ok(())
397            }
398
399            fn try_lend(&self, connection: flatipc::CID, opcode: usize) -> Result<(), flatipc::Error> {
400                let signature = self.signature();
401                let data = unsafe {
402                    core::slice::from_raw_parts(
403                        self as *const #ipc_ident as *const u8,
404                        core::mem::size_of::< #ipc_ident >(),
405                    )
406                };
407                #try_lend
408                Ok(())
409            }
410
411            fn lend_mut(&mut self, connection: flatipc::CID, opcode: usize) -> Result<(), flatipc::Error> {
412                let signature = self.signature();
413                let mut data = unsafe {
414                    core::slice::from_raw_parts_mut(
415                        self as *mut #ipc_ident as *mut u8,
416                        #padded_size,
417                    )
418                };
419                #lend_mut
420                Ok(())
421            }
422
423            fn try_lend_mut(&mut self, connection: flatipc::CID, opcode: usize) -> Result<(), flatipc::Error> {
424                let signature = self.signature();
425                let mut data = unsafe {
426                    core::slice::from_raw_parts_mut(
427                        self as *mut #ipc_ident as *mut u8,
428                        #padded_size,
429                    )
430                };
431                #try_lend_mut
432                Ok(())
433            }
434
435            fn as_original(&self) -> &Self::Original {
436                &self.original
437            }
438
439            fn as_original_mut(&mut self) -> &mut Self::Original {
440                &mut self.original
441            }
442
443            fn into_original(self) -> Self::Original {
444                self.original
445            }
446
447            fn signature(&self) -> usize {
448                #hash
449            }
450
451            #memory_messages
452        }
453    })
454}