1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
use super::*;

use std::marker::*;
use std::ops::*;

/// A Java Array of some POD-like type such as bool, jbyte, jchar, jshort, jint, jlong, jfloat, or jdouble.
/// 
/// See also [ObjectArray] for arrays of reference types.
/// 
/// | JNI Type      | PrimitiveArray Implementation |
/// | ------------- | ----------------- |
/// | [bool]\[\]    | [BooleanArray]    |
/// | [jbyte]\[\]   | [ByteArray]       |
/// | [jchar]\[\]   | [CharArray]       |
/// | [jint]\[\]    | [IntArray]        |
/// | [jlong]\[\]   | [LongArray]       |
/// | [jfloat]\[\]  | [FloatArray]      |
/// | [jdouble]\[\] | [DoubleArray]     |
/// 
/// [bool]:         https://doc.rust-lang.org/std/primitive.bool.html
/// [jbyte]:        https://docs.rs/jni-sys/0.3.0/jni_sys/type.jbyte.html
/// [jchar]:        struct.jchar.html
/// [jint]:         https://docs.rs/jni-sys/0.3.0/jni_sys/type.jint.html
/// [jlong]:        https://docs.rs/jni-sys/0.3.0/jni_sys/type.jlong.html
/// [jfloat]:       https://docs.rs/jni-sys/0.3.0/jni_sys/type.jfloat.html
/// [jdouble]:      https://docs.rs/jni-sys/0.3.0/jni_sys/type.jdouble.html
/// 
/// [BooleanArray]: struct.BooleanArray.html
/// [ByteArray]:    struct.ByteArray.html
/// [CharArray]:    struct.CharArray.html
/// [IntArray]:     struct.IntArray.html
/// [LongArray]:    struct.LongArray.html
/// [FloatArray]:   struct.FloatArray.html
/// [DoubleArray]:  struct.DoubleArray.html
/// [ObjectArray]:  struct.ObjectArray.html
/// 
pub trait PrimitiveArray<T> where Self : Sized + AsValidJObjectAndEnv, T : Clone + Default {
    /// Uses env.New{Type}Array to create a new java array containing "size" elements.
    fn new<'env>(env: &'env Env, size: usize) -> Local<'env, Self>;

    /// Uses env.GetArrayLength to get the length of the java array.
    fn len(&self) -> usize;

    /// Uses env.Get{Type}ArrayRegion to read the contents of the java array from \[start .. start + elements.len())
    fn get_region(&self, start: usize, elements: &mut [T]);

    /// Uses env.Set{Type}ArrayRegion to set the contents of the java array from \[start .. start + elements.len())
    fn set_region(&self, start: usize, elements: &[T]);

    /// Uses env.New{Type}Array + Set{Type}ArrayRegion to create a new java array containing a copy of "elements".
    fn from<'env>(env: &'env Env, elements: &[T]) -> Local<'env, Self> {
        let array = Self::new(env, elements.len());
        array.set_region(0, elements);
        array
    }

    /// Uses env.GetArrayLength + env.Get{Type}ArrayRegion to read the contents of the java array from range into a new Vec.
    fn get_region_as_vec(&self, range: impl RangeBounds<usize>) -> Vec<T> {
        let len = self.len();

        let start = match range.start_bound() {
            Bound::Unbounded => 0,
            Bound::Included(n) => *n,
            Bound::Excluded(n) => *n+1,
        };

        let end = match range.end_bound() {
            Bound::Unbounded => len,
            Bound::Included(n) => *n+1,
            Bound::Excluded(n) => *n,
        };

        assert!(start <= end);
        assert!(end   <= len);
        let vec_len = end - start;

        let mut vec = Vec::new();
        vec.resize(vec_len, Default::default());
        self.get_region(start, &mut vec[..]);
        vec
    }

    /// Uses env.GetArrayLength + env.Get{Type}ArrayRegion to read the contents of the entire java array into a new Vec.
    fn as_vec(&self) -> Vec<T> {
        self.get_region_as_vec(0..self.len())
    }
}

// I assume jboolean as used exclusively by JNI/JVM is compatible with bool.
// This is *not* a sound/safe assumption in the general case as jboolean can be any u8 bit pattern.
// However, I believe this *is* a sound/safe assumption when exclusively dealing with JNI/JVM APIs which *should* be
// returning exclusively JNI_TRUE or JNI_FALSE, which are bitwise compatible with Rust's definitions of true / false.
#[test] fn bool_ffi_assumptions_test() {
    use std::mem::*;

    // Assert that the sizes are indeed the same.
    assert_eq!(size_of::<jboolean>(), 1); // Forever
    assert_eq!(size_of::<bool>(),     1); // As of https://github.com/rust-lang/rust/pull/46156/commits/219ba511c824bc44149d55c570f723dcd0f0217d

    // Assert that the underlying representations are indeed the same.
    assert_eq!(unsafe { std::mem::transmute::<bool, u8>(true ) }, JNI_TRUE );
    assert_eq!(unsafe { std::mem::transmute::<bool, u8>(false) }, JNI_FALSE);
}

macro_rules! primitive_array {
    (#[repr(transparent)] pub struct $name:ident = $type_str:expr, $type:ident { $new_array:ident $set_region:ident $get_region:ident } ) => {
        /// A [PrimitiveArray](trait.PrimitiveArray.html) implementation.
        #[repr(transparent)] pub struct $name(ObjectAndEnv);

        unsafe impl AsValidJObjectAndEnv for $name {}
        unsafe impl AsJValue for $name { fn as_jvalue(&self) -> jni_sys::jvalue { jni_sys::jvalue { l: self.0.object } } }
        unsafe impl JniType for $name { fn static_with_jni_type<R>(callback: impl FnOnce(&str) -> R) -> R { callback($type_str) } }

        impl PrimitiveArray<$type> for $name {
            fn new<'env>(env: &'env Env, size: usize) -> Local<'env, Self> {
                assert!(size <= std::i32::MAX as usize); // jsize == jint == i32
                let size = size as jsize;
                let env = env.as_jni_env();
                unsafe {
                    let object = (**env).$new_array.unwrap()(env, size);
                    let exception = (**env).ExceptionOccurred.unwrap()(env);
                    assert!(exception.is_null()); // Only sane exception here is an OOM exception
                    Local::from_env_object(env, object)
                }
            }

            fn from<'env>(env: &'env Env, elements: &[$type]) -> Local<'env, Self> {
                let array  = Self::new(env, elements.len());
                let size   = elements.len() as jsize;
                let env    = array.0.env as *mut JNIEnv;
                let object = array.0.object;
                unsafe {
                    (**env).$set_region.unwrap()(env, object, 0, size, elements.as_ptr() as *const _);
                }
                array
            }

            fn len(&self) -> usize {
                unsafe { (**self.0.env).GetArrayLength.unwrap()(self.0.env as *mut _, self.0.object) as usize }
            }

            fn get_region(&self, start: usize, elements: &mut [$type]) {
                assert!(start          <= std::i32::MAX as usize); // jsize == jint == i32
                assert!(elements.len() <= std::i32::MAX as usize); // jsize == jint == i32
                let self_len     = self.len() as jsize;
                let elements_len = elements.len() as jsize;

                let start = start as jsize;
                let end   = start + elements_len;
                assert!(start <= end);
                assert!(end   <= self_len);

                unsafe { (**self.0.env).$get_region.unwrap()(self.0.env as *mut _, self.0.object, start, elements_len, elements.as_mut_ptr() as *mut _) };
            }

            fn set_region(&self, start: usize, elements: &[$type]) {
                assert!(start          <= std::i32::MAX as usize); // jsize == jint == i32
                assert!(elements.len() <= std::i32::MAX as usize); // jsize == jint == i32
                let self_len     = self.len() as jsize;
                let elements_len = elements.len() as jsize;

                let start = start as jsize;
                let end   = start + elements_len;
                assert!(start <= end);
                assert!(end   <= self_len);

                unsafe { (**self.0.env).$set_region.unwrap()(self.0.env as *mut _, self.0.object, start, elements_len, elements.as_ptr() as *const _) };
            }
        }
    };
}

primitive_array! { #[repr(transparent)] pub struct BooleanArray = "[Z\0", bool    { NewBooleanArray SetBooleanArrayRegion GetBooleanArrayRegion } }
primitive_array! { #[repr(transparent)] pub struct ByteArray    = "[B\0", jbyte   { NewByteArray    SetByteArrayRegion    GetByteArrayRegion    } }
primitive_array! { #[repr(transparent)] pub struct CharArray    = "[C\0", jchar   { NewCharArray    SetCharArrayRegion    GetCharArrayRegion    } }
primitive_array! { #[repr(transparent)] pub struct ShortArray   = "[S\0", jshort  { NewShortArray   SetShortArrayRegion   GetShortArrayRegion   } }
primitive_array! { #[repr(transparent)] pub struct IntArray     = "[I\0", jint    { NewIntArray     SetIntArrayRegion     GetIntArrayRegion     } }
primitive_array! { #[repr(transparent)] pub struct LongArray    = "[J\0", jlong   { NewLongArray    SetLongArrayRegion    GetLongArrayRegion    } }
primitive_array! { #[repr(transparent)] pub struct FloatArray   = "[F\0", jfloat  { NewFloatArray   SetFloatArrayRegion   GetFloatArrayRegion   } }
primitive_array! { #[repr(transparent)] pub struct DoubleArray  = "[D\0", jdouble { NewDoubleArray  SetDoubleArrayRegion  GetDoubleArrayRegion  } }

/// A Java Array of reference types (classes, interfaces, other arrays, etc.)
/// 
/// See also [PrimitiveArray] for arrays of reference types.
/// 
/// [PrimitiveArray]:   struct.PrimitiveArray.html
/// 
#[repr(transparent)]
pub struct ObjectArray<T: AsValidJObjectAndEnv, E: ThrowableType>(ObjectAndEnv, PhantomData<(T,E)>);

unsafe impl<T: AsValidJObjectAndEnv, E: ThrowableType> AsValidJObjectAndEnv for ObjectArray<T, E> {}

unsafe impl<T: AsValidJObjectAndEnv, E: ThrowableType> JniType for ObjectArray<T, E> {
    fn static_with_jni_type<R>(callback: impl FnOnce(&str) -> R) -> R {
        T::static_with_jni_type(|inner| callback(format!("[{}", inner).as_str()))
    }
}

unsafe impl<T: AsValidJObjectAndEnv, E: ThrowableType> AsJValue for ObjectArray<T, E> {
    fn as_jvalue(&self) -> jni_sys::jvalue {
        jni_sys::jvalue { l: self.0.object }
    }
}

impl<T: AsValidJObjectAndEnv, E: ThrowableType> ObjectArray<T, E> {
    pub fn new<'env>(env: &'env Env, size: usize) -> Local<'env, Self> {
        assert!(size <= std::i32::MAX as usize); // jsize == jint == i32
        let class = Self::static_with_jni_type(|t| unsafe { env.require_class(t) });
        let size = size as jsize;
        let env = env.as_jni_env();
        unsafe {
            let fill = null_mut();
            let object = (**env).NewObjectArray.unwrap()(env, size, class, fill);
            let exception = (**env).ExceptionOccurred.unwrap()(env);
            assert!(exception.is_null()); // Only sane exception here is an OOM exception
            Local::from_env_object(env, object)
        }
    }

    pub fn iter<'env>(&'env self) -> ObjectArrayIter<'env, T, E> {
        ObjectArrayIter {
            array:  self,
            index:  0,
            length: self.len(),
        }
    }

    pub fn from<'env>(env: &'env Env, elements: impl 'env + ExactSizeIterator + Iterator<Item = impl Into<Option<&'env T>>>) -> Local<'env, Self> {
        let size    = elements.len();
        let array   = Self::new(env, size);
        let env     = array.0.env as *mut JNIEnv;
        let this    = array.0.object;
        let set     = unsafe { (**env) }.SetObjectArrayElement.unwrap();

        for (index, element) in elements.enumerate() {
            assert!(index < size); // Should only be violated by an invalid ExactSizeIterator implementation.
            let value = element.into().map(|v| unsafe { AsJValue::as_jvalue(v.into()).l }).unwrap_or(null_mut());
            unsafe { set(env, this, index as jsize, value) };
        }
        array
    }

    pub fn len(&self) -> usize {
        unsafe { (**self.0.env).GetArrayLength.unwrap()(self.0.env as *mut _, self.0.object) as usize }
    }

    /// XXX: Expose this via std::ops::Index
    pub fn get<'env>(&'env self, index: usize) -> Result<Option<Local<'env, T>>, Local<'env, E>> {
        assert!(index <= std::i32::MAX as usize); // jsize == jint == i32 XXX: Should maybe be treated as an exception?
        let index   = index as jsize;
        let env     = self.0.env as *mut JNIEnv;
        let this    = self.0.object;
        unsafe {
            let result = (**env).GetObjectArrayElement.unwrap()(env, this, index);
            let exception = (**env).ExceptionOccurred.unwrap()(env);
            if !exception.is_null() {
                (**env).ExceptionClear.unwrap()(env);
                Err(Local::from_env_object(env, exception))
            } else if result.is_null() {
                Ok(None)
            } else {
                Ok(Some(Local::from_env_object(env, result)))
            }
        }
    }

    /// XXX: I don't think there's a way to expose this via std::ops::IndexMut sadly?
    pub fn set<'env>(&'env self, index: usize, value: impl Into<Option<&'env T>>) -> Result<(), Local<'env, E>> {
        assert!(index <= std::i32::MAX as usize); // jsize == jint == i32 XXX: Should maybe be treated as an exception?
        let value   = value.into().map(|v| unsafe { AsJValue::as_jvalue(v.into()).l }).unwrap_or(null_mut());
        let index   = index as jsize;
        let env     = self.0.env as *mut JNIEnv;
        let this    = self.0.object;
        unsafe {
            (**env).SetObjectArrayElement.unwrap()(env, this, index, value);
            let exception = (**env).ExceptionOccurred.unwrap()(env);
            if !exception.is_null() {
                (**env).ExceptionClear.unwrap()(env);
                Err(Local::from_env_object(env, exception))
            } else {
                Ok(())
            }
        }
    }
}



pub struct ObjectArrayIter<'env, T: AsValidJObjectAndEnv, E: ThrowableType> {
    array:  &'env ObjectArray<T, E>,
    index:  usize,
    length: usize,
}

impl<'env, T: AsValidJObjectAndEnv, E: ThrowableType> Iterator for ObjectArrayIter<'env, T, E> {
    type Item = Option<Local<'env, T>>;
    fn next(&mut self) -> Option<Self::Item> {
        let index = self.index;
        if index < self.length {
            self.index = index + 1;
            Some(self.array.get(index).unwrap_or(None))
        } else {
            None
        }
    }
}