Skip to main content

shape_jit/
jit_matrix.rs

1//! Native JIT matrix with guaranteed C-compatible layout.
2//!
3//! Mirrors `Arc<MatrixData>` for the JIT. Holds the Arc alive via a leaked
4//! raw pointer so the flat f64 data buffer remains valid for direct SIMD
5//! access from Cranelift-generated code.
6//!
7//! Memory layout (`#[repr(C)]`):
8//! ```text
9//!   offset  0: data       — *const f64 (pointer into MatrixData.data, NOT owned)
10//!   offset  8: rows       — u32
11//!   offset 12: cols       — u32
12//!   offset 16: total_len  — u64  (rows * cols, cached for bounds checks)
13//!   offset 24: owner      — *const () (leaked Arc<MatrixData>, reconstituted on drop)
14//! ```
15
16use std::sync::Arc;
17
18use shape_value::heap_value::MatrixData;
19
20pub const MATRIX_DATA_OFFSET: i32 = 0;
21pub const MATRIX_ROWS_OFFSET: i32 = 8;
22pub const MATRIX_COLS_OFFSET: i32 = 12;
23pub const MATRIX_TOTAL_LEN_OFFSET: i32 = 16;
24pub const MATRIX_OWNER_OFFSET: i32 = 24;
25
26/// Native JIT matrix — a flat f64 buffer with row/col dimensions.
27///
28/// The `data` pointer points directly into the owned `Arc<MatrixData>`'s
29/// `AlignedVec<f64>`, giving the JIT zero-copy access to the underlying
30/// SIMD-aligned storage.
31#[repr(C)]
32pub struct JitMatrix {
33    /// Pointer to the flat f64 data buffer (row-major order).
34    /// NOT owned — lifetime tied to `owner`.
35    pub data: *const f64,
36    /// Number of rows.
37    pub rows: u32,
38    /// Number of columns.
39    pub cols: u32,
40    /// Total element count (rows * cols), cached.
41    pub total_len: u64,
42    /// Leaked `Arc<MatrixData>` that owns the data buffer.
43    /// Reconstituted and dropped in `Drop`.
44    owner: *const MatrixData,
45}
46
47impl JitMatrix {
48    /// Create a JitMatrix from an `Arc<MatrixData>`.
49    ///
50    /// Leaks one Arc strong reference to keep the data alive. The `Drop`
51    /// impl reconstitutes the Arc and releases it.
52    pub fn from_arc(arc: &Arc<MatrixData>) -> Self {
53        let mat = arc.as_ref();
54        let data = mat.data.as_slice().as_ptr();
55        let rows = mat.rows;
56        let cols = mat.cols;
57        let total_len = mat.data.len() as u64;
58        // Increment refcount; raw pointer keeps data alive.
59        let owner = Arc::into_raw(Arc::clone(arc));
60        Self {
61            data,
62            rows,
63            cols,
64            total_len,
65            owner,
66        }
67    }
68
69    /// Reconstitute the owned `Arc<MatrixData>` without dropping it.
70    ///
71    /// Returns a new Arc clone. The JitMatrix retains its own reference
72    /// (will be released on drop).
73    pub fn to_arc(&self) -> Arc<MatrixData> {
74        assert!(!self.owner.is_null(), "JitMatrix has null owner");
75        // Safety: owner was created by Arc::into_raw in from_arc.
76        let arc = unsafe { Arc::from_raw(self.owner) };
77        let cloned = Arc::clone(&arc);
78        // Leak back so Drop still has a reference to release.
79        std::mem::forget(arc);
80        cloned
81    }
82}
83
84impl Drop for JitMatrix {
85    fn drop(&mut self) {
86        if !self.owner.is_null() {
87            // Reconstitute and drop the leaked Arc.
88            unsafe {
89                let _ = Arc::from_raw(self.owner);
90            }
91        }
92    }
93}
94
95#[cfg(test)]
96mod tests {
97    use super::*;
98    use shape_value::aligned_vec::AlignedVec;
99
100    fn make_test_matrix(rows: u32, cols: u32) -> Arc<MatrixData> {
101        let n = (rows as usize) * (cols as usize);
102        let mut data = AlignedVec::with_capacity(n);
103        for i in 0..n {
104            data.push(i as f64);
105        }
106        Arc::new(MatrixData::from_flat(data, rows, cols))
107    }
108
109    #[test]
110    fn test_layout() {
111        assert_eq!(std::mem::offset_of!(JitMatrix, data), MATRIX_DATA_OFFSET as usize);
112        assert_eq!(std::mem::offset_of!(JitMatrix, rows), MATRIX_ROWS_OFFSET as usize);
113        assert_eq!(std::mem::offset_of!(JitMatrix, cols), MATRIX_COLS_OFFSET as usize);
114        assert_eq!(std::mem::offset_of!(JitMatrix, total_len), MATRIX_TOTAL_LEN_OFFSET as usize);
115        assert_eq!(std::mem::offset_of!(JitMatrix, owner), MATRIX_OWNER_OFFSET as usize);
116        assert_eq!(std::mem::size_of::<JitMatrix>(), 32);
117    }
118
119    #[test]
120    fn test_round_trip() {
121        let arc = make_test_matrix(3, 4);
122        let jm = JitMatrix::from_arc(&arc);
123        assert_eq!(jm.rows, 3);
124        assert_eq!(jm.cols, 4);
125        assert_eq!(jm.total_len, 12);
126
127        // Data pointer gives direct access.
128        let slice = unsafe { std::slice::from_raw_parts(jm.data, jm.total_len as usize) };
129        assert_eq!(slice[0], 0.0);
130        assert_eq!(slice[11], 11.0);
131
132        // Round-trip back to Arc.
133        let recovered = jm.to_arc();
134        assert_eq!(recovered.rows, 3);
135        assert_eq!(recovered.cols, 4);
136        assert_eq!(recovered.data[0], 0.0);
137        assert_eq!(recovered.data[11], 11.0);
138
139        // Original Arc is still valid.
140        assert_eq!(arc.data[5], 5.0);
141    }
142
143    #[test]
144    fn test_arc_refcount() {
145        let arc = make_test_matrix(2, 2);
146        assert_eq!(Arc::strong_count(&arc), 1);
147
148        let jm = JitMatrix::from_arc(&arc);
149        assert_eq!(Arc::strong_count(&arc), 2); // jm holds one ref
150
151        let recovered = jm.to_arc();
152        assert_eq!(Arc::strong_count(&arc), 3); // jm + recovered
153
154        drop(recovered);
155        assert_eq!(Arc::strong_count(&arc), 2);
156
157        drop(jm);
158        assert_eq!(Arc::strong_count(&arc), 1); // back to original
159    }
160}