Skip to main content

java_jni_extras/
lib.rs

1extern crate proc_macro;
2use proc_macro::TokenStream;
3use proc_macro2::{Span, TokenStream as TokenStream2};
4use quote::quote;
5use syn::{Ident, LitStr, Token, Result, parse::{Parse, ParseStream}, parse_macro_input};
6
7struct JavaClass {
8    package: String,
9    name: Ident,
10    methods: Vec<JavaMethod>,
11}
12
13struct JavaMethod {
14    is_static: bool,
15    is_native: bool,
16    return_type: JavaType,
17    name: Ident,
18    params: Vec<JavaParam>,
19}
20
21struct JavaParam {
22    ty: JavaType,
23    name: Ident,
24}
25
26#[derive(Clone)]
27enum JavaType {
28    Void,
29    Boolean,
30    Byte,
31    Char,
32    Short,
33    Int,
34    Long,
35    Float,
36    Double,
37    Object(String),
38    Array(Box<JavaType>),
39}
40
41impl JavaType {
42    fn to_jni_sig(&self) -> String {
43        match self {
44            JavaType::Void => "V".to_string(),
45            JavaType::Boolean => "Z".to_string(),
46            JavaType::Byte => "B".to_string(),
47            JavaType::Char => "C".to_string(),
48            JavaType::Short => "S".to_string(),
49            JavaType::Int => "I".to_string(),
50            JavaType::Long => "J".to_string(),
51            JavaType::Float => "F".to_string(),
52            JavaType::Double => "D".to_string(),
53            JavaType::Object(name) => {
54                let slashed = name.replace('.', "/");
55                // handle common unqualified names
56                let resolved = match slashed.as_str() {
57                    "String" => "java/lang/String".to_string(),
58                    "Object" => "java/lang/Object".to_string(),
59                    other => other.to_string(),
60                };
61                format!("L{};", resolved)
62            }
63            JavaType::Array(inner) => format!("[{}", inner.to_jni_sig()),
64        }
65    }
66
67    fn to_rust_type(&self) -> TokenStream2 {
68        match self {
69            JavaType::Void => quote! { () },
70            JavaType::Boolean => quote! { bool },
71            JavaType::Byte => quote! { i8 },
72            JavaType::Char => quote! { u16 },
73            JavaType::Short => quote! { i16 },
74            JavaType::Int => quote! { i32 },
75            JavaType::Long => quote! { i64 },
76            JavaType::Float => quote! { f32 },
77            JavaType::Double => quote! { f64 },
78            JavaType::Object(name) if name == "String" => quote! { &'refs str },
79            JavaType::Object(_) => quote! { jni::objects::JObject<'caller> },
80            JavaType::Array(_) => quote! { jni::objects::JObject<'caller> },
81        }
82    }
83    fn to_rust_return_type(&self) -> TokenStream2 {
84        match self {
85            JavaType::Void => quote! { () },
86            JavaType::Boolean => quote! { bool },
87            JavaType::Byte => quote! { i8 },
88            JavaType::Char => quote! { u16 },
89            JavaType::Short => quote! { i16 },
90            JavaType::Int => quote! { i32 },
91            JavaType::Long => quote! { i64 },
92            JavaType::Float => quote! { f32 },
93            JavaType::Double => quote! { f64 },
94            JavaType::Object(name) if name == "String" => quote! { String },
95            JavaType::Object(_) => quote! { jni::objects::JObject<'caller> },
96            JavaType::Array(_) => quote! { jni::objects::JObject<'caller> },
97        }
98    }
99
100    fn to_jvalue(&self, ident: &Ident) -> TokenStream2 {
101        match self {
102            JavaType::Boolean => quote! { jni::objects::JValue::Bool(#ident as u8) },
103            JavaType::Byte => quote! { jni::objects::JValue::Byte(#ident) },
104            JavaType::Char => quote! { jni::objects::JValue::Char(#ident) },
105            JavaType::Short => quote! { jni::objects::JValue::Short(#ident) },
106            JavaType::Int => quote! { jni::objects::JValue::Int(#ident) },
107            JavaType::Long => quote! { jni::objects::JValue::Long(#ident) },
108            JavaType::Float => quote! { jni::objects::JValue::Float(#ident) },
109            JavaType::Double => quote! { jni::objects::JValue::Double(#ident) },
110            JavaType::Object(name) if name == "String" => {
111                let tmp = Ident::new(&format!("__jstr_{}", ident), Span::call_site());
112                quote! { jni::objects::JValue::Object(&(#tmp).into()) }
113            },
114            JavaType::Object(_) | JavaType::Array(_) => {
115                quote! { jni::objects::JValue::Object(&#ident) }
116            }
117            JavaType::Void => quote! { compile_error!("void cannot be a parameter type") },
118        }
119    }
120
121    fn extract_return(&self, call: TokenStream2) -> TokenStream2 {
122        match self {
123            JavaType::Void => quote! { #call; },
124            JavaType::Boolean => quote! { #call.z() },
125            JavaType::Byte => quote! { #call.b() },
126            JavaType::Char => quote! { #call.c() },
127            JavaType::Short => quote! { #call.s() },
128            JavaType::Int => quote! { #call.i() },
129            JavaType::Long => quote! { #call.j() },
130            JavaType::Float => quote! { #call.f() },
131            JavaType::Double => quote! { #call.d() },
132            JavaType::Object(name) if name == "String" => {
133                quote!( let o = #call.l()?; Ok(JString::cast_local(env, o)?.to_string()) )
134            }
135            JavaType::Object(_) | JavaType::Array(_) => quote! { #call.l() },
136        }
137    }
138
139    fn to_jni_param_type(&self) -> TokenStream2 {
140        match self {
141            JavaType::Void => quote!(),
142            JavaType::Boolean => quote!(boolean),
143            JavaType::Byte => quote!(byte),
144            JavaType::Char => quote!(char),
145            JavaType::Short => quote!(short),
146            JavaType::Int => quote!(int),
147            JavaType::Long => quote!(jlong),
148            JavaType::Float => quote!(float),
149            JavaType::Double => quote!(double),
150            JavaType::Object(x) => quote!(#x),
151            JavaType::Array(_) => quote!(todo!()),
152        }
153    }
154    fn to_jni_return_type(&self) -> TokenStream2 {
155        match self {
156            JavaType::Void => quote!(),
157            JavaType::Boolean => quote!(-> boolean),
158            JavaType::Byte => quote!(-> byte),
159            JavaType::Char => quote!(-> char),
160            JavaType::Short => quote!(-> short),
161            JavaType::Int => quote!(-> int),
162            JavaType::Long => quote!(-> jlong),
163            JavaType::Float => quote!(-> float),
164            JavaType::Double => quote!(-> double),
165            JavaType::Object(x) => quote!(-> #x),
166            JavaType::Array(_) => quote!(-> todo!()),
167        }
168    }
169}
170
171fn parse_java_type(input: ParseStream) -> Result<JavaType> {
172    let ty = if input.peek(Ident) {
173        let ident: Ident = input.parse()?;
174        match ident.to_string().as_str() {
175            "void" => JavaType::Void,
176            "boolean" => JavaType::Boolean,
177            "byte" => JavaType::Byte,
178            "char" => JavaType::Char,
179            "short" => JavaType::Short,
180            "int" => JavaType::Int,
181            "long" => JavaType::Long,
182            "float" => JavaType::Float,
183            "double" => JavaType::Double,
184            other => {
185                let mut name = other.to_string();
186                while input.peek(Token![.]) {
187                    input.parse::<Token![.]>()?;
188                    let next: Ident = input.parse()?;
189                    name.push('.');
190                    name.push_str(&next.to_string());
191                }
192                JavaType::Object(name)
193            }
194        }
195    } else {
196        return Err(input.error("expected Java type"));
197    };
198
199    if input.peek(syn::token::Bracket) {
200        let content;
201        syn::bracketed!(content in input);
202        let _ = content; // empty []
203        return Ok(JavaType::Array(Box::new(ty)));
204    }
205
206    Ok(ty)
207}
208
209impl Parse for JavaMethod {
210    fn parse(input: ParseStream) -> Result<Self> {
211        let mut is_static = false;
212        let mut is_native = false;
213
214        loop {
215            if input.peek(Token![static]) {
216                input.parse::<Token![static]>()?;
217                is_static = true;
218            } else if input.peek(Ident) {
219                let ident: Ident = input.fork().parse()?;
220                match ident.to_string().as_str() {
221                    "native" => { input.parse::<Ident>()?; is_native = true; }
222                    "public" | "private" | "protected" | "final" | "synchronized" => {
223                        input.parse::<Ident>()?;
224                    }
225                    _ => break,
226                }
227            } else {
228                break;
229            }
230        }
231
232        let return_type = parse_java_type(input)?;
233        let name: Ident = input.parse()?;
234
235        let content;
236        syn::parenthesized!(content in input);
237
238        let mut params = Vec::new();
239        while !content.is_empty() {
240            let ty = parse_java_type(&content)?;
241            let param_name: Ident = content.parse()?;
242            params.push(JavaParam { ty, name: param_name });
243            if content.peek(Token![,]) {
244                content.parse::<Token![,]>()?;
245            }
246        }
247
248        input.parse::<Token![;]>()?;
249
250        Ok(JavaMethod { is_static, is_native, return_type, name, params })
251    }
252}
253
254impl Parse for JavaClass {
255    fn parse(input: ParseStream) -> Result<Self> {
256        let pkg_kw: Ident = input.parse()?;
257        if pkg_kw != "package" {
258            return Err(syn::Error::new(pkg_kw.span(), "expected 'package'"));
259        }
260
261        let mut package = String::new();
262        loop {
263            let ident: Ident = input.parse()?;
264            package.push_str(&ident.to_string());
265            if input.peek(Token![;]) {
266                input.parse::<Token![;]>()?;
267                break;
268            }
269            input.parse::<Token![.]>()?;
270            package.push('.');
271        }
272
273        let class_kw: Ident = input.parse()?;
274        if class_kw != "class" {
275            return Err(syn::Error::new(class_kw.span(), "expected 'class'"));
276        }
277
278        let name: Ident = input.parse()?;
279
280        let content;
281        syn::braced!(content in input);
282
283        let mut methods = Vec::new();
284        while !content.is_empty() {
285            methods.push(content.parse::<JavaMethod>()?);
286        }
287
288        Ok(JavaClass { package, name, methods })
289    }
290}
291
292fn class_path(package: &str, name: &str) -> String {
293    format!("{}.{}", package, name).replace('.', "/")
294}
295
296fn generate_method(
297    class_path_lit: &LitStr,
298    method: &JavaMethod,
299) -> TokenStream2 {
300    let method_name = &method.name;
301    let method_name_str = method_name.to_string();
302
303    let param_sig: String = method.params.iter().map(|p| p.ty.to_jni_sig()).collect();
304    let return_sig = method.return_type.to_jni_sig();
305    let full_sig = format!("({}){}", param_sig, return_sig);
306    let sig_lit = LitStr::new(&full_sig, Span::call_site());
307    let method_name_lit = LitStr::new(&method_name_str, Span::call_site());
308
309    let rust_params: Vec<TokenStream2> = method.params.iter().map(|p| {
310        let pname = &p.name;
311        let pty = p.ty.to_rust_type();
312        quote! { #pname: #pty }
313    }).collect();
314
315    let string_conversions: Vec<TokenStream2> = method.params.iter().map(|p| {
316        let pname = &p.name;
317        match &p.ty {
318            JavaType::Object(name) if name == "String" => {
319                let tmp = Ident::new(&format!("__jstr_{}", pname), Span::call_site());
320                quote! { let #tmp = env.new_string(#pname)?; }
321            }
322            _ => quote! {},
323        }
324    }).collect();
325
326    let jvalues: Vec<TokenStream2> = method.params.iter().map(|p| {
327        p.ty.to_jvalue(&p.name)
328    }).collect();
329
330    let return_type = method.return_type.to_rust_return_type();
331
332    let call = if method.is_static {
333        quote! { env.call_static_method(jni_str!(#class_path_lit), jni_str!(#method_name_lit), jni_sig!(#sig_lit), &[#(#jvalues),*])? }
334    } else {
335        quote! { env.call_method(obj, jni_str!(#method_name_lit), jni_sig!(#sig_lit), &[#(#jvalues),*])? }
336    };
337
338    let body = match &method.return_type {
339        JavaType::Void => quote! {
340            #(#string_conversions)*
341            #call;
342            Ok(())
343        },
344        _ => {
345            let extract = method.return_type.extract_return(call);
346            quote! {
347                #(#string_conversions)*
348                #extract
349            }
350
351        }
352    };
353
354    if method.is_static {
355        quote! {
356            pub fn #method_name<'caller, 'refs>(
357                env: &'refs mut jni::Env<'caller>,
358                #(#rust_params),*
359            ) -> Result<#return_type, jni::errors::Error> {
360                #body
361            }
362        }
363    } else {
364        quote! {
365            pub fn #method_name<'caller, 'refs>(
366                env: &'refs mut jni::Env<'caller>,
367                obj: &'refs jni::objects::JObject<'caller>,
368                #(#rust_params),*
369            ) -> Result<#return_type, jni::errors::Error> {
370                #body
371            }
372        }
373    }
374}
375
376#[proc_macro]
377pub fn java_class_decl(input: TokenStream) -> TokenStream {
378    let class = parse_macro_input!(input as JavaClass);
379
380    let struct_name = &class.name;
381    let cp = class_path(&class.package, &class.name.to_string());
382    let class_path_lit = LitStr::new(&cp, Span::call_site());
383
384    let native_registrations: Vec<TokenStream2> = class.methods.iter()
385        .filter(|m| m.is_native)
386        .map(|m| {
387            let name = &m.name;
388            let package_class = format!("{}.{}", class.package, class.name);
389            let package_class_lit = LitStr::new(&package_class, Span::call_site());
390            let return_type = m.return_type.to_jni_return_type();
391            let params: Vec<TokenStream2> = m.params.iter().map(|p| {
392                p.ty.to_jni_param_type()
393            }).collect();
394
395            if m.is_static {
396                quote! {
397                    const _: jni::NativeMethod = jni::native_method! {
398                        java_type = #package_class_lit,
399                        static extern fn #name(#(#params),*) #return_type,
400                    };
401                }
402            } else {
403                quote! {
404                    const _: jni::NativeMethod = jni::native_method! {
405                        java_type = #package_class_lit,
406                        extern fn #name(#(#params),*) #return_type,
407                    };
408                }
409            }
410        })
411        .collect();
412
413    let validate_checks: Vec<TokenStream2> = class.methods.iter()
414        .filter(|m| !m.is_native)
415        .map(|m| {
416            let method_name_str = m.name.to_string();
417            let method_name_lit = LitStr::new(&method_name_str, Span::call_site());
418            let param_sig: String = m.params.iter().map(|p| p.ty.to_jni_sig()).collect();
419            let return_sig = m.return_type.to_jni_sig();
420            let sig_lit = LitStr::new(&format!("({}){}", param_sig, return_sig), Span::call_site());
421            let class_path_lit_str = LitStr::new(&format!("{}.{}", class.package, class.name), Span::call_site());
422
423            if m.is_static {
424                quote! {
425                env.get_static_method_id(
426                    jni_str!(#class_path_lit_str),
427                    jni_str!(#method_name_lit),
428                    jni_sig!(#sig_lit),
429                )?;
430            }
431            } else {
432                quote! {
433                env.get_method_id(
434                    jni_str!(#class_path_lit_str),
435                    jni_str!(#method_name_lit),
436                    jni_sig!(#sig_lit),
437                )?;
438            }
439            }
440        })
441        .collect();
442
443    let methods: Vec<TokenStream2> = class.methods.iter()
444        .filter(|m| !m.is_native)
445        .map(|m| generate_method(&class_path_lit, m))
446        .collect();
447
448
449    let expanded = quote! {
450        #(#native_registrations)*
451
452        pub struct #struct_name;
453
454        impl #struct_name {
455            pub fn _validate_interface(env: &mut jni::Env<'_>) -> Result<(), jni::errors::Error> {
456                #(#validate_checks)*
457                Ok(())
458            }
459
460            #(#methods)*
461        }
462    };
463
464    expanded.into()
465}