Skip to main content

java_jni_extras/
lib.rs

1extern crate proc_macro;
2use proc_macro::TokenStream;
3use proc_macro2::{Span, TokenStream as TokenStream2};
4use quote::quote;
5use syn::{Ident, LitStr, Token, Result, parse::{Parse, ParseStream}, parse_macro_input};
6
7struct JavaClass {
8    package: String,
9    name: Ident,
10    methods: Vec<JavaMethod>,
11}
12
13enum JavaMethod {
14    Constructor {
15        name: Ident,
16        alias: Option<Ident>,
17        params: Vec<JavaParam>,
18    },
19    Method {
20        is_static: bool,
21        is_native: bool,
22        return_type: JavaType,
23        name: Ident,
24        params: Vec<JavaParam>,
25    },
26}
27
28struct JavaParam {
29    ty: JavaType,
30    name: Ident,
31}
32
33#[derive(Clone)]
34enum JavaType {
35    Void,
36    Boolean,
37    Byte,
38    Char,
39    Short,
40    Int,
41    Long,
42    Float,
43    Double,
44    Object(String),
45    Array(Box<JavaType>),
46}
47
48impl JavaType {
49    fn to_jni_sig(&self) -> String {
50        match self {
51            JavaType::Void => "V".to_string(),
52            JavaType::Boolean => "Z".to_string(),
53            JavaType::Byte => "B".to_string(),
54            JavaType::Char => "C".to_string(),
55            JavaType::Short => "S".to_string(),
56            JavaType::Int => "I".to_string(),
57            JavaType::Long => "J".to_string(),
58            JavaType::Float => "F".to_string(),
59            JavaType::Double => "D".to_string(),
60            JavaType::Object(name) => {
61                let slashed = name.replace('.', "/");
62                let resolved = match slashed.as_str() {
63                    "String" => "java/lang/String".to_string(),
64                    "Object" => "java/lang/Object".to_string(),
65                    other => other.to_string(),
66                };
67                format!("L{};", resolved)
68            }
69            JavaType::Array(inner) => format!("[{}", inner.to_jni_sig()),
70        }
71    }
72
73    fn to_rust_type(&self) -> TokenStream2 {
74        match self {
75            JavaType::Void => quote! { () },
76            JavaType::Boolean => quote! { bool },
77            JavaType::Byte => quote! { i8 },
78            JavaType::Char => quote! { u16 },
79            JavaType::Short => quote! { i16 },
80            JavaType::Int => quote! { i32 },
81            JavaType::Long => quote! { i64 },
82            JavaType::Float => quote! { f32 },
83            JavaType::Double => quote! { f64 },
84            JavaType::Object(name) if name == "String" => quote! { &'refs str },
85            JavaType::Object(_) => quote! { jni::objects::JObject<'caller> },
86            JavaType::Array(_) => quote! { jni::objects::JObject<'caller> },
87        }
88    }
89    fn to_rust_return_type(&self) -> TokenStream2 {
90        match self {
91            JavaType::Void => quote! { () },
92            JavaType::Boolean => quote! { bool },
93            JavaType::Byte => quote! { i8 },
94            JavaType::Char => quote! { u16 },
95            JavaType::Short => quote! { i16 },
96            JavaType::Int => quote! { i32 },
97            JavaType::Long => quote! { i64 },
98            JavaType::Float => quote! { f32 },
99            JavaType::Double => quote! { f64 },
100            JavaType::Object(name) if name == "String" => quote! { String },
101            JavaType::Object(_) => quote! { jni::objects::JObject<'caller> },
102            JavaType::Array(_) => quote! { jni::objects::JObject<'caller> },
103        }
104    }
105
106    fn to_jvalue(&self, ident: &Ident) -> TokenStream2 {
107        match self {
108            JavaType::Boolean => quote! { jni::objects::JValue::Bool(#ident) },
109            JavaType::Byte => quote! { jni::objects::JValue::Byte(#ident) },
110            JavaType::Char => quote! { jni::objects::JValue::Char(#ident) },
111            JavaType::Short => quote! { jni::objects::JValue::Short(#ident) },
112            JavaType::Int => quote! { jni::objects::JValue::Int(#ident) },
113            JavaType::Long => quote! { jni::objects::JValue::Long(#ident) },
114            JavaType::Float => quote! { jni::objects::JValue::Float(#ident) },
115            JavaType::Double => quote! { jni::objects::JValue::Double(#ident) },
116            JavaType::Object(name) if name == "String" => {
117                let tmp = Ident::new(&format!("__jstr_{}", ident), Span::call_site());
118                quote! { jni::objects::JValue::Object(&(#tmp).into()) }
119            },
120            JavaType::Object(_) | JavaType::Array(_) => {
121                quote! { jni::objects::JValue::Object(&#ident) }
122            }
123            JavaType::Void => quote! { compile_error!("void cannot be a parameter type") },
124        }
125    }
126
127    fn extract_return(&self, call: TokenStream2) -> TokenStream2 {
128        match self {
129            JavaType::Void => quote! { #call; },
130            JavaType::Boolean => quote! { #call.z() },
131            JavaType::Byte => quote! { #call.b() },
132            JavaType::Char => quote! { #call.c() },
133            JavaType::Short => quote! { #call.s() },
134            JavaType::Int => quote! { #call.i() },
135            JavaType::Long => quote! { #call.j() },
136            JavaType::Float => quote! { #call.f() },
137            JavaType::Double => quote! { #call.d() },
138            JavaType::Object(name) if name == "String" => {
139                quote!( let o = #call.l()?; Ok(JString::cast_local(env, o)?.to_string()) )
140            }
141            JavaType::Object(_) | JavaType::Array(_) => quote! { #call.l() },
142        }
143    }
144
145    fn to_jni_param_type(&self) -> TokenStream2 {
146        match self {
147            JavaType::Void => quote!(),
148            JavaType::Boolean => quote!(boolean),
149            JavaType::Byte => quote!(byte),
150            JavaType::Char => quote!(char),
151            JavaType::Short => quote!(short),
152            JavaType::Int => quote!(int),
153            JavaType::Long => quote!(jlong),
154            JavaType::Float => quote!(float),
155            JavaType::Double => quote!(double),
156            JavaType::Object(x) => {
157                let x: TokenStream2 = syn::parse_str(x).unwrap();
158                quote!(#x)
159            },
160            JavaType::Array(x) => {
161                let x = x.to_jni_param_type();
162                quote!(#x[])
163            },
164        }
165    }
166    fn to_jni_return_type(&self) -> TokenStream2 {
167        match self {
168            JavaType::Void => quote!(),
169            JavaType::Boolean => quote!(-> boolean),
170            JavaType::Byte => quote!(-> byte),
171            JavaType::Char => quote!(-> char),
172            JavaType::Short => quote!(-> short),
173            JavaType::Int => quote!(-> int),
174            JavaType::Long => quote!(-> jlong),
175            JavaType::Float => quote!(-> float),
176            JavaType::Double => quote!(-> double),
177            JavaType::Object(x) => {
178                let x: TokenStream2 = syn::parse_str(x).unwrap();
179                quote!(-> #x)
180            },
181            JavaType::Array(x) => {
182                let x = x.to_jni_param_type();
183                quote!(-> #x[])
184            },
185        }
186    }
187}
188
189fn parse_java_type(input: ParseStream) -> Result<JavaType> {
190    let ty = if input.peek(Ident) {
191        let ident: Ident = input.parse()?;
192        match ident.to_string().as_str() {
193            "void" => JavaType::Void,
194            "boolean" => JavaType::Boolean,
195            "byte" => JavaType::Byte,
196            "char" => JavaType::Char,
197            "short" => JavaType::Short,
198            "int" => JavaType::Int,
199            "long" => JavaType::Long,
200            "float" => JavaType::Float,
201            "double" => JavaType::Double,
202            other => {
203                let mut name = other.to_string();
204                while input.peek(Token![.]) {
205                    input.parse::<Token![.]>()?;
206                    let next: Ident = input.parse()?;
207                    name.push('.');
208                    name.push_str(&next.to_string());
209                }
210                JavaType::Object(name)
211            }
212        }
213    } else {
214        return Err(input.error("expected Java type"));
215    };
216
217    if input.peek(syn::token::Bracket) {
218        let content;
219        syn::bracketed!(content in input);
220        let _ = content;
221        return Ok(JavaType::Array(Box::new(ty)));
222    }
223
224    Ok(ty)
225}
226
227impl Parse for JavaMethod {
228    fn parse(input: ParseStream) -> Result<Self> {
229        let alias: Option<Ident> = if input.peek(Token![#]) {
230            input.parse::<Token![#]>()?;
231            let content;
232            syn::bracketed!(content in input);
233            let attr_name: Ident = content.parse()?;
234            if attr_name != "alias" {
235                return Err(syn::Error::new(attr_name.span(), "expected 'alias'"));
236            }
237            let inner;
238            syn::parenthesized!(inner in content);
239            Some(inner.parse::<Ident>()?)
240        } else {
241            None
242        };
243
244        let mut is_static = false;
245        let mut is_native = false;
246
247        loop {
248            if input.peek(Token![static]) {
249                input.parse::<Token![static]>()?;
250                is_static = true;
251            } else if input.peek(Ident) {
252                let ident: Ident = input.fork().parse()?;
253                match ident.to_string().as_str() {
254                    "native" => { input.parse::<Ident>()?; is_native = true; }
255                    "public" | "private" | "protected" | "final" | "synchronized" => {
256                        input.parse::<Ident>()?;
257                    }
258                    _ => break,
259                }
260            } else {
261                break;
262            }
263        }
264
265        let is_constructor = input.peek(Ident) && {
266            let fork = input.fork();
267            let _: Ident = fork.parse()?;
268            fork.peek(syn::token::Paren)
269        };
270
271        if is_constructor {
272            let name: Ident = input.parse()?;
273            let content;
274            syn::parenthesized!(content in input);
275            let params = parse_params(&content)?;
276            input.parse::<Token![;]>()?;
277            return Ok(JavaMethod::Constructor { name, alias, params });
278        }
279
280        let return_type = parse_java_type(input)?;
281        let name: Ident = input.parse()?;
282        let content;
283        syn::parenthesized!(content in input);
284        let params = parse_params(&content)?;
285        input.parse::<Token![;]>()?;
286
287        Ok(JavaMethod::Method { is_static, is_native, return_type, name, params })
288    }
289}
290
291fn parse_params(content: ParseStream) -> Result<Vec<JavaParam>> {
292    let mut params = Vec::new();
293    while !content.is_empty() {
294        let ty = parse_java_type(content)?;
295        let name: Ident = content.parse()?;
296        params.push(JavaParam { ty, name });
297        if content.peek(Token![,]) {
298            content.parse::<Token![,]>()?;
299        }
300    }
301    Ok(params)
302}
303
304impl Parse for JavaClass {
305    fn parse(input: ParseStream) -> Result<Self> {
306        let pkg_kw: Ident = input.parse()?;
307        if pkg_kw != "package" {
308            return Err(syn::Error::new(pkg_kw.span(), "expected 'package'"));
309        }
310
311        let mut package = String::new();
312        loop {
313            let ident: Ident = input.parse()?;
314            package.push_str(&ident.to_string());
315            if input.peek(Token![;]) {
316                input.parse::<Token![;]>()?;
317                break;
318            }
319            input.parse::<Token![.]>()?;
320            package.push('.');
321        }
322
323        let class_kw: Ident = input.parse()?;
324        if class_kw != "class" {
325            return Err(syn::Error::new(class_kw.span(), "expected 'class'"));
326        }
327
328        let name: Ident = input.parse()?;
329
330        let content;
331        syn::braced!(content in input);
332
333        let mut methods = Vec::new();
334        while !content.is_empty() {
335            methods.push(content.parse::<JavaMethod>()?);
336        }
337
338        Ok(JavaClass { package, name, methods })
339    }
340}
341
342fn class_path(package: &str, name: &str) -> String {
343    format!("{}.{}", package, name).replace('.', "/")
344}
345
346fn generate_method(
347    class_path_lit: &LitStr,
348    is_static: bool,
349    return_type: &JavaType,
350    name: &Ident,
351    params: &[JavaParam]
352) -> TokenStream2 {
353    let method_name = name;
354    let method_name_str = method_name.to_string();
355
356    let param_sig: String = params.iter().map(|p| p.ty.to_jni_sig()).collect();
357    let return_sig = return_type.to_jni_sig();
358    let full_sig = format!("({}){}", param_sig, return_sig);
359    let sig_lit = LitStr::new(&full_sig, Span::call_site());
360    let method_name_lit = LitStr::new(&method_name_str, Span::call_site());
361
362    let rust_params: Vec<TokenStream2> = params.iter().map(|p| {
363        let pname = &p.name;
364        let pty = p.ty.to_rust_type();
365        quote! { #pname: #pty }
366    }).collect();
367
368    let string_conversions: Vec<TokenStream2> = params.iter().map(|p| {
369        let pname = &p.name;
370        match &p.ty {
371            JavaType::Object(name) if name == "String" => {
372                let tmp = Ident::new(&format!("__jstr_{}", pname), Span::call_site());
373                quote! { let #tmp = env.new_string(#pname)?; }
374            }
375            _ => quote! {},
376        }
377    }).collect();
378
379    let jvalues: Vec<TokenStream2> = params.iter().map(|p| {
380        p.ty.to_jvalue(&p.name)
381    }).collect();
382
383    let return_type_ts = return_type.to_rust_return_type();
384
385    let call = if is_static {
386        quote! { env.call_static_method(jni_str!(#class_path_lit), jni_str!(#method_name_lit), jni_sig!(#sig_lit), &[#(#jvalues),*])? }
387    } else {
388        quote! { env.call_method(obj, jni_str!(#method_name_lit), jni_sig!(#sig_lit), &[#(#jvalues),*])? }
389    };
390
391    let body = match &return_type {
392        JavaType::Void => quote! {
393            #(#string_conversions)*
394            #call;
395            Ok(())
396        },
397        _ => {
398            let extract = return_type.extract_return(call);
399            quote! {
400                #(#string_conversions)*
401                #extract
402            }
403
404        }
405    };
406
407    if is_static {
408        quote! {
409            pub fn #method_name<'caller, 'refs>(
410                env: &'refs mut jni::Env<'caller>,
411                #(#rust_params),*
412            ) -> Result<#return_type_ts, jni::errors::Error> {
413                #body
414            }
415        }
416    } else {
417        quote! {
418            pub fn #method_name<'caller, 'refs>(
419                env: &'refs mut jni::Env<'caller>,
420                obj: &'refs jni::objects::JObject<'caller>,
421                #(#rust_params),*
422            ) -> Result<#return_type_ts, jni::errors::Error> {
423                #body
424            }
425        }
426    }
427}
428
429fn generate_constructor(
430    class_path_lit: &LitStr,
431    name: &Ident,
432    params: &[JavaParam],
433    alias: Option<&Ident>
434) -> TokenStream2 {
435    let param_sig: String = params.iter().map(|p| p.ty.to_jni_sig()).collect();
436    let sig_lit = LitStr::new(&format!("({})V", param_sig), Span::call_site());
437
438    let name = alias.unwrap_or(name);
439
440    let rust_params: Vec<TokenStream2> = params.iter().map(|p| {
441        let pname = &p.name;
442        let pty = p.ty.to_rust_type();
443        quote! { #pname: #pty }
444    }).collect();
445
446    let string_conversions: Vec<TokenStream2> = params.iter().map(|p| {
447        let pname = &p.name;
448        match &p.ty {
449            JavaType::Object(n) if n == "String" => {
450                let tmp = Ident::new(&format!("__jstr_{}", pname), Span::call_site());
451                quote! { let #tmp = env.new_string(#pname)?; }
452            }
453            _ => quote! {},
454        }
455    }).collect();
456
457    let jvalues: Vec<TokenStream2> = params.iter().map(|p| {
458        p.ty.to_jvalue(&p.name)
459    }).collect();
460
461    quote! {
462        pub fn #name<'caller, 'refs>(
463            env: &'refs mut jni::Env<'caller>,
464            #(#rust_params),*
465        ) -> Result<jni::objects::JObject<'caller>, jni::errors::Error> {
466            #(#string_conversions)*
467            env.new_object(
468                jni_str!(#class_path_lit),
469                jni_sig!(#sig_lit),
470                &[#(#jvalues),*]
471            )
472        }
473    }
474}
475
476#[proc_macro]
477pub fn java_class_decl(input: TokenStream) -> TokenStream {
478    let class = parse_macro_input!(input as JavaClass);
479
480    let struct_name = &class.name;
481    let cp = class_path(&class.package, &class.name.to_string());
482    let class_path_lit = LitStr::new(&cp, Span::call_site());
483
484    let native_registrations: Vec<TokenStream2> = class.methods.iter()
485        .filter_map(|m| match m {
486            JavaMethod::Method { is_native: true, name, params, return_type, is_static, .. } => {
487
488                let package_class = format!("{}.{}", class.package, class.name);
489                let package_class_lit = LitStr::new(&package_class, Span::call_site());
490                let return_type = return_type.to_jni_return_type();
491                let params: Vec<TokenStream2> = params.iter().map(|p| {
492                    p.ty.to_jni_param_type()
493                }).collect();
494
495                Some(if *is_static {
496                    quote! {
497                        const _: jni::NativeMethod = jni::native_method! {
498                            java_type = #package_class_lit,
499                            static extern fn #name(#(#params),*) #return_type,
500                        };
501                    }
502                } else {
503                    quote! {
504                        const _: jni::NativeMethod = jni::native_method! {
505                            java_type = #package_class_lit,
506                            extern fn #name(#(#params),*) #return_type,
507                        };
508                    }
509                })
510            }
511            _ => None,
512        })
513        .collect();
514
515    let validate_checks: Vec<TokenStream2> = class.methods.iter()
516        .filter_map(|m| match m {
517            JavaMethod::Method { is_native: false, name, params, return_type, is_static, .. } => {
518
519                let method_name_str = name.to_string();
520                let method_name_lit = LitStr::new(&method_name_str, Span::call_site());
521                let param_sig: String = params.iter().map(|p| p.ty.to_jni_sig()).collect();
522                let return_sig = return_type.to_jni_sig();
523                let sig_lit = LitStr::new(&format!("({}){}", param_sig, return_sig), Span::call_site());
524                let class_path_lit_str = LitStr::new(&format!("{}.{}", class.package, class.name), Span::call_site());
525
526                Some(if *is_static {
527                    quote! {
528                        env.get_static_method_id(
529                            jni_str!(#class_path_lit_str),
530                            jni_str!(#method_name_lit),
531                            jni_sig!(#sig_lit),
532                        )?;
533                    }
534                } else {
535                    quote! {
536                        env.get_method_id(
537                            jni_str!(#class_path_lit_str),
538                            jni_str!(#method_name_lit),
539                            jni_sig!(#sig_lit),
540                        )?;
541                    }
542                })
543            }
544            _ => None,
545        })
546        .collect();
547
548    let methods: Vec<TokenStream2> = class.methods.iter()
549        .filter_map(|m| match m {
550            JavaMethod::Constructor { name, params, alias } => {
551                Some(generate_constructor(&class_path_lit, name, params, alias.as_ref()))
552            }
553            JavaMethod::Method { is_static, is_native: false, return_type, name, params, } => {
554                Some(generate_method(&class_path_lit, *is_static, return_type, name, params))
555            }
556            _ => None,
557        })
558        .collect();
559
560
561    let expanded = quote! {
562        #(#native_registrations)*
563
564        pub struct #struct_name;
565
566        impl #struct_name {
567            pub fn _validate_interface(env: &mut jni::Env<'_>) -> Result<(), jni::errors::Error> {
568                #(#validate_checks)*
569                Ok(())
570            }
571
572            #(#methods)*
573        }
574    };
575
576    expanded.into()
577}