oo_bindgen/backend/java/jni/
mod.rs

1use crate::backend::*;
2use crate::model::*;
3
4use crate::backend::java::jni::conversion::*;
5
6use std::path::Path;
7
8mod classes;
9mod conversion;
10mod enums;
11mod exceptions;
12mod interface;
13mod structs;
14
15/// Configuration for JNI (Rust) generation
16pub struct JniBindgenConfig<'a> {
17    /// Maven group id (e.g. io.stepfunc)
18    pub group_id: &'a str,
19    /// Name of the FFI target
20    pub ffi_name: &'a str,
21}
22
23impl JniBindgenConfig<'_> {
24    fn java_signature_path(&self, libname: &str) -> String {
25        let mut result = self.group_id.replace('.', "/");
26        result.push('/');
27        result.push_str(libname);
28        result
29    }
30}
31
32fn module_string(name: &str, f: &mut dyn Printer, content: &str) -> FormattingResult<()> {
33    module(name, f, |f| {
34        for line in content.lines() {
35            f.writeln(line)?;
36        }
37        Ok(())
38    })
39}
40
41fn module<F>(name: &str, f: &mut dyn Printer, write: F) -> FormattingResult<()>
42where
43    F: Fn(&mut dyn Printer) -> FormattingResult<()>,
44{
45    f.newline()?;
46    f.writeln(&format!("pub(crate) mod {name} {{"))?;
47    indented(f, |f| write(f))?;
48    f.writeln("}")?;
49    Ok(())
50}
51
52/// Generate all of the JNI (Rust) source code that glues the Java and FFI together
53///
54/// This function is typically called from a build.rs script in a target that builds
55/// the JNI shared library
56pub fn generate_jni(path: &Path, lib: &Library, config: &JniBindgenConfig) -> FormattingResult<()> {
57    let mut f = FilePrinter::new(path)?;
58
59    generate_cache(&mut f)?;
60    write_functions(&mut f, lib, config)?;
61    write_collection_conversions(&mut f, lib, config)?;
62    write_iterator_conversions(&mut f, lib, config)?;
63
64    module("classes", &mut f, |f| {
65        classes::generate_classes_cache(f, lib, config)
66    })?;
67
68    module("enums", &mut f, |f| {
69        enums::generate_enums_cache(f, lib, config)
70    })?;
71
72    module("structs", &mut f, |f| structs::generate(f, lib, config))?;
73
74    module("interfaces", &mut f, |f| {
75        interface::generate_interfaces_cache(f, lib, config)
76    })?;
77
78    module("exceptions", &mut f, |f| {
79        exceptions::generate_exceptions_cache(f, lib, config)
80    })?;
81
82    // Copy the modules that never change
83    module_string("primitives", &mut f, include_str!("copy/primitives.rs"))?;
84    module_string("unsigned", &mut f, include_str!("copy/unsigned.rs"))?;
85    module_string("duration", &mut f, include_str!("copy/duration.rs"))?;
86    module_string("collection", &mut f, include_str!("copy/collection.rs"))?;
87    module_string("pointers", &mut f, include_str!("copy/pointers.rs"))?;
88    module_string("util", &mut f, include_str!("copy/util.rs"))?;
89
90    Ok(())
91}
92
93fn generate_cache(f: &mut dyn Printer) -> FormattingResult<()> {
94    // Create cache
95    f.writeln("pub(crate) struct JCache")?;
96    blocked(f, |f| {
97        f.writeln("vm: jni::JavaVM,")?;
98        f.writeln("primitives: primitives::Primitives,")?;
99        f.writeln("unsigned: unsigned::Unsigned,")?;
100        f.writeln("duration: duration::Duration,")?;
101        f.writeln("collection: collection::Collection,")?;
102        f.writeln("classes: classes::Classes,")?;
103        f.writeln("enums: enums::Enums,")?;
104        f.writeln("structs: structs::Structs,")?;
105        f.writeln("interfaces: interfaces::Interfaces,")?;
106        f.writeln("exceptions: exceptions::Exceptions,")?;
107        Ok(())
108    })?;
109
110    f.newline()?;
111
112    f.writeln("impl JCache")?;
113    blocked(f, |f| {
114        f.writeln("fn init(vm: jni::JavaVM) -> Self")?;
115        blocked(f, |f| {
116            f.writeln("let env = vm.get_env().unwrap();")?;
117            f.writeln("let primitives = primitives::Primitives::init(&env);")?;
118            f.writeln("let unsigned = unsigned::Unsigned::init(&env);")?;
119            f.writeln("let duration = duration::Duration::init(&env);")?;
120            f.writeln("let collection = collection::Collection::init(&env);")?;
121            f.writeln("let classes = classes::Classes::init(&env);")?;
122            f.writeln("let enums = enums::Enums::init(&env);")?;
123            f.writeln("let structs = structs::Structs::init(&env);")?;
124            f.writeln("let interfaces = interfaces::Interfaces::init(&env);")?;
125            f.writeln("let exceptions = exceptions::Exceptions::init(&env);")?;
126            f.writeln("Self")?;
127            blocked(f, |f| {
128                f.writeln("vm,")?;
129                f.writeln("primitives,")?;
130                f.writeln("unsigned,")?;
131                f.writeln("duration,")?;
132                f.writeln("collection,")?;
133                f.writeln("classes,")?;
134                f.writeln("enums,")?;
135                f.writeln("structs,")?;
136                f.writeln("interfaces,")?;
137                f.writeln("exceptions,")?;
138                Ok(())
139            })
140        })
141    })?;
142
143    f.newline()?;
144
145    f.writeln("static mut JCACHE: Option<JCache> = None;")?;
146
147    f.newline()?;
148
149    f.writeln("pub(crate) fn get_cache<'a>() -> &'a JCache {")?;
150    indented(f, |f| {
151        f.writeln("// safety: this is only called after initialization / JVM load")
152    })?;
153    indented(f, |f| f.writeln("#[allow(static_mut_refs)]"))?;
154    indented(f, |f| f.writeln("unsafe { JCACHE.as_ref().unwrap() }"))?;
155    f.writeln("}")?;
156
157    f.newline()?;
158
159    // OnLoad function
160    f.writeln("#[no_mangle]")?;
161    f.writeln("pub extern \"C\" fn JNI_OnLoad(vm: *mut jni::sys::JavaVM, _: *mut std::ffi::c_void) -> jni::sys::jint")?;
162    blocked(f, |f| {
163        f.writeln("let vm = unsafe { jni::JavaVM::from_raw(vm).unwrap() };")?;
164        f.writeln("let jcache = JCache::init(vm);")?;
165        f.writeln("// safety: this is only called during library loading")?;
166        f.writeln("unsafe { JCACHE = Some(jcache); };")?;
167        f.writeln("jni::JNIVersion::V8.into()")
168    })?;
169
170    f.newline()?;
171
172    // OnUnload function
173    f.writeln("#[no_mangle]")?;
174    f.writeln("pub extern \"C\" fn JNI_OnUnload(_vm: *mut jni::sys::JavaVM, _: *mut std::ffi::c_void) -> jni::sys::jint")?;
175    blocked(f, |f| {
176        f.writeln("// safety: this is only called during library unloading / JVM shutdown")?;
177        f.writeln("unsafe { JCACHE = None; }")?;
178        f.writeln("return 0;")
179    })
180}
181
182fn write_collection_conversions(
183    f: &mut dyn Printer,
184    lib: &Library,
185    config: &JniBindgenConfig,
186) -> FormattingResult<()> {
187    f.newline()?;
188    f.writeln("/// convert Java lists into native API collections")?;
189    f.writeln("pub(crate) mod collections {")?;
190    indented(f, |f| {
191        for col in lib.collections() {
192            f.newline()?;
193            write_collection_guard(f, config, col)?;
194        }
195        Ok(())
196    })?;
197    f.writeln("}")
198}
199
200fn write_iterator_conversions(
201    f: &mut dyn Printer,
202    lib: &Library,
203    config: &JniBindgenConfig,
204) -> FormattingResult<()> {
205    f.newline()?;
206    f.writeln("/// functions that convert native API iterators into Java lists")?;
207    f.writeln("pub(crate) mod iterators {")?;
208    indented(f, |f| {
209        for iter in lib.iterators() {
210            f.newline()?;
211            write_iterator_conversion(f, config, iter)?;
212        }
213        Ok(())
214    })?;
215    f.writeln("}")
216}
217
218fn write_iterator_conversion(
219    f: &mut dyn Printer,
220    config: &JniBindgenConfig,
221    iter: &Handle<AbstractIterator<Validated>>,
222) -> FormattingResult<()> {
223    f.writeln(&format!("pub(crate) fn {}(_env: &jni::JNIEnv, _cache: &crate::JCache, iter: {}) -> jni::sys::jobject {{", iter.name(), iter.iter_class.get_rust_type(config.ffi_name)))?;
224    indented(f, |f| {
225        f.writeln("let list = _cache.collection.new_array_list(&_env);")?;
226        f.writeln(&format!(
227            "while let Some(next) = unsafe {{ {}::ffi::{}_{}(iter).as_ref() }} {{",
228            config.ffi_name, iter.iter_class.settings.c_ffi_prefix, iter.next_function.name
229        ))?;
230        indented(f, |f| {
231            match &iter.item_type {
232                IteratorItemType::Primitive(x) => {
233                    let converted = x
234                        .maybe_convert("*next")
235                        .unwrap_or_else(|| "*next".to_string());
236                    f.writeln(&format!("let next = _env.auto_local({converted});"))?;
237                }
238                IteratorItemType::Struct(x) => {
239                    f.writeln(&format!(
240                        "let next = _env.auto_local({});",
241                        x.convert("next")
242                    ))?;
243                }
244            }
245            f.writeln("_cache.collection.add_to_array_list(&_env, list, next.as_obj().into());")
246        })?;
247        f.writeln("}")?;
248        f.writeln("list.into_inner()")
249    })?;
250    f.writeln("}")
251}
252
253fn write_collection_guard(
254    f: &mut dyn Printer,
255    config: &JniBindgenConfig,
256    col: &Handle<Collection<Validated>>,
257) -> FormattingResult<()> {
258    let collection_name = col.collection_class.name.camel_case();
259    let c_ffi_prefix = col.collection_class.settings.c_ffi_prefix.clone();
260
261    f.writeln("/// Guard that builds the C collection type from a Java list")?;
262    f.writeln(&format!("pub(crate) struct {collection_name} {{"))?;
263    indented(f, |f| {
264        f.writeln(&format!(
265            "inner: *mut {}::{}",
266            config.ffi_name, collection_name
267        ))
268    })?;
269    f.writeln("}")?;
270
271    f.newline()?;
272
273    f.writeln(&format!("impl std::ops::Deref for {collection_name} {{"))?;
274    indented(f, |f| {
275        f.writeln(&format!(
276            "type Target = *mut {}::{};",
277            config.ffi_name, collection_name
278        ))?;
279        f.newline()?;
280        f.writeln("fn deref(&self) -> &Self::Target {")?;
281        indented(f, |f| f.writeln("&self.inner"))?;
282        f.writeln("}")
283    })?;
284    f.writeln("}")?;
285
286    f.newline()?;
287
288    f.writeln(&format!("impl {collection_name} {{"))?;
289    indented(f, |f| {
290        f.writeln("pub(crate) fn new(_env: jni::JNIEnv, list: jni::sys::jobject) -> Result<Self, jni::errors::Error> {")?;
291        indented(f, |f| {
292            f.writeln("let _cache = crate::get_cache();")?;
293            let size = if col.has_reserve {
294                f.writeln("let size = _cache.collection.get_size(&_env, list.into());")?;
295                "size"
296            } else {
297                ""
298            };
299            f.writeln(&format!(
300                "let col = Self {{ inner: unsafe {{ {}::ffi::{}_{}({}) }} }};",
301                config.ffi_name, c_ffi_prefix, col.create_func.name, size
302            ))?;
303            f.writeln(
304                "let it = _env.auto_local(_cache.collection.get_iterator(&_env, list.into()));",
305            )?;
306            f.writeln("while _cache.collection.has_next(&_env, it.as_obj()) {")?;
307            indented(f, |f| {
308                f.writeln(
309                    "let next = _env.auto_local(_cache.collection.next(&_env, it.as_obj()));",
310                )?;
311                if let Some(converted) = col
312                    .item_type
313                    .to_rust_from_object("next.as_obj().into_inner()")
314                {
315                    // perform  primary conversion that shadows the variable
316                    f.writeln(&format!("let next = {converted};"))?;
317                }
318                let arg = col
319                    .item_type
320                    .call_site("next")
321                    .unwrap_or_else(|| "next".to_string());
322                f.writeln(&format!(
323                    "unsafe {{ {}::ffi::{}_{}(col.inner, {}) }};",
324                    config.ffi_name, c_ffi_prefix, col.add_func.name, arg
325                ))?;
326                Ok(())
327            })?;
328            f.writeln("}")?;
329            f.writeln("Ok(col)")?;
330            Ok(())
331        })?;
332        f.writeln("}")
333    })?;
334    f.writeln("}")?;
335
336    f.newline()?;
337
338    f.writeln("/// Destroy the C collection on drop")?;
339    f.writeln(&format!("impl Drop for {collection_name} {{"))?;
340    indented(f, |f| {
341        f.writeln("fn drop(&mut self) {")?;
342        indented(f, |f| {
343            f.writeln(&format!(
344                "unsafe {{ {}::ffi::{}_{}(self.inner) }}",
345                config.ffi_name, c_ffi_prefix, col.delete_func.name
346            ))
347        })?;
348        f.writeln("}")
349    })?;
350    f.writeln("}")
351}
352
353fn write_functions(
354    f: &mut dyn Printer,
355    lib: &Library,
356    config: &JniBindgenConfig,
357) -> FormattingResult<()> {
358    fn skip(c: FunctionCategory) -> bool {
359        match c {
360            FunctionCategory::Native => false,
361            // these all get used internally to the JNI and
362            // don't need external wrappers accessed from Java
363            FunctionCategory::CollectionCreate => true,
364            FunctionCategory::CollectionDestroy => true,
365            FunctionCategory::CollectionAdd => true,
366            FunctionCategory::IteratorNext => true,
367        }
368    }
369
370    for handle in lib.functions().filter(|f| !skip(f.category)) {
371        f.newline()?;
372        write_function(f, lib, config, handle)?;
373    }
374    Ok(())
375}
376
377fn write_function_signature(
378    f: &mut dyn Printer,
379    lib: &Library,
380    config: &JniBindgenConfig,
381    handle: &Handle<Function<Validated>>,
382) -> FormattingResult<()> {
383    let args = handle
384        .arguments
385        .iter()
386        .map(|param| format!("{}: {}", param.name, param.arg_type.jni_signature_type()))
387        .collect::<Vec<String>>()
388        .join(", ");
389
390    let returns = match handle.return_type.get_value() {
391        None => "".to_string(),
392        Some(x) => {
393            format!(" -> {}", x.jni_signature_type())
394        }
395    };
396
397    f.writeln("#[no_mangle]")?;
398    f.writeln(
399        &format!(
400            "pub extern \"C\" fn Java_{}_{}_NativeFunctions_{}(_env: jni::JNIEnv, _: jni::sys::jobject, {}){}",
401            config.group_id.replace('.', "_"),
402            lib.settings.name,
403            handle.name.replace('_', "_1"),
404            args,
405            returns
406        )
407    )
408}
409
410fn write_function(
411    f: &mut dyn Printer,
412    lib: &Library,
413    config: &JniBindgenConfig,
414    handle: &Handle<Function<Validated>>,
415) -> FormattingResult<()> {
416    write_function_signature(f, lib, config, handle)?;
417    blocked(f, |f| {
418        // Get the JCache
419        f.writeln("let _cache = get_cache();")?;
420
421        f.newline()?;
422
423        // Perform the primary conversion of the parameters if required
424        for param in &handle.arguments {
425            if let Some(converted) = param.arg_type.to_rust(&param.name) {
426                let conversion = format!("let {} = {};", param.name, converted);
427                f.writeln(&conversion)?;
428            }
429        }
430
431        f.newline()?;
432
433        let extra_param = match handle.get_signature_type() {
434            SignatureType::NoErrorNoReturn => None,
435            SignatureType::NoErrorWithReturn(_, _) => None,
436            SignatureType::ErrorNoReturn(_) => None,
437            SignatureType::ErrorWithReturn(_, _, _) => Some("_out.as_mut_ptr()".to_string()),
438        };
439
440        // list of arguments in the invocation
441        let args = handle
442            .arguments
443            .iter()
444            .map(|param| {
445                param
446                    .arg_type
447                    .call_site(&param.name)
448                    .unwrap_or_else(|| param.name.to_string())
449            })
450            .chain(extra_param)
451            .collect::<Vec<String>>()
452            .join(", ");
453
454        // the invocation of the native function
455        let invocation = format!(
456            "unsafe {{ {}::ffi::{}_{}({}) }}",
457            config.ffi_name, lib.settings.c_ffi_prefix, handle.name, args
458        );
459
460        // Call the C FFI
461        match handle.get_signature_type() {
462            SignatureType::NoErrorNoReturn => {
463                f.writeln(&format!("{invocation};"))?;
464            }
465            SignatureType::NoErrorWithReturn(_, _) | SignatureType::ErrorNoReturn(_) => {
466                f.writeln(&format!("let _result = {invocation};"))?;
467            }
468            SignatureType::ErrorWithReturn(_, _, _) => {
469                f.writeln("let mut _out = std::mem::MaybeUninit::uninit();")?;
470                f.writeln(&format!("let _result = {invocation};"))?;
471            }
472        };
473
474        // Convert return value
475        match handle.get_signature_type() {
476            SignatureType::NoErrorNoReturn => (),
477            SignatureType::NoErrorWithReturn(return_type, _) => {
478                if let Some(conversion) = return_type.maybe_convert("_result") {
479                    f.writeln(&format!("let _result = {conversion};"))?;
480                }
481            }
482            SignatureType::ErrorNoReturn(error_type) => {
483                f.writeln("if _result != 0")?;
484                blocked(f, |f| {
485                    error_type.inner.convert("_result");
486                    f.writeln(&format!(
487                        "let _error = {};",
488                        error_type.inner.convert("_result")
489                    ))?;
490                    f.writeln(&format!(
491                        "let error = _cache.exceptions.{}.throw(&_env, _error);",
492                        error_type.exception_name
493                    ))
494                })?;
495            }
496            SignatureType::ErrorWithReturn(error_type, return_type, _) => {
497                f.writeln("let _result = if _result == 0")?;
498                blocked(f, |f| {
499                    f.writeln("let _result = unsafe { _out.assume_init() };")?;
500                    if let Some(conversion) = return_type.maybe_convert("_result") {
501                        f.writeln(&conversion)?;
502                    }
503                    Ok(())
504                })?;
505                f.writeln("else")?;
506                blocked(f, |f| {
507                    f.writeln(&format!(
508                        "let _error = {};",
509                        error_type.inner.convert("_result")
510                    ))?;
511                    f.writeln(&format!(
512                        "let error = _cache.exceptions.{}.throw(&_env, _error);",
513                        error_type.exception_name
514                    ))?;
515                    f.writeln(return_type.get_default_value())
516                })?;
517                f.write(";")?;
518            }
519        }
520
521        // Return value
522        if !handle.return_type.is_none() {
523            f.writeln("return _result.into();")?;
524        }
525
526        Ok(())
527    })
528}