Skip to main content

java_jni_extras/
lib.rs

1extern crate proc_macro;
2use proc_macro::TokenStream;
3use proc_macro2::{Span, TokenStream as TokenStream2};
4use quote::quote;
5use syn::{Ident, LitStr, Token, Result, parse::{Parse, ParseStream}, parse_macro_input};
6
7struct JavaClass {
8    package: String,
9    name: Ident,
10    methods: Vec<JavaMethod>,
11}
12
13enum JavaMethod {
14    Constructor {
15        name: Ident,
16        alias: Option<Ident>,
17        params: Vec<JavaParam>,
18    },
19    Method {
20        is_static: bool,
21        is_native: bool,
22        return_type: JavaType,
23        name: Ident,
24        params: Vec<JavaParam>,
25    },
26}
27
28struct JavaParam {
29    ty: JavaType,
30    name: Ident,
31}
32
33#[derive(Clone)]
34enum JavaType {
35    Void,
36    Boolean,
37    Byte,
38    Char,
39    Short,
40    Int,
41    Long,
42    Float,
43    Double,
44    Object(String),
45    Array(Box<JavaType>),
46}
47
48impl JavaType {
49    fn to_jni_sig(&self) -> String {
50        match self {
51            JavaType::Void => "V".to_string(),
52            JavaType::Boolean => "Z".to_string(),
53            JavaType::Byte => "B".to_string(),
54            JavaType::Char => "C".to_string(),
55            JavaType::Short => "S".to_string(),
56            JavaType::Int => "I".to_string(),
57            JavaType::Long => "J".to_string(),
58            JavaType::Float => "F".to_string(),
59            JavaType::Double => "D".to_string(),
60            JavaType::Object(name) => {
61                let slashed = name.replace('.', "/");
62                let resolved = match slashed.as_str() {
63                    "String" => "java/lang/String".to_string(),
64                    "Object" => "java/lang/Object".to_string(),
65                    other => other.to_string(),
66                };
67                format!("L{};", resolved)
68            }
69            JavaType::Array(inner) => format!("[{}", inner.to_jni_sig()),
70        }
71    }
72
73    fn to_rust_type(&self) -> TokenStream2 {
74        match self {
75            JavaType::Void => quote! { () },
76            JavaType::Boolean => quote! { bool },
77            JavaType::Byte => quote! { i8 },
78            JavaType::Char => quote! { u16 },
79            JavaType::Short => quote! { i16 },
80            JavaType::Int => quote! { i32 },
81            JavaType::Long => quote! { i64 },
82            JavaType::Float => quote! { f32 },
83            JavaType::Double => quote! { f64 },
84            JavaType::Object(name) if name == "String" => quote! { &'refs str },
85            JavaType::Object(_) => quote! { jni::objects::JObject<'caller> },
86            JavaType::Array(_) => quote! { jni::objects::JObject<'caller> },
87        }
88    }
89    fn to_rust_return_type(&self) -> TokenStream2 {
90        match self {
91            JavaType::Void => quote! { () },
92            JavaType::Boolean => quote! { bool },
93            JavaType::Byte => quote! { i8 },
94            JavaType::Char => quote! { u16 },
95            JavaType::Short => quote! { i16 },
96            JavaType::Int => quote! { i32 },
97            JavaType::Long => quote! { i64 },
98            JavaType::Float => quote! { f32 },
99            JavaType::Double => quote! { f64 },
100            JavaType::Object(name) if name == "String" => quote! { String },
101            JavaType::Object(_) => quote! { jni::objects::JObject<'caller> },
102            JavaType::Array(_) => quote! { jni::objects::JObject<'caller> },
103        }
104    }
105
106    fn to_jvalue(&self, ident: &Ident) -> TokenStream2 {
107        match self {
108            JavaType::Boolean => quote! { jni::objects::JValue::Bool(#ident) },
109            JavaType::Byte => quote! { jni::objects::JValue::Byte(#ident) },
110            JavaType::Char => quote! { jni::objects::JValue::Char(#ident) },
111            JavaType::Short => quote! { jni::objects::JValue::Short(#ident) },
112            JavaType::Int => quote! { jni::objects::JValue::Int(#ident) },
113            JavaType::Long => quote! { jni::objects::JValue::Long(#ident) },
114            JavaType::Float => quote! { jni::objects::JValue::Float(#ident) },
115            JavaType::Double => quote! { jni::objects::JValue::Double(#ident) },
116            JavaType::Object(name) if name == "String" => {
117                let tmp = Ident::new(&format!("__jstr_{}", ident), Span::call_site());
118                quote! { jni::objects::JValue::Object(&(#tmp).into()) }
119            },
120            JavaType::Object(_) | JavaType::Array(_) => {
121                quote! { jni::objects::JValue::Object(&#ident) }
122            }
123            JavaType::Void => quote! { compile_error!("void cannot be a parameter type") },
124        }
125    }
126
127    fn extract_return(&self, call: TokenStream2) -> TokenStream2 {
128        match self {
129            JavaType::Void => quote! { #call; },
130            JavaType::Boolean => quote! { #call.z() },
131            JavaType::Byte => quote! { #call.b() },
132            JavaType::Char => quote! { #call.c() },
133            JavaType::Short => quote! { #call.s() },
134            JavaType::Int => quote! { #call.i() },
135            JavaType::Long => quote! { #call.j() },
136            JavaType::Float => quote! { #call.f() },
137            JavaType::Double => quote! { #call.d() },
138            JavaType::Object(name) if name == "String" => {
139                quote!( let o = #call.l()?; Ok(JString::cast_local(env, o)?.to_string()) )
140            }
141            JavaType::Object(_) | JavaType::Array(_) => quote! { #call.l() },
142        }
143    }
144
145    fn to_jni_param_type(&self) -> TokenStream2 {
146        match self {
147            JavaType::Void => quote!(),
148            JavaType::Boolean => quote!(boolean),
149            JavaType::Byte => quote!(byte),
150            JavaType::Char => quote!(char),
151            JavaType::Short => quote!(short),
152            JavaType::Int => quote!(int),
153            JavaType::Long => quote!(jlong),
154            JavaType::Float => quote!(float),
155            JavaType::Double => quote!(double),
156            JavaType::Object(x) => {
157                let x: TokenStream2 = syn::parse_str(x).unwrap();
158                quote!(#x)
159            },
160            JavaType::Array(_) => quote!(java.lang.Object),
161        }
162    }
163    fn to_jni_return_type(&self) -> TokenStream2 {
164        match self {
165            JavaType::Void => quote!(),
166            JavaType::Boolean => quote!(-> boolean),
167            JavaType::Byte => quote!(-> byte),
168            JavaType::Char => quote!(-> char),
169            JavaType::Short => quote!(-> short),
170            JavaType::Int => quote!(-> int),
171            JavaType::Long => quote!(-> jlong),
172            JavaType::Float => quote!(-> float),
173            JavaType::Double => quote!(-> double),
174            JavaType::Object(x) => {
175                let x: TokenStream2 = syn::parse_str(x).unwrap();
176                quote!(-> #x)
177            },
178            JavaType::Array(_) => quote!(-> java.lang.Object),
179        }
180    }
181}
182
183fn parse_java_type(input: ParseStream) -> Result<JavaType> {
184    let ty = if input.peek(Ident) {
185        let ident: Ident = input.parse()?;
186        match ident.to_string().as_str() {
187            "void" => JavaType::Void,
188            "boolean" => JavaType::Boolean,
189            "byte" => JavaType::Byte,
190            "char" => JavaType::Char,
191            "short" => JavaType::Short,
192            "int" => JavaType::Int,
193            "long" => JavaType::Long,
194            "float" => JavaType::Float,
195            "double" => JavaType::Double,
196            other => {
197                let mut name = other.to_string();
198                while input.peek(Token![.]) {
199                    input.parse::<Token![.]>()?;
200                    let next: Ident = input.parse()?;
201                    name.push('.');
202                    name.push_str(&next.to_string());
203                }
204                JavaType::Object(name)
205            }
206        }
207    } else {
208        return Err(input.error("expected Java type"));
209    };
210
211    if input.peek(syn::token::Bracket) {
212        let content;
213        syn::bracketed!(content in input);
214        let _ = content;
215        return Ok(JavaType::Array(Box::new(ty)));
216    }
217
218    Ok(ty)
219}
220
221impl Parse for JavaMethod {
222    fn parse(input: ParseStream) -> Result<Self> {
223        let alias: Option<Ident> = if input.peek(Token![#]) {
224            input.parse::<Token![#]>()?;
225            let content;
226            syn::bracketed!(content in input);
227            let attr_name: Ident = content.parse()?;
228            if attr_name != "alias" {
229                return Err(syn::Error::new(attr_name.span(), "expected 'alias'"));
230            }
231            let inner;
232            syn::parenthesized!(inner in content);
233            Some(inner.parse::<Ident>()?)
234        } else {
235            None
236        };
237
238        let mut is_static = false;
239        let mut is_native = false;
240
241        loop {
242            if input.peek(Token![static]) {
243                input.parse::<Token![static]>()?;
244                is_static = true;
245            } else if input.peek(Ident) {
246                let ident: Ident = input.fork().parse()?;
247                match ident.to_string().as_str() {
248                    "native" => { input.parse::<Ident>()?; is_native = true; }
249                    "public" | "private" | "protected" | "final" | "synchronized" => {
250                        input.parse::<Ident>()?;
251                    }
252                    _ => break,
253                }
254            } else {
255                break;
256            }
257        }
258
259        let is_constructor = input.peek(Ident) && {
260            let fork = input.fork();
261            let _: Ident = fork.parse()?;
262            fork.peek(syn::token::Paren)
263        };
264
265        if is_constructor {
266            let name: Ident = input.parse()?;
267            let content;
268            syn::parenthesized!(content in input);
269            let params = parse_params(&content)?;
270            input.parse::<Token![;]>()?;
271            return Ok(JavaMethod::Constructor { name, alias, params });
272        }
273
274        let return_type = parse_java_type(input)?;
275        let name: Ident = input.parse()?;
276        let content;
277        syn::parenthesized!(content in input);
278        let params = parse_params(&content)?;
279        input.parse::<Token![;]>()?;
280
281        Ok(JavaMethod::Method { is_static, is_native, return_type, name, params })
282    }
283}
284
285fn parse_params(content: ParseStream) -> Result<Vec<JavaParam>> {
286    let mut params = Vec::new();
287    while !content.is_empty() {
288        let ty = parse_java_type(content)?;
289        let name: Ident = content.parse()?;
290        params.push(JavaParam { ty, name });
291        if content.peek(Token![,]) {
292            content.parse::<Token![,]>()?;
293        }
294    }
295    Ok(params)
296}
297
298impl Parse for JavaClass {
299    fn parse(input: ParseStream) -> Result<Self> {
300        let pkg_kw: Ident = input.parse()?;
301        if pkg_kw != "package" {
302            return Err(syn::Error::new(pkg_kw.span(), "expected 'package'"));
303        }
304
305        let mut package = String::new();
306        loop {
307            let ident: Ident = input.parse()?;
308            package.push_str(&ident.to_string());
309            if input.peek(Token![;]) {
310                input.parse::<Token![;]>()?;
311                break;
312            }
313            input.parse::<Token![.]>()?;
314            package.push('.');
315        }
316
317        let class_kw: Ident = input.parse()?;
318        if class_kw != "class" {
319            return Err(syn::Error::new(class_kw.span(), "expected 'class'"));
320        }
321
322        let name: Ident = input.parse()?;
323
324        let content;
325        syn::braced!(content in input);
326
327        let mut methods = Vec::new();
328        while !content.is_empty() {
329            methods.push(content.parse::<JavaMethod>()?);
330        }
331
332        Ok(JavaClass { package, name, methods })
333    }
334}
335
336fn class_path(package: &str, name: &str) -> String {
337    format!("{}.{}", package, name).replace('.', "/")
338}
339
340fn generate_method(
341    class_path_lit: &LitStr,
342    is_static: bool,
343    return_type: &JavaType,
344    name: &Ident,
345    params: &[JavaParam]
346) -> TokenStream2 {
347    let method_name = name;
348    let method_name_str = method_name.to_string();
349
350    let param_sig: String = params.iter().map(|p| p.ty.to_jni_sig()).collect();
351    let return_sig = return_type.to_jni_sig();
352    let full_sig = format!("({}){}", param_sig, return_sig);
353    let sig_lit = LitStr::new(&full_sig, Span::call_site());
354    let method_name_lit = LitStr::new(&method_name_str, Span::call_site());
355
356    let rust_params: Vec<TokenStream2> = params.iter().map(|p| {
357        let pname = &p.name;
358        let pty = p.ty.to_rust_type();
359        quote! { #pname: #pty }
360    }).collect();
361
362    let string_conversions: Vec<TokenStream2> = params.iter().map(|p| {
363        let pname = &p.name;
364        match &p.ty {
365            JavaType::Object(name) if name == "String" => {
366                let tmp = Ident::new(&format!("__jstr_{}", pname), Span::call_site());
367                quote! { let #tmp = env.new_string(#pname)?; }
368            }
369            _ => quote! {},
370        }
371    }).collect();
372
373    let jvalues: Vec<TokenStream2> = params.iter().map(|p| {
374        p.ty.to_jvalue(&p.name)
375    }).collect();
376
377    let return_type_ts = return_type.to_rust_return_type();
378
379    let call = if is_static {
380        quote! { env.call_static_method(jni_str!(#class_path_lit), jni_str!(#method_name_lit), jni_sig!(#sig_lit), &[#(#jvalues),*])? }
381    } else {
382        quote! { env.call_method(obj, jni_str!(#method_name_lit), jni_sig!(#sig_lit), &[#(#jvalues),*])? }
383    };
384
385    let body = match &return_type {
386        JavaType::Void => quote! {
387            #(#string_conversions)*
388            #call;
389            Ok(())
390        },
391        _ => {
392            let extract = return_type.extract_return(call);
393            quote! {
394                #(#string_conversions)*
395                #extract
396            }
397
398        }
399    };
400
401    if is_static {
402        quote! {
403            pub fn #method_name<'caller, 'refs>(
404                env: &'refs mut jni::Env<'caller>,
405                #(#rust_params),*
406            ) -> Result<#return_type_ts, jni::errors::Error> {
407                #body
408            }
409        }
410    } else {
411        quote! {
412            pub fn #method_name<'caller, 'refs>(
413                env: &'refs mut jni::Env<'caller>,
414                obj: &'refs jni::objects::JObject<'caller>,
415                #(#rust_params),*
416            ) -> Result<#return_type_ts, jni::errors::Error> {
417                #body
418            }
419        }
420    }
421}
422
423fn generate_constructor(
424    class_path_lit: &LitStr,
425    name: &Ident,
426    params: &[JavaParam],
427    alias: Option<&Ident>
428) -> TokenStream2 {
429    let param_sig: String = params.iter().map(|p| p.ty.to_jni_sig()).collect();
430    let sig_lit = LitStr::new(&format!("({})V", param_sig), Span::call_site());
431
432    let name = alias.unwrap_or(name);
433
434    let rust_params: Vec<TokenStream2> = params.iter().map(|p| {
435        let pname = &p.name;
436        let pty = p.ty.to_rust_type();
437        quote! { #pname: #pty }
438    }).collect();
439
440    let string_conversions: Vec<TokenStream2> = params.iter().map(|p| {
441        let pname = &p.name;
442        match &p.ty {
443            JavaType::Object(n) if n == "String" => {
444                let tmp = Ident::new(&format!("__jstr_{}", pname), Span::call_site());
445                quote! { let #tmp = env.new_string(#pname)?; }
446            }
447            _ => quote! {},
448        }
449    }).collect();
450
451    let jvalues: Vec<TokenStream2> = params.iter().map(|p| {
452        p.ty.to_jvalue(&p.name)
453    }).collect();
454
455    quote! {
456        pub fn #name<'caller>(
457            env: &mut jni::Env<'caller>,
458            #(#rust_params),*
459        ) -> Result<jni::objects::JObject<'caller>, jni::errors::Error> {
460            #(#string_conversions)*
461            env.new_object(
462                jni_str!(#class_path_lit),
463                jni_sig!(#sig_lit),
464                &[#(#jvalues),*]
465            )
466        }
467    }
468}
469
470#[proc_macro]
471pub fn java_class_decl(input: TokenStream) -> TokenStream {
472    let class = parse_macro_input!(input as JavaClass);
473
474    let struct_name = &class.name;
475    let cp = class_path(&class.package, &class.name.to_string());
476    let class_path_lit = LitStr::new(&cp, Span::call_site());
477
478    let native_registrations: Vec<TokenStream2> = class.methods.iter()
479        .filter_map(|m| match m {
480            JavaMethod::Method { is_native: true, name, params, return_type, is_static, .. } => {
481
482                let package_class = format!("{}.{}", class.package, class.name);
483                let package_class_lit = LitStr::new(&package_class, Span::call_site());
484                let return_type = return_type.to_jni_return_type();
485                let params: Vec<TokenStream2> = params.iter().map(|p| {
486                    p.ty.to_jni_param_type()
487                }).collect();
488
489                Some(if *is_static {
490                    quote! {
491                        const _: jni::NativeMethod = jni::native_method! {
492                            java_type = #package_class_lit,
493                            static extern fn #name(#(#params),*) #return_type,
494                        };
495                    }
496                } else {
497                    quote! {
498                        const _: jni::NativeMethod = jni::native_method! {
499                            java_type = #package_class_lit,
500                            extern fn #name(#(#params),*) #return_type,
501                        };
502                    }
503                })
504            }
505            _ => None,
506        })
507        .collect();
508
509    let validate_checks: Vec<TokenStream2> = class.methods.iter()
510        .filter_map(|m| match m {
511            JavaMethod::Method { is_native: false, name, params, return_type, is_static, .. } => {
512
513                let method_name_str = name.to_string();
514                let method_name_lit = LitStr::new(&method_name_str, Span::call_site());
515                let param_sig: String = params.iter().map(|p| p.ty.to_jni_sig()).collect();
516                let return_sig = return_type.to_jni_sig();
517                let sig_lit = LitStr::new(&format!("({}){}", param_sig, return_sig), Span::call_site());
518                let class_path_lit_str = LitStr::new(&format!("{}.{}", class.package, class.name), Span::call_site());
519
520                Some(if *is_static {
521                    quote! {
522                        env.get_static_method_id(
523                            jni_str!(#class_path_lit_str),
524                            jni_str!(#method_name_lit),
525                            jni_sig!(#sig_lit),
526                        )?;
527                    }
528                } else {
529                    quote! {
530                        env.get_method_id(
531                            jni_str!(#class_path_lit_str),
532                            jni_str!(#method_name_lit),
533                            jni_sig!(#sig_lit),
534                        )?;
535                    }
536                })
537            }
538            _ => None,
539        })
540        .collect();
541
542    let methods: Vec<TokenStream2> = class.methods.iter()
543        .filter_map(|m| match m {
544            JavaMethod::Constructor { name, params, alias } => {
545                Some(generate_constructor(&class_path_lit, name, params, alias.as_ref()))
546            }
547            JavaMethod::Method { is_static, is_native: false, return_type, name, params, } => {
548                Some(generate_method(&class_path_lit, *is_static, return_type, name, params))
549            }
550            _ => None,
551        })
552        .collect();
553
554
555    let expanded = quote! {
556        #(#native_registrations)*
557
558        pub struct #struct_name;
559
560        impl #struct_name {
561            pub fn _validate_interface(env: &mut jni::Env<'_>) -> Result<(), jni::errors::Error> {
562                #(#validate_checks)*
563                Ok(())
564            }
565
566            #(#methods)*
567        }
568    };
569
570    expanded.into()
571}