Skip to main content

java_jni_extras/
lib.rs

1extern crate proc_macro;
2use proc_macro::TokenStream;
3use proc_macro2::{Span, TokenStream as TokenStream2};
4use quote::quote;
5use syn::{Ident, LitStr, Token, Result, parse::{Parse, ParseStream}, parse_macro_input};
6
7struct JavaClass {
8    package: String,
9    name: Ident,
10    methods: Vec<JavaMethod>,
11}
12
13enum JavaMethod {
14    Constructor {
15        name: Ident,
16        params: Vec<JavaParam>,
17    },
18    Method {
19        is_static: bool,
20        is_native: bool,
21        return_type: JavaType,
22        name: Ident,
23        params: Vec<JavaParam>,
24    },
25}
26
27struct JavaParam {
28    ty: JavaType,
29    name: Ident,
30}
31
32#[derive(Clone)]
33enum JavaType {
34    Void,
35    Boolean,
36    Byte,
37    Char,
38    Short,
39    Int,
40    Long,
41    Float,
42    Double,
43    Object(String),
44    Array(Box<JavaType>),
45}
46
47impl JavaType {
48    fn to_jni_sig(&self) -> String {
49        match self {
50            JavaType::Void => "V".to_string(),
51            JavaType::Boolean => "Z".to_string(),
52            JavaType::Byte => "B".to_string(),
53            JavaType::Char => "C".to_string(),
54            JavaType::Short => "S".to_string(),
55            JavaType::Int => "I".to_string(),
56            JavaType::Long => "J".to_string(),
57            JavaType::Float => "F".to_string(),
58            JavaType::Double => "D".to_string(),
59            JavaType::Object(name) => {
60                let slashed = name.replace('.', "/");
61                // handle common unqualified names
62                let resolved = match slashed.as_str() {
63                    "String" => "java/lang/String".to_string(),
64                    "Object" => "java/lang/Object".to_string(),
65                    other => other.to_string(),
66                };
67                format!("L{};", resolved)
68            }
69            JavaType::Array(inner) => format!("[{}", inner.to_jni_sig()),
70        }
71    }
72
73    fn to_rust_type(&self) -> TokenStream2 {
74        match self {
75            JavaType::Void => quote! { () },
76            JavaType::Boolean => quote! { bool },
77            JavaType::Byte => quote! { i8 },
78            JavaType::Char => quote! { u16 },
79            JavaType::Short => quote! { i16 },
80            JavaType::Int => quote! { i32 },
81            JavaType::Long => quote! { i64 },
82            JavaType::Float => quote! { f32 },
83            JavaType::Double => quote! { f64 },
84            JavaType::Object(name) if name == "String" => quote! { &'refs str },
85            JavaType::Object(_) => quote! { jni::objects::JObject<'caller> },
86            JavaType::Array(_) => quote! { jni::objects::JObject<'caller> },
87        }
88    }
89    fn to_rust_return_type(&self) -> TokenStream2 {
90        match self {
91            JavaType::Void => quote! { () },
92            JavaType::Boolean => quote! { bool },
93            JavaType::Byte => quote! { i8 },
94            JavaType::Char => quote! { u16 },
95            JavaType::Short => quote! { i16 },
96            JavaType::Int => quote! { i32 },
97            JavaType::Long => quote! { i64 },
98            JavaType::Float => quote! { f32 },
99            JavaType::Double => quote! { f64 },
100            JavaType::Object(name) if name == "String" => quote! { String },
101            JavaType::Object(_) => quote! { jni::objects::JObject<'caller> },
102            JavaType::Array(_) => quote! { jni::objects::JObject<'caller> },
103        }
104    }
105
106    fn to_jvalue(&self, ident: &Ident) -> TokenStream2 {
107        match self {
108            JavaType::Boolean => quote! { jni::objects::JValue::Bool(#ident) },
109            JavaType::Byte => quote! { jni::objects::JValue::Byte(#ident) },
110            JavaType::Char => quote! { jni::objects::JValue::Char(#ident) },
111            JavaType::Short => quote! { jni::objects::JValue::Short(#ident) },
112            JavaType::Int => quote! { jni::objects::JValue::Int(#ident) },
113            JavaType::Long => quote! { jni::objects::JValue::Long(#ident) },
114            JavaType::Float => quote! { jni::objects::JValue::Float(#ident) },
115            JavaType::Double => quote! { jni::objects::JValue::Double(#ident) },
116            JavaType::Object(name) if name == "String" => {
117                let tmp = Ident::new(&format!("__jstr_{}", ident), Span::call_site());
118                quote! { jni::objects::JValue::Object(&(#tmp).into()) }
119            },
120            JavaType::Object(_) | JavaType::Array(_) => {
121                quote! { jni::objects::JValue::Object(&#ident) }
122            }
123            JavaType::Void => quote! { compile_error!("void cannot be a parameter type") },
124        }
125    }
126
127    fn extract_return(&self, call: TokenStream2) -> TokenStream2 {
128        match self {
129            JavaType::Void => quote! { #call; },
130            JavaType::Boolean => quote! { #call.z() },
131            JavaType::Byte => quote! { #call.b() },
132            JavaType::Char => quote! { #call.c() },
133            JavaType::Short => quote! { #call.s() },
134            JavaType::Int => quote! { #call.i() },
135            JavaType::Long => quote! { #call.j() },
136            JavaType::Float => quote! { #call.f() },
137            JavaType::Double => quote! { #call.d() },
138            JavaType::Object(name) if name == "String" => {
139                quote!( let o = #call.l()?; Ok(JString::cast_local(env, o)?.to_string()) )
140            }
141            JavaType::Object(_) | JavaType::Array(_) => quote! { #call.l() },
142        }
143    }
144
145    fn to_jni_param_type(&self) -> TokenStream2 {
146        match self {
147            JavaType::Void => quote!(),
148            JavaType::Boolean => quote!(boolean),
149            JavaType::Byte => quote!(byte),
150            JavaType::Char => quote!(char),
151            JavaType::Short => quote!(short),
152            JavaType::Int => quote!(int),
153            JavaType::Long => quote!(jlong),
154            JavaType::Float => quote!(float),
155            JavaType::Double => quote!(double),
156            JavaType::Object(x) => quote!(#x),
157            JavaType::Array(_) => quote!(todo!()),
158        }
159    }
160    fn to_jni_return_type(&self) -> TokenStream2 {
161        match self {
162            JavaType::Void => quote!(),
163            JavaType::Boolean => quote!(-> boolean),
164            JavaType::Byte => quote!(-> byte),
165            JavaType::Char => quote!(-> char),
166            JavaType::Short => quote!(-> short),
167            JavaType::Int => quote!(-> int),
168            JavaType::Long => quote!(-> jlong),
169            JavaType::Float => quote!(-> float),
170            JavaType::Double => quote!(-> double),
171            JavaType::Object(x) => quote!(-> #x),
172            JavaType::Array(_) => quote!(-> todo!()),
173        }
174    }
175}
176
177fn parse_java_type(input: ParseStream) -> Result<JavaType> {
178    let ty = if input.peek(Ident) {
179        let ident: Ident = input.parse()?;
180        match ident.to_string().as_str() {
181            "void" => JavaType::Void,
182            "boolean" => JavaType::Boolean,
183            "byte" => JavaType::Byte,
184            "char" => JavaType::Char,
185            "short" => JavaType::Short,
186            "int" => JavaType::Int,
187            "long" => JavaType::Long,
188            "float" => JavaType::Float,
189            "double" => JavaType::Double,
190            other => {
191                let mut name = other.to_string();
192                while input.peek(Token![.]) {
193                    input.parse::<Token![.]>()?;
194                    let next: Ident = input.parse()?;
195                    name.push('.');
196                    name.push_str(&next.to_string());
197                }
198                JavaType::Object(name)
199            }
200        }
201    } else {
202        return Err(input.error("expected Java type"));
203    };
204
205    if input.peek(syn::token::Bracket) {
206        let content;
207        syn::bracketed!(content in input);
208        let _ = content; // empty []
209        return Ok(JavaType::Array(Box::new(ty)));
210    }
211
212    Ok(ty)
213}
214
215impl Parse for JavaMethod {
216    fn parse(input: ParseStream) -> Result<Self> {
217        let mut is_static = false;
218        let mut is_native = false;
219
220        // parse modifiers
221        loop {
222            if input.peek(Token![static]) {
223                input.parse::<Token![static]>()?;
224                is_static = true;
225            } else if input.peek(Ident) {
226                let ident: Ident = input.fork().parse()?;
227                match ident.to_string().as_str() {
228                    "native" => { input.parse::<Ident>()?; is_native = true; }
229                    "public" | "private" | "protected" | "final" | "synchronized" => {
230                        input.parse::<Ident>()?;
231                    }
232                    _ => break,
233                }
234            } else {
235                break;
236            }
237        }
238
239        // peek: if next is Ident followed immediately by '(' it's a constructor
240        let is_constructor = input.peek(Ident) && {
241            let fork = input.fork();
242            let _: Ident = fork.parse()?;
243            fork.peek(syn::token::Paren)
244        };
245
246        if is_constructor {
247            let name: Ident = input.parse()?;
248            let content;
249            syn::parenthesized!(content in input);
250            let params = parse_params(&content)?;
251            input.parse::<Token![;]>()?;
252            return Ok(JavaMethod::Constructor { name, params });
253        }
254
255        let return_type = parse_java_type(input)?;
256        let name: Ident = input.parse()?;
257        let content;
258        syn::parenthesized!(content in input);
259        let params = parse_params(&content)?;
260        input.parse::<Token![;]>()?;
261
262        Ok(JavaMethod::Method { is_static, is_native, return_type, name, params })
263    }
264}
265
266fn parse_params(content: ParseStream) -> Result<Vec<JavaParam>> {
267    let mut params = Vec::new();
268    while !content.is_empty() {
269        let ty = parse_java_type(content)?;
270        let name: Ident = content.parse()?;
271        params.push(JavaParam { ty, name });
272        if content.peek(Token![,]) {
273            content.parse::<Token![,]>()?;
274        }
275    }
276    Ok(params)
277}
278
279impl Parse for JavaClass {
280    fn parse(input: ParseStream) -> Result<Self> {
281        let pkg_kw: Ident = input.parse()?;
282        if pkg_kw != "package" {
283            return Err(syn::Error::new(pkg_kw.span(), "expected 'package'"));
284        }
285
286        let mut package = String::new();
287        loop {
288            let ident: Ident = input.parse()?;
289            package.push_str(&ident.to_string());
290            if input.peek(Token![;]) {
291                input.parse::<Token![;]>()?;
292                break;
293            }
294            input.parse::<Token![.]>()?;
295            package.push('.');
296        }
297
298        let class_kw: Ident = input.parse()?;
299        if class_kw != "class" {
300            return Err(syn::Error::new(class_kw.span(), "expected 'class'"));
301        }
302
303        let name: Ident = input.parse()?;
304
305        let content;
306        syn::braced!(content in input);
307
308        let mut methods = Vec::new();
309        while !content.is_empty() {
310            methods.push(content.parse::<JavaMethod>()?);
311        }
312
313        Ok(JavaClass { package, name, methods })
314    }
315}
316
317fn class_path(package: &str, name: &str) -> String {
318    format!("{}.{}", package, name).replace('.', "/")
319}
320
321fn generate_method(
322    class_path_lit: &LitStr,
323    is_static: bool,
324    return_type: &JavaType,
325    name: &Ident,
326    params: &[JavaParam]
327) -> TokenStream2 {
328    let method_name = name;
329    let method_name_str = method_name.to_string();
330
331    let param_sig: String = params.iter().map(|p| p.ty.to_jni_sig()).collect();
332    let return_sig = return_type.to_jni_sig();
333    let full_sig = format!("({}){}", param_sig, return_sig);
334    let sig_lit = LitStr::new(&full_sig, Span::call_site());
335    let method_name_lit = LitStr::new(&method_name_str, Span::call_site());
336
337    let rust_params: Vec<TokenStream2> = params.iter().map(|p| {
338        let pname = &p.name;
339        let pty = p.ty.to_rust_type();
340        quote! { #pname: #pty }
341    }).collect();
342
343    let string_conversions: Vec<TokenStream2> = params.iter().map(|p| {
344        let pname = &p.name;
345        match &p.ty {
346            JavaType::Object(name) if name == "String" => {
347                let tmp = Ident::new(&format!("__jstr_{}", pname), Span::call_site());
348                quote! { let #tmp = env.new_string(#pname)?; }
349            }
350            _ => quote! {},
351        }
352    }).collect();
353
354    let jvalues: Vec<TokenStream2> = params.iter().map(|p| {
355        p.ty.to_jvalue(&p.name)
356    }).collect();
357
358    let return_type_ts = return_type.to_rust_return_type();
359
360    let call = if is_static {
361        quote! { env.call_static_method(jni_str!(#class_path_lit), jni_str!(#method_name_lit), jni_sig!(#sig_lit), &[#(#jvalues),*])? }
362    } else {
363        quote! { env.call_method(obj, jni_str!(#method_name_lit), jni_sig!(#sig_lit), &[#(#jvalues),*])? }
364    };
365
366    let body = match &return_type {
367        JavaType::Void => quote! {
368            #(#string_conversions)*
369            #call;
370            Ok(())
371        },
372        _ => {
373            let extract = return_type.extract_return(call);
374            quote! {
375                #(#string_conversions)*
376                #extract
377            }
378
379        }
380    };
381
382    if is_static {
383        quote! {
384            pub fn #method_name<'caller, 'refs>(
385                env: &'refs mut jni::Env<'caller>,
386                #(#rust_params),*
387            ) -> Result<#return_type_ts, jni::errors::Error> {
388                #body
389            }
390        }
391    } else {
392        quote! {
393            pub fn #method_name<'caller, 'refs>(
394                env: &'refs mut jni::Env<'caller>,
395                obj: &'refs jni::objects::JObject<'caller>,
396                #(#rust_params),*
397            ) -> Result<#return_type_ts, jni::errors::Error> {
398                #body
399            }
400        }
401    }
402}
403
404fn generate_constructor(
405    class_path_lit: &LitStr,
406    name: &Ident,
407    params: &[JavaParam],
408) -> TokenStream2 {
409    let param_sig: String = params.iter().map(|p| p.ty.to_jni_sig()).collect();
410    let sig_lit = LitStr::new(&format!("({})V", param_sig), Span::call_site());
411
412    let rust_params: Vec<TokenStream2> = params.iter().map(|p| {
413        let pname = &p.name;
414        let pty = p.ty.to_rust_type();
415        quote! { #pname: #pty }
416    }).collect();
417
418    let string_conversions: Vec<TokenStream2> = params.iter().map(|p| {
419        let pname = &p.name;
420        match &p.ty {
421            JavaType::Object(n) if n == "String" => {
422                let tmp = Ident::new(&format!("__jstr_{}", pname), Span::call_site());
423                quote! { let #tmp = env.new_string(#pname)?; }
424            }
425            _ => quote! {},
426        }
427    }).collect();
428
429    let jvalues: Vec<TokenStream2> = params.iter().map(|p| {
430        p.ty.to_jvalue(&p.name)
431    }).collect();
432
433    quote! {
434        pub fn #name<'caller>(
435            env: &mut jni::Env<'caller>,
436            #(#rust_params),*
437        ) -> Result<jni::objects::JObject<'caller>, jni::errors::Error> {
438            #(#string_conversions)*
439            env.new_object(
440                jni_str!(#class_path_lit),
441                jni_sig!(#sig_lit),
442                &[#(#jvalues),*]
443            )
444        }
445    }
446}
447
448#[proc_macro]
449pub fn java_class_decl(input: TokenStream) -> TokenStream {
450    let class = parse_macro_input!(input as JavaClass);
451
452    let struct_name = &class.name;
453    let cp = class_path(&class.package, &class.name.to_string());
454    let class_path_lit = LitStr::new(&cp, Span::call_site());
455
456    let native_registrations: Vec<TokenStream2> = class.methods.iter()
457        .filter_map(|m| match m {
458            JavaMethod::Method { is_native: true, name, params, return_type, is_static, .. } => {
459
460                let package_class = format!("{}.{}", class.package, class.name);
461                let package_class_lit = LitStr::new(&package_class, Span::call_site());
462                let return_type = return_type.to_jni_return_type();
463                let params: Vec<TokenStream2> = params.iter().map(|p| {
464                    p.ty.to_jni_param_type()
465                }).collect();
466
467                Some(if *is_static {
468                    quote! {
469                        const _: jni::NativeMethod = jni::native_method! {
470                            java_type = #package_class_lit,
471                            static extern fn #name(#(#params),*) #return_type,
472                        };
473                    }
474                } else {
475                    quote! {
476                        const _: jni::NativeMethod = jni::native_method! {
477                            java_type = #package_class_lit,
478                            extern fn #name(#(#params),*) #return_type,
479                        };
480                    }
481                })
482            }
483            _ => None,
484        })
485        .collect();
486
487    let validate_checks: Vec<TokenStream2> = class.methods.iter()
488        .filter_map(|m| match m {
489            JavaMethod::Method { is_native: false, name, params, return_type, is_static, .. } => {
490
491                let method_name_str = name.to_string();
492                let method_name_lit = LitStr::new(&method_name_str, Span::call_site());
493                let param_sig: String = params.iter().map(|p| p.ty.to_jni_sig()).collect();
494                let return_sig = return_type.to_jni_sig();
495                let sig_lit = LitStr::new(&format!("({}){}", param_sig, return_sig), Span::call_site());
496                let class_path_lit_str = LitStr::new(&format!("{}.{}", class.package, class.name), Span::call_site());
497
498                Some(if *is_static {
499                    quote! {
500                        env.get_static_method_id(
501                            jni_str!(#class_path_lit_str),
502                            jni_str!(#method_name_lit),
503                            jni_sig!(#sig_lit),
504                        )?;
505                    }
506                } else {
507                    quote! {
508                        env.get_method_id(
509                            jni_str!(#class_path_lit_str),
510                            jni_str!(#method_name_lit),
511                            jni_sig!(#sig_lit),
512                        )?;
513                    }
514                })
515            }
516            _ => None,
517        })
518        .collect();
519
520    let methods: Vec<TokenStream2> = class.methods.iter()
521        .filter_map(|m| match m {
522            JavaMethod::Constructor { name, params } => {
523                Some(generate_constructor(&class_path_lit, name, params))
524            }
525            JavaMethod::Method { is_static, is_native: false, return_type, name, params, } => {
526                Some(generate_method(&class_path_lit, *is_static, return_type, name, params))
527            }
528            _ => None,
529        })
530        .collect();
531
532
533    let expanded = quote! {
534        #(#native_registrations)*
535
536        pub struct #struct_name;
537
538        impl #struct_name {
539            pub fn _validate_interface(env: &mut jni::Env<'_>) -> Result<(), jni::errors::Error> {
540                #(#validate_checks)*
541                Ok(())
542            }
543
544            #(#methods)*
545        }
546    };
547
548    expanded.into()
549}