metatensor 0.3.0-rc2

Self-describing sparse tensor data format for atomistic machine learning and beyond
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
use std::sync::{Arc, RwLock, TryLockError};

use dlpk::sys::{DLDevice, DLPackVersion, DLDataType};
use dlpk::{DLDataTypeCode, DLPackPointerCast, DLPackTensor, GetDLPackDataType, ReadOnly};

use crate::errors::Error;
use crate::c_api::mts_data_movement_t;

use super::{Array, MtsArray};

impl<T> From<ndarray::ArrayD<T>> for MtsArray where T: 'static + Clone + Send + Default + Sync + GetDLPackDataType + DLPackPointerCast {
    fn from(value: ndarray::ArrayD<T>) -> Self {
        let array = Arc::new(RwLock::new(value));
        let boxed: Box<dyn Array> = Box::new(array);
        return MtsArray::from(boxed);
    }
}

impl<T> Array for Arc<RwLock<ndarray::ArrayD<T>>>
where
    T: 'static + Send + Sync + Clone + Default + GetDLPackDataType + DLPackPointerCast,
{
    fn as_any(&self) -> &dyn std::any::Any {
        self
    }

    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
        self
    }

    fn create(&self, shape: &[usize], fill_value: MtsArray) -> Box<dyn Array> {
        let cpu_device = DLDevice::cpu();
        let max_version = DLPackVersion::current();
        let fill_value_dlpack = fill_value.as_dlpack(cpu_device, None, max_version).expect("failed to extract fill_value as DLPack");

        // Validate fill_value shape from the DLPack tensor directly
        assert_eq!(fill_value_dlpack.shape(), &[], "fill_value must be a single scalar");
        assert_eq!(fill_value_dlpack.device(), cpu_device, "fill_value must be on CPU");

        let fill_value_ptr = fill_value_dlpack.data_ptr::<T>().expect("dtype mismatch between array and fill_value");
        let fill_value_scalar = unsafe { std::ptr::read(fill_value_ptr) };

        let array = ndarray::Array::from_elem(shape, fill_value_scalar);
        return Box::new(Arc::new(RwLock::new(array)));
    }

    fn copy(&self, device: DLDevice) -> Box<dyn Array> {
        assert_eq!(device, DLDevice::cpu(), "Rust ndarray data can only be copied to CPU device");
        return Box::new(self.clone());
    }

    fn shape(&self) -> Vec<usize> {
        match self.try_read() {
            Ok(lock) => lock.shape().to_vec(),
            Err(TryLockError::Poisoned(_)) => panic!("array lock is poisoned"),
            Err(TryLockError::WouldBlock) => panic!("array is already locked"),
        }
    }

    fn reshape(&mut self, shape: &[usize]) {
        let mut lock = match self.try_write() {
            Ok(lock) => lock,
            Err(TryLockError::Poisoned(_)) => panic!("array lock is poisoned"),
            Err(TryLockError::WouldBlock) => panic!("array is already locked"),
        };
        let array = std::mem::take(&mut *lock);
        let array = array.into_shape_clone(shape).expect("invalid shape");
        let _ = std::mem::replace(&mut *lock, array);
    }

    fn swap_axes(&mut self, axis_1: usize, axis_2: usize) {
        let mut lock = match self.try_write() {
            Ok(lock) => lock,
            Err(TryLockError::Poisoned(_)) => panic!("array lock is poisoned"),
            Err(TryLockError::WouldBlock) => panic!("array is already locked"),
        };
        lock.swap_axes(axis_1, axis_2);
    }

    fn move_data(
        &mut self,
        input: &dyn Array,
        movements: &[mts_data_movement_t],
    ) {
        use ndarray::{Axis, Slice};

        let input = input.as_any().downcast_ref::<Self>().expect("input must be a ndarray of the same type");
        let input = match input.try_read() {
            Ok(lock) => lock,
            Err(TryLockError::Poisoned(_)) => panic!("input array lock is poisoned"),
            Err(TryLockError::WouldBlock) => panic!("input array is already locked"),
        };

        let mut output = match self.try_write() {
            Ok(lock) => lock,
            Err(TryLockError::Poisoned(_)) => panic!("output array lock is poisoned"),
            Err(TryLockError::WouldBlock) => panic!("output array is already locked"),
        };

        if movements.is_empty() {
            return;
        }

        // Check if we can use the optimized path (all moves have same property structure)
        let first_prop_start_in = movements[0].properties_start_in;
        let first_prop_start_out = movements[0].properties_start_out;
        let first_prop_len = movements[0].properties_length;

        let mut constant_properties = true;
        let mut contiguous_input_samples = true;
        let mut contiguous_output_samples = true;

        for w in movements.windows(2) {
            if w[0].properties_start_in != first_prop_start_in ||
               w[0].properties_start_out != first_prop_start_out ||
               w[0].properties_length != first_prop_len {
                constant_properties = false;
                break;
            }

            if w[1].sample_in != w[0].sample_in + 1 {
                contiguous_input_samples = false;
            }

            if w[1].sample_out != w[0].sample_out + 1 {
                contiguous_output_samples = false;
            }
        }

        if constant_properties {
            let last = movements.last().unwrap();
            if last.properties_start_in != first_prop_start_in ||
               last.properties_start_out != first_prop_start_out ||
               last.properties_length != first_prop_len {
                constant_properties = false;
            }
        }

        let property_axis = output.shape().len() - 1;

        if constant_properties {
            let input_slice_info = Slice::from(first_prop_start_in..(first_prop_start_in + first_prop_len));
            let output_slice_info = Slice::from(first_prop_start_out..(first_prop_start_out + first_prop_len));

            if contiguous_input_samples && contiguous_output_samples {
                let sample_start_in = movements[0].sample_in;
                let sample_start_out = movements[0].sample_out;
                let sample_count = movements.len();

                let input_samples = input.slice_axis(
                    Axis(0),
                    Slice::from(sample_start_in..(sample_start_in + sample_count))
                );
                let mut output_samples = output.slice_axis_mut(
                    Axis(0),
                    Slice::from(sample_start_out..(sample_start_out + sample_count))
                );

                let value = input_samples.slice_axis(Axis(property_axis), input_slice_info);
                let mut output_location = output_samples.slice_axis_mut(Axis(property_axis), output_slice_info);

                output_location.assign(&value);
            } else {
                for move_item in movements {
                    let input_sample = input.index_axis(Axis(0), move_item.sample_in);
                    let mut output_sample = output.index_axis_mut(Axis(0), move_item.sample_out);

                    let value = input_sample.slice_axis(
                        // property_axis - 1 because we are slicing the sample
                        // axis out, so the property axis is now one less
                        Axis(property_axis - 1),
                        input_slice_info
                    );
                    let mut output_location = output_sample.slice_axis_mut(
                        Axis(property_axis - 1),
                        output_slice_info
                    );
                    output_location.assign(&value);
                }
            }
        } else {
            // fallback to the general case
            for move_item in movements {
                let input_sample = input.index_axis(Axis(0), move_item.sample_in);
                let mut output_sample = output.index_axis_mut(Axis(0), move_item.sample_out);

                let value = input_sample.slice_axis(
                    // see above for property_axis - 1 explanation
                    Axis(property_axis - 1),
                    Slice::from(move_item.properties_start_in..(move_item.properties_start_in + move_item.properties_length))
                );
                let mut output_location = output_sample.slice_axis_mut(
                    Axis(property_axis - 1),
                    Slice::from(move_item.properties_start_out..(move_item.properties_start_out + move_item.properties_length))
                );
                output_location.assign(&value);
            }
        }
    }

    fn device(&self) -> DLDevice {
        DLDevice::cpu()
    }

    fn dtype(&self) -> DLDataType {
        T::get_dlpack_data_type()
    }

    fn as_dlpack(
        &self,
        device: DLDevice,
        stream: Option<i64>,
        max_version: DLPackVersion,
    ) -> Result<DLPackTensor, Error> {
        if stream.is_some() {
            // we only support CPU for now
            return Err(Error {
                code: Some(crate::c_api::MTS_INVALID_PARAMETER_ERROR),
                message: "CPU arrays can not be used with a stream".into(),
            });
        }
        let vendored_version = DLPackVersion::current();
        if max_version.major != vendored_version.major {
            return Err(Error {
                code: Some(crate::c_api::MTS_INVALID_PARAMETER_ERROR),
                message: format!(
                    "invalid `max_version` in ndarray::ArrayD<T>::as_dlpack: \
                    we got v{}.{}, but we support v{}.{}",
                    max_version.major, max_version.minor,
                    vendored_version.major, vendored_version.minor
                ),
            });
        }

        let ndarray_device = DLDevice::cpu();

        if device.device_type != ndarray_device.device_type || device.device_id != ndarray_device.device_id {
            return Err(Error {
                code: Some(crate::c_api::MTS_INVALID_PARAMETER_ERROR),
                message: format!(
                    "Requested DLPack device ({}) does not match array device ({})",
                    device, ndarray_device
                ),
            });
        }

        let tensor: DLPackTensor = ReadOnly(Arc::clone(self)).try_into().map_err(|e| Error {
            code: Some(crate::c_api::MTS_INVALID_PARAMETER_ERROR),
            message: format!("failed to convert ndarray to DLPack: {:?}", e),
        })?;

        Ok(tensor)
    }

    #[allow(clippy::enum_glob_use)]
    fn from_dlpack(&self, dlpack_tensor: DLPackTensor) -> Result<Box<dyn Array>, Error> {
        use DLDataTypeCode::*;

        let dtype = dlpack_tensor.dtype();

        if dtype.lanes != 1 {
            return Err(Error {
                code: Some(crate::c_api::MTS_INVALID_PARAMETER_ERROR),
                message: "Only DLPack tensors with lanes == 1 are supported".into(),
            });
        }

        let map_error = |e| Error {
            code: Some(crate::c_api::MTS_INVALID_PARAMETER_ERROR),
            message: format!("failed to convert DLPack to ndarray: {:?}", e),
        };

        if dtype.code == kDLFloat && dtype.bits == 64 {
            let array: ndarray::ArrayD<f64> = dlpack_tensor.try_into().map_err(map_error)?;
            return Ok(Box::new(Arc::new(RwLock::new(array))));
        } else if dtype.code == kDLFloat && dtype.bits == 32 {
            let array: ndarray::ArrayD<f32> = dlpack_tensor.try_into().map_err(map_error)?;
            return Ok(Box::new(Arc::new(RwLock::new(array))));
        } else if dtype.code == kDLInt && dtype.bits == 8 {
            let array: ndarray::ArrayD<i8> = dlpack_tensor.try_into().map_err(map_error)?;
            return Ok(Box::new(Arc::new(RwLock::new(array))));
        } else if dtype.code == kDLInt && dtype.bits == 16 {
            let array: ndarray::ArrayD<i16> = dlpack_tensor.try_into().map_err(map_error)?;
            return Ok(Box::new(Arc::new(RwLock::new(array))));
        } else if dtype.code == kDLInt && dtype.bits == 32 {
            let array: ndarray::ArrayD<i32> = dlpack_tensor.try_into().map_err(map_error)?;
            return Ok(Box::new(Arc::new(RwLock::new(array))));
        } else if dtype.code == kDLInt && dtype.bits == 64 {
            let array: ndarray::ArrayD<i64> = dlpack_tensor.try_into().map_err(map_error)?;
            return Ok(Box::new(Arc::new(RwLock::new(array))));
        } else if dtype.code == kDLUInt && dtype.bits == 8 {
            let array: ndarray::ArrayD<u8> = dlpack_tensor.try_into().map_err(map_error)?;
            return Ok(Box::new(Arc::new(RwLock::new(array))));
        } else if dtype.code == kDLUInt && dtype.bits == 16 {
            let array: ndarray::ArrayD<u16> = dlpack_tensor.try_into().map_err(map_error)?;
            return Ok(Box::new(Arc::new(RwLock::new(array))));
        } else if dtype.code == kDLUInt && dtype.bits == 32 {
            let array: ndarray::ArrayD<u32> = dlpack_tensor.try_into().map_err(map_error)?;
            return Ok(Box::new(Arc::new(RwLock::new(array))));
        } else if dtype.code == kDLUInt && dtype.bits == 64 {
            let array: ndarray::ArrayD<u64> = dlpack_tensor.try_into().map_err(map_error)?;
            return Ok(Box::new(Arc::new(RwLock::new(array))));
        } else if dtype.code == kDLBool && dtype.bits == 8 {
            let array: ndarray::ArrayD<bool> = dlpack_tensor.try_into().map_err(map_error)?;
            return Ok(Box::new(Arc::new(RwLock::new(array))));
        } else {
            return Err(Error {
                code: Some(crate::c_api::MTS_INVALID_PARAMETER_ERROR),
                message: format!("Unsupported DLPack dtype {}", dtype),
            });
        }
    }
}

#[cfg(test)]
mod tests {
    use dlpk::{DLPackPointerCast, GetDLPackDataType, sys::{DLDataTypeCode, DLDevice, DLPackVersion}};
    use crate::MtsArray;

    #[test]
    fn ndarray_as_mts_array() {
        let data = ndarray::Array::<f64, _>::zeros(vec![2, 3, 4]);
        let mts_array = MtsArray::from(data);

        assert_eq!(mts_array.shape().unwrap(), [2, 3, 4]);

        let fill_value = MtsArray::from(ndarray::Array::from_elem(vec![], 42.0));

        let created = mts_array.create(&[2, 3, 4], fill_value.as_ref()).unwrap();
        assert_eq!(created.shape().unwrap(), [2, 3, 4]);
    }

    #[test]
    fn ndarray_as_mts_array_dlpack() {
        let data = ndarray::Array::<f64, _>::zeros(vec![4, 5, 6]);
        let mts_array = MtsArray::from(data);

        let dl_managed = mts_array.as_dlpack(DLDevice::cpu(), None, DLPackVersion::current()).unwrap();

        assert_eq!(dl_managed.n_dims(), 3);
        assert_eq!(dl_managed.shape(), [4, 5, 6]);

        assert_eq!(dl_managed.dtype().code, DLDataTypeCode::kDLFloat);
        assert_eq!(dl_managed.dtype().bits, 64);
        assert_eq!(dl_managed.dtype().lanes, 1);
    }

    #[test]
    fn ndarray_all_dtypes() {
        fn test_for_dtype<T>(code: DLDataTypeCode, bits: u8) where T: Send + Sync + Clone + Default + GetDLPackDataType + DLPackPointerCast + 'static {
            let data = ndarray::Array::<T, _>::from_elem(vec![2, 2], T::default());
            let mts_array = MtsArray::from(data);

            assert_eq!(mts_array.shape().unwrap(), [2, 2]);

            // Should be able to export as DLPack
            let dl_managed = mts_array.as_dlpack(DLDevice::cpu(), None, DLPackVersion::current()).unwrap();
            assert_eq!(dl_managed.dtype().code, code);
            assert_eq!(dl_managed.dtype().bits, bits);
            assert_eq!(dl_managed.dtype().lanes, 1);


            // And `create` should make an array of the same type (i32)
            let fill_value = MtsArray::from(ndarray::Array::from_elem(vec![], T::default()));

            let created = mts_array.create(&[1, 1], fill_value.as_ref()).unwrap();
            let dl_managed = created.as_dlpack(DLDevice::cpu(), None, DLPackVersion::current()).unwrap();

            assert_eq!(dl_managed.dtype().code, code);
            assert_eq!(dl_managed.dtype().bits, bits);
            assert_eq!(dl_managed.dtype().lanes, 1);
        }

        test_for_dtype::<bool>(DLDataTypeCode::kDLBool, 8);
        test_for_dtype::<f64>(DLDataTypeCode::kDLFloat, 64);
        test_for_dtype::<f32>(DLDataTypeCode::kDLFloat, 32);
        test_for_dtype::<i8>(DLDataTypeCode::kDLInt, 8);
        test_for_dtype::<i16>(DLDataTypeCode::kDLInt, 16);
        test_for_dtype::<i32>(DLDataTypeCode::kDLInt, 32);
        test_for_dtype::<i64>(DLDataTypeCode::kDLInt, 64);
        test_for_dtype::<u8>(DLDataTypeCode::kDLUInt, 8);
        test_for_dtype::<u16>(DLDataTypeCode::kDLUInt, 16);
        test_for_dtype::<u32>(DLDataTypeCode::kDLUInt, 32);
        test_for_dtype::<u64>(DLDataTypeCode::kDLUInt, 64);
    }

    #[test]
    fn ndarray_device() {
        let data = ndarray::Array::<f64, _>::zeros(vec![2, 3]);
        let mts_array = MtsArray::from(data);

        assert_eq!(mts_array.device().unwrap(), DLDevice::cpu());
    }

    #[test]
    fn as_dlpack_rejects_stream() {
        let data = ndarray::Array::<f64, _>::zeros(vec![2, 3]);
        let mts_array = MtsArray::from(data);
        match mts_array.as_dlpack(DLDevice::cpu(), Some(42), DLPackVersion::current()) {
            Err(e) => assert!(e.message.contains("stream"), "{}", e.message),
            Ok(_) => panic!("expected error for non-null stream"),
        }
    }

    #[test]
    fn as_dlpack_rejects_wrong_device() {
        let data = ndarray::Array::<f64, _>::zeros(vec![2, 3]);
        let mts_array = MtsArray::from(data);
        let cuda = DLDevice {
            device_type: dlpk::sys::DLDeviceType::kDLCUDA,
            device_id: 0,
        };
        match mts_array.as_dlpack(cuda, None, DLPackVersion::current()) {
            Err(e) => assert!(e.message.contains("does not match"), "{}", e.message),
            Ok(_) => panic!("expected error for CUDA device on CPU array"),
        }
    }

    #[test]
    fn as_dlpack_rejects_incompatible_version() {
        let data = ndarray::Array::<f64, _>::zeros(vec![2, 3]);
        let mts_array = MtsArray::from(data);

        let bad_version = DLPackVersion { major: 99, minor: 0 };
        match mts_array.as_dlpack(DLDevice::cpu(), None, bad_version) {
            Err(e) => assert!(e.message.contains("version"), "{}", e.message),
            Ok(_) => panic!("expected error for incompatible DLPack version"),
        }
    }

    #[test]
    #[allow(clippy::float_cmp)]
    fn from_dlpack() {
        let mut f64_data = ndarray::Array::<f64, _>::zeros(vec![2, 3]);
        f64_data[[0, 0]] = 1.573;
        f64_data[[1, 2]] = -42.0;
        let f64_array = MtsArray::from(f64_data);

        let mut i16_data = ndarray::Array::<i16, _>::zeros(vec![2, 5, 10]);
        i16_data[[0, 1, 3]] = 3;
        i16_data[[1, 2, 4]] = -42;
        let i16_array = MtsArray::from(i16_data);

        let f64_dl_tensor = f64_array.as_dlpack(DLDevice::cpu(), None, DLPackVersion::current()).unwrap();
        let i16_dl_tensor = i16_array.as_dlpack(DLDevice::cpu(), None, DLPackVersion::current()).unwrap();

        let new_f64_array = f64_array.from_dlpack(f64_dl_tensor).unwrap();
        let new_i16_array = i16_array.from_dlpack(i16_dl_tensor).unwrap();

        assert_eq!(f64_array.origin().unwrap(), i16_array.origin().unwrap());
        assert_eq!(new_f64_array.origin().unwrap(), f64_array.origin().unwrap());
        assert_eq!(new_i16_array.origin().unwrap(), i16_array.origin().unwrap());

        let new_f64_dl_tensor = new_f64_array.as_dlpack(DLDevice::cpu(), None, DLPackVersion::current()).unwrap();
        let new_i16_dl_tensor = new_i16_array.as_dlpack(DLDevice::cpu(), None, DLPackVersion::current()).unwrap();

        let new_f64_data: ndarray::ArrayD<f64> = new_f64_dl_tensor.try_into().unwrap();
        let new_i16_data: ndarray::ArrayD<i16> = new_i16_dl_tensor.try_into().unwrap();

        assert_eq!(new_f64_data[[0, 0]], 1.573);
        assert_eq!(new_f64_data[[1, 2]], -42.0);

        assert_eq!(new_i16_data[[0, 1, 3]], 3);
        assert_eq!(new_i16_data[[1, 2, 4]], -42);
    }
}