Skip to main content

java_jni_extras/
lib.rs

1extern crate proc_macro;
2use proc_macro::TokenStream;
3use proc_macro2::{Span, TokenStream as TokenStream2};
4use quote::quote;
5use syn::{Ident, LitStr, Token, Result, parse::{Parse, ParseStream}, parse_macro_input};
6
7struct JavaClass {
8    package: String,
9    name: Ident,
10    methods: Vec<JavaMethod>,
11}
12
13struct JavaMethod {
14    is_static: bool,
15    is_native: bool,
16    return_type: JavaType,
17    name: Ident,
18    params: Vec<JavaParam>,
19}
20
21struct JavaParam {
22    ty: JavaType,
23    name: Ident,
24}
25
26#[derive(Clone)]
27enum JavaType {
28    Void,
29    Boolean,
30    Byte,
31    Char,
32    Short,
33    Int,
34    Long,
35    Float,
36    Double,
37    Object(String),
38    Array(Box<JavaType>),
39}
40
41impl JavaType {
42    fn to_jni_sig(&self) -> String {
43        match self {
44            JavaType::Void => "V".to_string(),
45            JavaType::Boolean => "Z".to_string(),
46            JavaType::Byte => "B".to_string(),
47            JavaType::Char => "C".to_string(),
48            JavaType::Short => "S".to_string(),
49            JavaType::Int => "I".to_string(),
50            JavaType::Long => "J".to_string(),
51            JavaType::Float => "F".to_string(),
52            JavaType::Double => "D".to_string(),
53            JavaType::Object(name) => {
54                let slashed = name.replace('.', "/");
55                // handle common unqualified names
56                let resolved = match slashed.as_str() {
57                    "String" => "java/lang/String".to_string(),
58                    "Object" => "java/lang/Object".to_string(),
59                    other => other.to_string(),
60                };
61                format!("L{};", resolved)
62            }
63            JavaType::Array(inner) => format!("[{}", inner.to_jni_sig()),
64        }
65    }
66
67    fn to_rust_type(&self) -> TokenStream2 {
68        match self {
69            JavaType::Void => quote! { () },
70            JavaType::Boolean => quote! { bool },
71            JavaType::Byte => quote! { i8 },
72            JavaType::Char => quote! { u16 },
73            JavaType::Short => quote! { i16 },
74            JavaType::Int => quote! { i32 },
75            JavaType::Long => quote! { i64 },
76            JavaType::Float => quote! { f32 },
77            JavaType::Double => quote! { f64 },
78            JavaType::Object(name) if name == "String" => quote! { &'refs str },
79            JavaType::Object(_) => quote! { jni::objects::JObject<'caller> },
80            JavaType::Array(_) => quote! { jni::objects::JObject<'caller> },
81        }
82    }
83    fn to_rust_return_type(&self) -> TokenStream2 {
84        match self {
85            JavaType::Void => quote! { () },
86            JavaType::Boolean => quote! { bool },
87            JavaType::Byte => quote! { i8 },
88            JavaType::Char => quote! { u16 },
89            JavaType::Short => quote! { i16 },
90            JavaType::Int => quote! { i32 },
91            JavaType::Long => quote! { i64 },
92            JavaType::Float => quote! { f32 },
93            JavaType::Double => quote! { f64 },
94            JavaType::Object(name) if name == "String" => quote! { String },
95            JavaType::Object(_) => quote! { jni::objects::JObject<'caller> },
96            JavaType::Array(_) => quote! { jni::objects::JObject<'caller> },
97        }
98    }
99
100    fn to_jvalue(&self, ident: &Ident) -> TokenStream2 {
101        match self {
102            JavaType::Boolean => quote! { jni::objects::JValue::Bool(#ident as u8) },
103            JavaType::Byte => quote! { jni::objects::JValue::Byte(#ident) },
104            JavaType::Char => quote! { jni::objects::JValue::Char(#ident) },
105            JavaType::Short => quote! { jni::objects::JValue::Short(#ident) },
106            JavaType::Int => quote! { jni::objects::JValue::Int(#ident) },
107            JavaType::Long => quote! { jni::objects::JValue::Long(#ident) },
108            JavaType::Float => quote! { jni::objects::JValue::Float(#ident) },
109            JavaType::Double => quote! { jni::objects::JValue::Double(#ident) },
110            JavaType::Object(name) if name == "String" => {
111                let tmp = Ident::new(&format!("__jstr_{}", ident), Span::call_site());
112                quote! { jni::objects::JValue::Object(&(#tmp).into()) }
113            },
114            JavaType::Object(_) | JavaType::Array(_) => {
115                quote! { jni::objects::JValue::Object(&#ident) }
116            }
117            JavaType::Void => quote! { compile_error!("void cannot be a parameter type") },
118        }
119    }
120
121    fn extract_return(&self) -> TokenStream2 {
122        match self {
123            JavaType::Void => quote! { ; },
124            JavaType::Boolean => quote! { .z() },
125            JavaType::Byte => quote! { .b() },
126            JavaType::Char => quote! { .c() },
127            JavaType::Short => quote! { .s() },
128            JavaType::Int => quote! { .i() },
129            JavaType::Long => quote! { .j() },
130            JavaType::Float => quote! { .f() },
131            JavaType::Double => quote! { .d() },
132            JavaType::Object(_) | JavaType::Array(_) => quote! { .l() },
133        }
134    }
135
136    fn to_jni_param_type(&self) -> TokenStream2 {
137        match self {
138            JavaType::Void => quote!(),
139            JavaType::Boolean => quote!(boolean),
140            JavaType::Byte => quote!(byte),
141            JavaType::Char => quote!(char),
142            JavaType::Short => quote!(short),
143            JavaType::Int => quote!(int),
144            JavaType::Long => quote!(jlong),
145            JavaType::Float => quote!(float),
146            JavaType::Double => quote!(double),
147            JavaType::Object(x) => quote!(#x),
148            JavaType::Array(_) => quote!(todo!()),
149        }
150    }
151    fn to_jni_return_type(&self) -> TokenStream2 {
152        match self {
153            JavaType::Void => quote!(),
154            JavaType::Boolean => quote!(-> boolean),
155            JavaType::Byte => quote!(-> byte),
156            JavaType::Char => quote!(-> char),
157            JavaType::Short => quote!(-> short),
158            JavaType::Int => quote!(-> int),
159            JavaType::Long => quote!(-> jlong),
160            JavaType::Float => quote!(-> float),
161            JavaType::Double => quote!(-> double),
162            JavaType::Object(x) => quote!(-> #x),
163            JavaType::Array(_) => quote!(-> todo!()),
164        }
165    }
166}
167
168fn parse_java_type(input: ParseStream) -> Result<JavaType> {
169    let ty = if input.peek(Ident) {
170        let ident: Ident = input.parse()?;
171        match ident.to_string().as_str() {
172            "void" => JavaType::Void,
173            "boolean" => JavaType::Boolean,
174            "byte" => JavaType::Byte,
175            "char" => JavaType::Char,
176            "short" => JavaType::Short,
177            "int" => JavaType::Int,
178            "long" => JavaType::Long,
179            "float" => JavaType::Float,
180            "double" => JavaType::Double,
181            other => {
182                let mut name = other.to_string();
183                while input.peek(Token![.]) {
184                    input.parse::<Token![.]>()?;
185                    let next: Ident = input.parse()?;
186                    name.push('.');
187                    name.push_str(&next.to_string());
188                }
189                JavaType::Object(name)
190            }
191        }
192    } else {
193        return Err(input.error("expected Java type"));
194    };
195
196    if input.peek(syn::token::Bracket) {
197        let content;
198        syn::bracketed!(content in input);
199        let _ = content; // empty []
200        return Ok(JavaType::Array(Box::new(ty)));
201    }
202
203    Ok(ty)
204}
205
206impl Parse for JavaMethod {
207    fn parse(input: ParseStream) -> Result<Self> {
208        let mut is_static = false;
209        let mut is_native = false;
210
211        loop {
212            if input.peek(Token![static]) {
213                input.parse::<Token![static]>()?;
214                is_static = true;
215            } else if input.peek(Ident) {
216                let ident: Ident = input.fork().parse()?;
217                match ident.to_string().as_str() {
218                    "native" => { input.parse::<Ident>()?; is_native = true; }
219                    "public" | "private" | "protected" | "final" | "synchronized" => {
220                        input.parse::<Ident>()?;
221                    }
222                    _ => break,
223                }
224            } else {
225                break;
226            }
227        }
228
229        let return_type = parse_java_type(input)?;
230        let name: Ident = input.parse()?;
231
232        let content;
233        syn::parenthesized!(content in input);
234
235        let mut params = Vec::new();
236        while !content.is_empty() {
237            let ty = parse_java_type(&content)?;
238            let param_name: Ident = content.parse()?;
239            params.push(JavaParam { ty, name: param_name });
240            if content.peek(Token![,]) {
241                content.parse::<Token![,]>()?;
242            }
243        }
244
245        input.parse::<Token![;]>()?;
246
247        Ok(JavaMethod { is_static, is_native, return_type, name, params })
248    }
249}
250
251impl Parse for JavaClass {
252    fn parse(input: ParseStream) -> Result<Self> {
253        let pkg_kw: Ident = input.parse()?;
254        if pkg_kw != "package" {
255            return Err(syn::Error::new(pkg_kw.span(), "expected 'package'"));
256        }
257
258        let mut package = String::new();
259        loop {
260            let ident: Ident = input.parse()?;
261            package.push_str(&ident.to_string());
262            if input.peek(Token![;]) {
263                input.parse::<Token![;]>()?;
264                break;
265            }
266            input.parse::<Token![.]>()?;
267            package.push('.');
268        }
269
270        let class_kw: Ident = input.parse()?;
271        if class_kw != "class" {
272            return Err(syn::Error::new(class_kw.span(), "expected 'class'"));
273        }
274
275        let name: Ident = input.parse()?;
276
277        let content;
278        syn::braced!(content in input);
279
280        let mut methods = Vec::new();
281        while !content.is_empty() {
282            methods.push(content.parse::<JavaMethod>()?);
283        }
284
285        Ok(JavaClass { package, name, methods })
286    }
287}
288
289fn class_path(package: &str, name: &str) -> String {
290    format!("{}.{}", package, name).replace('.', "/")
291}
292
293fn generate_method(
294    class_path_lit: &LitStr,
295    method: &JavaMethod,
296) -> TokenStream2 {
297    let method_name = &method.name;
298    let method_name_str = method_name.to_string();
299
300    let param_sig: String = method.params.iter().map(|p| p.ty.to_jni_sig()).collect();
301    let return_sig = method.return_type.to_jni_sig();
302    let full_sig = format!("({}){}", param_sig, return_sig);
303    let sig_lit = LitStr::new(&full_sig, Span::call_site());
304    let method_name_lit = LitStr::new(&method_name_str, Span::call_site());
305
306    let rust_params: Vec<TokenStream2> = method.params.iter().map(|p| {
307        let pname = &p.name;
308        let pty = p.ty.to_rust_type();
309        quote! { #pname: #pty }
310    }).collect();
311
312    let string_conversions: Vec<TokenStream2> = method.params.iter().map(|p| {
313        let pname = &p.name;
314        match &p.ty {
315            JavaType::Object(name) if name == "String" => {
316                let tmp = Ident::new(&format!("__jstr_{}", pname), Span::call_site());
317                quote! { let #tmp = env.new_string(#pname)?; }
318            }
319            _ => quote! {},
320        }
321    }).collect();
322
323    let jvalues: Vec<TokenStream2> = method.params.iter().map(|p| {
324        p.ty.to_jvalue(&p.name)
325    }).collect();
326
327    let return_type = method.return_type.to_rust_return_type();
328
329    let call = if method.is_static {
330        quote! { env.call_static_method(jni_str!(#class_path_lit), jni_str!(#method_name_lit), jni_sig!(#sig_lit), &[#(#jvalues),*])? }
331    } else {
332        quote! { env.call_method(obj, jni_str!(#method_name_lit), jni_sig!(#sig_lit), &[#(#jvalues),*])? }
333    };
334
335    let body = match &method.return_type {
336        JavaType::Void => quote! {
337            #(#string_conversions)*
338            #call;
339            Ok(())
340        },
341        _ => {
342            let extract = method.return_type.extract_return();
343            quote! {
344                #(#string_conversions)*
345                #call #extract
346            }
347
348        }
349    };
350
351    if method.is_static {
352        quote! {
353            pub fn #method_name<'caller, 'refs>(
354                env: &'refs mut jni::Env<'caller>,
355                #(#rust_params),*
356            ) -> Result<#return_type, jni::errors::Error> {
357                #body
358            }
359        }
360    } else {
361        quote! {
362            pub fn #method_name<'caller, 'refs>(
363                env: &'refs mut jni::Env<'caller>,
364                obj: &'refs jni::objects::JObject<'caller>,
365                #(#rust_params),*
366            ) -> Result<#return_type, jni::errors::Error> {
367                #body
368            }
369        }
370    }
371}
372
373#[proc_macro]
374pub fn java_class_decl(input: TokenStream) -> TokenStream {
375    let class = parse_macro_input!(input as JavaClass);
376
377    let struct_name = &class.name;
378    let cp = class_path(&class.package, &class.name.to_string());
379    let class_path_lit = LitStr::new(&cp, Span::call_site());
380
381    let native_registrations: Vec<TokenStream2> = class.methods.iter()
382        .filter(|m| m.is_native)
383        .map(|m| {
384            let name = &m.name;
385            let package_class = format!("{}.{}", class.package, class.name);
386            let package_class_lit = LitStr::new(&package_class, Span::call_site());
387            let return_type = m.return_type.to_jni_return_type();
388            let params: Vec<TokenStream2> = m.params.iter().map(|p| {
389                p.ty.to_jni_param_type()
390            }).collect();
391
392            if m.is_static {
393                quote! {
394                    const _: jni::NativeMethod = jni::native_method! {
395                        java_type = #package_class_lit,
396                        static extern fn #name(#(#params),*) #return_type,
397                    };
398                }
399            } else {
400                quote! {
401                    const _: jni::NativeMethod = jni::native_method! {
402                        java_type = #package_class_lit,
403                        extern fn #name(#(#params),*) #return_type,
404                    };
405                }
406            }
407        })
408        .collect();
409
410    let validate_checks: Vec<TokenStream2> = class.methods.iter()
411        .filter(|m| !m.is_native)
412        .map(|m| {
413            let method_name_str = m.name.to_string();
414            let method_name_lit = LitStr::new(&method_name_str, Span::call_site());
415            let param_sig: String = m.params.iter().map(|p| p.ty.to_jni_sig()).collect();
416            let return_sig = m.return_type.to_jni_sig();
417            let sig_lit = LitStr::new(&format!("({}){}", param_sig, return_sig), Span::call_site());
418            let class_path_lit_str = LitStr::new(&format!("{}.{}", class.package, class.name), Span::call_site());
419
420            if m.is_static {
421                quote! {
422                env.get_static_method_id(
423                    jni_str!(#class_path_lit_str),
424                    jni_str!(#method_name_lit),
425                    jni_sig!(#sig_lit),
426                )?;
427            }
428            } else {
429                quote! {
430                env.get_method_id(
431                    jni_str!(#class_path_lit_str),
432                    jni_str!(#method_name_lit),
433                    jni_sig!(#sig_lit),
434                )?;
435            }
436            }
437        })
438        .collect();
439
440    let methods: Vec<TokenStream2> = class.methods.iter()
441        .filter(|m| !m.is_native)
442        .map(|m| generate_method(&class_path_lit, m))
443        .collect();
444
445
446    let expanded = quote! {
447        #(#native_registrations)*
448
449        pub struct #struct_name;
450
451        impl #struct_name {
452            pub fn _validate_interface(env: &mut jni::Env<'_>) -> Result<(), jni::errors::Error> {
453                #(#validate_checks)*
454                Ok(())
455            }
456
457            #(#methods)*
458        }
459    };
460
461    expanded.into()
462}