jni_sys_macros/
lib.rs

1extern crate proc_macro;
2
3use std::{cmp::Ordering, collections::HashSet};
4
5use proc_macro::TokenStream;
6use quote::{format_ident, quote};
7use syn::{parse_macro_input, spanned::Spanned, Data, DeriveInput, Fields, Ident, LitStr};
8
9#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
10struct JniVersion {
11    major: u16,
12    minor: u16,
13}
14impl Default for JniVersion {
15    fn default() -> Self {
16        Self { major: 1, minor: 1 }
17    }
18}
19impl Ord for JniVersion {
20    fn cmp(&self, other: &Self) -> Ordering {
21        match self.major.cmp(&other.major) {
22            Ordering::Equal => self.minor.cmp(&other.minor),
23            major_order => major_order,
24        }
25    }
26}
27impl PartialOrd for JniVersion {
28    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
29        Some(self.cmp(other))
30    }
31}
32
33impl syn::parse::Parse for JniVersion {
34    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
35        let version: LitStr = input.parse()?;
36        let version = version.value();
37        if version == "reserved" {
38            // We special case version 999 later instead of making JniVersion an enum
39            return Ok(JniVersion {
40                major: 999,
41                minor: 0,
42            });
43        }
44        let mut split = version.splitn(2, '.');
45        const EXPECTED_MSG: &str = "Expected \"major.minor\" version number or \"reserved\"";
46        let major = split
47            .next()
48            .ok_or(syn::Error::new(input.span(), EXPECTED_MSG))?;
49        let major = major
50            .parse::<u16>()
51            .map_err(|_| syn::Error::new(input.span(), EXPECTED_MSG))?;
52        let minor = split
53            .next()
54            .unwrap_or("0")
55            .parse::<u16>()
56            .map_err(|_| syn::Error::new(input.span(), EXPECTED_MSG))?;
57        Ok(JniVersion { major, minor })
58    }
59}
60
61fn jni_to_union_impl(input: DeriveInput) -> syn::Result<TokenStream> {
62    let original_name = &input.ident;
63    let original_visibility = &input.vis;
64
65    let mut versions = HashSet::new();
66    let mut versioned_fields = vec![];
67
68    if let Data::Struct(data) = &input.data {
69        if let Fields::Named(fields) = &data.fields {
70            for field in &fields.named {
71                // Default to version 1.1
72                let mut min_version = JniVersion::default();
73
74                let mut field = field.clone();
75
76                let mut jni_added_attr = None;
77                field.attrs.retain(|attr| {
78                    if attr.path().is_ident("jni_added") {
79                        jni_added_attr = Some(attr.clone());
80                        false
81                    } else {
82                        true
83                    }
84                });
85                if let Some(attr) = jni_added_attr {
86                    let version = attr.parse_args::<JniVersion>()?;
87                    min_version = version;
88                }
89
90                versions.insert(min_version);
91                versioned_fields.push((min_version, field.clone()));
92            }
93
94            // Quote structs and union
95            let mut expanded = quote! {};
96
97            let mut union_members = quote!();
98
99            let mut versions: Vec<_> = versions.into_iter().collect();
100            versions.sort();
101
102            for version in versions {
103                let (struct_ident, version_ident, version_suffix) = if version.major == 999 {
104                    (
105                        Ident::new(&format!("{}_reserved", original_name), original_name.span()),
106                        Ident::new("reserved", original_name.span()),
107                        "reserved".to_string(),
108                    )
109                } else if version.minor == 0 {
110                    (
111                        Ident::new(
112                            &format!("{}_{}", original_name, version.major),
113                            original_name.span(),
114                        ),
115                        Ident::new(&format!("v{}", version.major), original_name.span()),
116                        format!("{}", version.major),
117                    )
118                } else {
119                    let struct_ident = Ident::new(
120                        &format!("{}_{}_{}", original_name, version.major, version.minor),
121                        original_name.span(),
122                    );
123                    let version_ident = Ident::new(
124                        &format!("v{}_{}", version.major, version.minor),
125                        original_name.span(),
126                    );
127                    (
128                        struct_ident,
129                        version_ident,
130                        format!("{}_{}", version.major, version.minor),
131                    )
132                };
133
134                let last = versioned_fields
135                    .iter()
136                    .rposition(|(v, _f)| v <= &version)
137                    .unwrap_or(versioned_fields.len());
138                let mut padding_idx = 0u32;
139
140                let mut version_field_tokens = quote!();
141                for (i, (field_min_version, field)) in versioned_fields.iter().enumerate() {
142                    if i > last {
143                        break;
144                    }
145                    if field_min_version > &version {
146                        let reserved_ident = format_ident!("_padding_{}", padding_idx);
147                        padding_idx += 1;
148                        version_field_tokens.extend(quote! { #reserved_ident: *mut c_void, });
149                    } else {
150                        version_field_tokens.extend(quote! { #field, });
151                    }
152                }
153                expanded.extend(quote! {
154                    #[allow(non_snake_case, non_camel_case_types)]
155                    #[repr(C)]
156                    #[derive(Copy, Clone)]
157                    #original_visibility struct #struct_ident {
158                        #version_field_tokens
159                    }
160                });
161
162                let api_comment =
163                    format!("API when JNI version >= `JNI_VERSION_{}`", version_suffix);
164                union_members.extend(quote! {
165                    #[doc = #api_comment]
166                    #original_visibility #version_ident: #struct_ident,
167                });
168            }
169
170            expanded.extend(quote! {
171                #[repr(C)]
172                #original_visibility union #original_name {
173                    #union_members
174                }
175            });
176
177            return Ok(TokenStream::from(expanded));
178        }
179    }
180
181    Err(syn::Error::new(
182        input.span(),
183        "Expected a struct with fields",
184    ))
185}
186
187#[proc_macro_attribute]
188pub fn jni_to_union(_attr: TokenStream, item: TokenStream) -> TokenStream {
189    let input = parse_macro_input!(item as DeriveInput);
190
191    match jni_to_union_impl(input) {
192        Ok(tokens) => tokens,
193        Err(err) => err.into_compile_error().into(),
194    }
195}