use crate::util::{ts2, CompileErrors};
use quote::quote;
use syn::__private::TokenStream2;
fn extract_rust_type(ty: &TokenStream2) -> (String, String) {
let ty = ty.to_string().replace("& mut", "").replace(' ', "");
let generic = ty.rfind("<").unwrap_or(ty.len());
let Some(split_index) = ty.rfind("::") else {
return ("".to_string(), ty);
};
if split_index > generic {
return ("".to_string(), ty);
}
let (left, right) = ty.split_at(split_index + 2);
(left.to_string(), right.to_string())
}
pub fn to_java_list(rust_type: String, errors: &mut CompileErrors) -> String {
let default = "List<Object>".to_string();
let Some(split_index) = rust_type.find('<') else {
return default;
};
let (_, right) = rust_type.split_at(split_index + 1);
let Some(split_index) = right.rfind('>') else {
return default;
};
let (ty, _) = right.split_at(split_index);
if ty == "u8" || ty == "i8" {
return "List<Byte>".to_string();
}
if ty == "i16" {
return "List<Short>".to_string();
}
if ty == "i32" {
return "List<Integer>".to_string();
}
if ty == "i64" {
return "List<Long>".to_string();
}
if ty == "f32" {
return "List<Float>".to_string();
}
if ty == "f64" {
return "List<Double>".to_string();
}
if ty == "char" {
return "List<Character>".to_string();
}
if ty == "bool" {
return "List<Boolean>".to_string();
}
let obj = rewrite_rust_to_java(&ts2(ty), errors).unwrap_or("Object".to_string());
format!("List<{obj}>")
}
fn extract_from_option(rust_type: String, errors: &mut CompileErrors) -> String {
let default = "void".to_string();
let Some(split_index) = rust_type.find('<') else {
return default;
};
let (_, right) = rust_type.split_at(split_index + 1);
let Some(split_index) = right.rfind('>') else {
return default;
};
let (ty, _) = right.split_at(split_index);
rewrite_rust_to_java(&ts2(ty), errors).unwrap_or("void".to_string())
}
pub fn rewrite_rust_to_java(ty: &TokenStream2, errors: &mut CompileErrors) -> Option<String> {
let (_, rust_type) = extract_rust_type(&ty);
if rust_type.starts_with("JNIEnv<") {
return None;
};
if rust_type.starts_with("JClass<") {
return None;
};
if rust_type.starts_with("JList<") {
return Some(to_java_list(rust_type, errors));
};
if rust_type.starts_with("Option<") {
return Some(extract_from_option(rust_type, errors));
};
if rust_type == "()" || rust_type == "" {
return Some("void".to_string());
};
if rust_type == "jbyte" {
return Some("byte".to_string());
};
if rust_type == "jchar" {
return Some("char".to_string());
};
if rust_type == "jboolean" {
return Some("boolean".to_string());
};
if rust_type == "jint" {
return Some("int".to_string());
};
if rust_type == "jshort" {
return Some("short".to_string());
};
if rust_type == "jlong" {
return Some("long".to_string());
};
if rust_type == "jfloat" {
return Some("float".to_string());
};
if rust_type == "jdouble" {
return Some("double".to_string());
};
if rust_type.starts_with("JString<") {
return Some("String".to_string());
};
if rust_type.starts_with("JObject<") {
return Some("Object".to_string());
};
if rust_type.starts_with("JByteArray<") {
return Some("byte[]".to_string());
};
if rust_type == "JByte" {
return Some("Byte".to_string());
};
if rust_type == "JShort" {
return Some("Short".to_string());
};
if rust_type == "JInt" {
return Some("Integer".to_string());
};
if rust_type == "JLong" {
return Some("Long".to_string());
};
if rust_type == "JFloat" {
return Some("Float".to_string());
};
if rust_type == "JDouble" {
return Some("Double".to_string());
};
if rust_type == "JBoolean" {
return Some("Boolean".to_string());
};
if rust_type == "JChar" {
return Some("Character".to_string());
};
if rust_type == "u8" || rust_type == "i8" {
return Some("byte".to_string());
}
if rust_type == "i16" {
return Some("short".to_string());
}
if rust_type == "i32" {
return Some("int".to_string());
}
if rust_type == "i64" {
return Some("long".to_string());
}
if rust_type == "f32" {
return Some("float".to_string());
}
if rust_type == "f64" {
return Some("double".to_string());
}
if rust_type == "bool" {
return Some("boolean".to_string());
}
if rust_type == "char" {
return Some("char".to_string());
}
if rust_type == "Vec<u8>" {
return Some("byte[]".to_string());
}
if rust_type == "String" {
return Some("String".to_string());
};
if rust_type == "Vec<u8>" {
return Some("byte[]".to_string());
};
Some(rust_type)
}
fn extract_jni_from_option(
rust_type: String,
lifetime: &TokenStream2,
errors: &mut CompileErrors,
) -> Option<TokenStream2> {
let Some(split_index) = rust_type.find('<') else {
return None;
};
let (_, right) = rust_type.split_at(split_index + 1);
let Some(split_index) = right.rfind('>') else {
return None;
};
let (ty, _) = right.split_at(split_index);
rewrite_rust_type_to_jni(&ts2(ty), lifetime, errors)
}
const OBJECT_TYPES: &[&str] = &[
"()", "JByte", "JShort", "JInt", "JLong", "JFloat", "JDouble", "JBoolean", "JChar",
];
pub fn rewrite_rust_type_to_jni(
ty: &TokenStream2,
lifetime: &TokenStream2,
errors: &mut CompileErrors,
) -> Option<TokenStream2> {
let (_, rust_type) = extract_rust_type(&ty);
if rust_type.starts_with("JNIEnv<") {
return None;
};
if rust_type.starts_with("JClass<") {
return None;
};
if rust_type.starts_with("Option<") {
return extract_jni_from_option(rust_type, lifetime, errors);
};
if rust_type.starts_with("JString<") {
return Some(quote! { jni::objects::JString #lifetime });
};
if rust_type.starts_with("JByteArray<") {
return Some(quote! { jni::objects::JByteArray #lifetime });
};
if OBJECT_TYPES.contains(&rust_type.as_str()) {
return Some(quote! { jni::objects::JObject #lifetime });
};
if rust_type == "jbyte" || rust_type == "u8" || rust_type == "i8" {
return Some(quote! { jni::sys::jbyte });
};
if rust_type == "jboolean" || rust_type == "bool" {
return Some(quote! { jni::sys::jboolean });
};
if rust_type == "jshort" || rust_type == "i16" {
return Some(quote! { jni::sys::jshort });
};
if rust_type == "jint" || rust_type == "i32" {
return Some(quote! { jni::sys::jint });
};
if rust_type == "jlong" || rust_type == "i64" {
return Some(quote! { jni::sys::jlong });
};
if rust_type == "jfloat" || rust_type == "f32" {
return Some(quote! { jni::sys::jfloat });
};
if rust_type == "jdouble" || rust_type == "f64" {
return Some(quote! { jni::sys::jdouble });
};
if rust_type == "jchar" || rust_type == "char" {
return Some(quote! { jni::sys::jchar });
};
if rust_type == "String" {
return Some(quote! { jni::objects::JString #lifetime });
};
if rust_type == "Vec<u8>" {
return Some(quote! { jni::objects::JByteArray #lifetime });
};
Some(quote! { jni::objects::JObject #lifetime })
}
#[cfg(test)]
pub mod tests {
use super::{extract_rust_type, rewrite_rust_to_java};
use crate::{
types_conversion::rewrite_rust_type_to_jni,
util::{ts2, CompileErrors},
};
#[test]
fn should_extract_type() {
let ty = extract_rust_type(&ts2("String"));
assert_eq!(("".to_string(), "String".to_string()), ty);
let ty = extract_rust_type(&ts2("std::string::String"));
assert_eq!(("std::string::".to_string(), "String".to_string()), ty);
let ty = extract_rust_type(&ts2("JList<std::string::String>"));
assert_eq!(
("".to_string(), "JList<std::string::String>".to_string()),
ty
);
}
#[test]
fn should_rewrite_to_java() {
let errors = &mut CompileErrors::default();
let ty = rewrite_rust_to_java(&ts2("Option<String>"), errors);
assert_eq!(Some("String".to_string()), ty);
let ty = rewrite_rust_to_java(&ts2("Option<std::string::String>"), errors);
assert_eq!(Some("String".to_string()), ty);
let ty = rewrite_rust_to_java(&ts2("JList<std::string::String>"), errors);
assert_eq!(Some("List<String>".to_string()), ty);
let ty = rewrite_rust_to_java(&ts2("Option<JList<std::string::String>>"), errors);
assert_eq!(Some("List<String>".to_string()), ty);
let ty = rewrite_rust_to_java(&ts2("jni::sys::jint"), errors);
assert_eq!(Some("int".to_string()), ty);
let ty = rewrite_rust_to_java(&ts2("java_bindgen::interop::JLong"), errors);
assert_eq!(Some("Long".to_string()), ty);
}
#[test]
fn should_rewrite_to_jni() {
let errors = &mut CompileErrors::default();
let lifetime = ts2("<'local>");
let ty = rewrite_rust_type_to_jni(&ts2("JNIEnv<'l>"), &lifetime, errors).map(|ts| ts.to_string());
assert_eq!(None, ty);
let ty = rewrite_rust_type_to_jni(&ts2("&mut JNIEnv<'l>"), &lifetime, errors).map(|ts| ts.to_string());
assert_eq!(None, ty);
let ty = rewrite_rust_type_to_jni(&ts2("JDouble"), &lifetime, errors).map(|ts| ts.to_string());
assert_eq!(Some("jni :: objects :: JObject <'local >"), ty.as_deref());
let ty = rewrite_rust_type_to_jni(&ts2("MyCustomClassStruct"), &lifetime, errors).map(|ts| ts.to_string());
assert_eq!(Some("jni :: objects :: JObject <'local >"), ty.as_deref());
}
}