Skip to main content

cjc_runtime/
aligned_pool.rs

1use std::fmt;
2use std::rc::Rc;
3
4use crate::error::RuntimeError;
5use crate::tensor::Tensor;
6
7// ---------------------------------------------------------------------------
8// 2c. AlignedPool — 16-byte aligned allocation for SIMD readiness
9// ---------------------------------------------------------------------------
10
11/// A pre-allocated memory pool with 16-byte alignment guarantee.
12///
13/// Used by `AlignedByteSlice` to ensure that f32/f64 data mapped from raw
14/// bytes starts on a SIMD-friendly boundary. When source bytes are already
15/// aligned, no copy is needed; when misaligned, a one-time aligned copy is
16/// performed into the pool.
17#[derive(Debug, Clone)]
18pub struct AlignedPool {
19    /// Backing storage. The Vec itself is heap-allocated with alignment ≥ 8.
20    /// We over-allocate by 15 bytes and track the aligned offset.
21    storage: Vec<u8>,
22    /// Byte offset into `storage` where the aligned region begins.
23    aligned_offset: usize,
24    /// Usable capacity (bytes) from the aligned offset.
25    capacity: usize,
26    /// Number of bytes currently written.
27    len: usize,
28}
29
30impl AlignedPool {
31    /// Create a new pool with capacity for at least `capacity_bytes` of
32    /// 16-byte-aligned data. The actual allocation may be slightly larger.
33    pub fn new(capacity_bytes: usize) -> Self {
34        // Over-allocate by 15 bytes so we can always find a 16-byte boundary.
35        let alloc_size = capacity_bytes + 15;
36        let storage = vec![0u8; alloc_size];
37        let base_ptr = storage.as_ptr() as usize;
38        let aligned_offset = (16 - (base_ptr % 16)) % 16;
39        AlignedPool {
40            storage,
41            aligned_offset,
42            capacity: capacity_bytes,
43            len: 0,
44        }
45    }
46
47    /// Returns a pointer to the aligned region.
48    pub fn as_ptr(&self) -> *const u8 {
49        // SAFETY: aligned_offset is always within bounds by construction.
50        unsafe { self.storage.as_ptr().add(self.aligned_offset) }
51    }
52
53    /// Returns a mutable pointer to the aligned region.
54    pub fn as_mut_ptr(&mut self) -> *mut u8 {
55        unsafe { self.storage.as_mut_ptr().add(self.aligned_offset) }
56    }
57
58    /// Returns the aligned region as a byte slice.
59    pub fn as_bytes(&self) -> &[u8] {
60        &self.storage[self.aligned_offset..self.aligned_offset + self.len]
61    }
62
63    /// Check if a raw pointer is 16-byte aligned.
64    pub fn is_aligned_16(ptr: *const u8) -> bool {
65        (ptr as usize) % 16 == 0
66    }
67
68    /// Copy `data` into the pool, returning the aligned byte slice.
69    /// Returns an error if data exceeds pool capacity.
70    pub fn copy_from(&mut self, data: &[u8]) -> Result<(), RuntimeError> {
71        if data.len() > self.capacity {
72            return Err(RuntimeError::InvalidOperation(
73                format!(
74                    "AlignedPool: data length {} exceeds capacity {}",
75                    data.len(),
76                    self.capacity
77                ),
78            ));
79        }
80        let dest = &mut self.storage[self.aligned_offset..self.aligned_offset + data.len()];
81        dest.copy_from_slice(data);
82        self.len = data.len();
83        Ok(())
84    }
85
86    /// Current number of bytes stored.
87    pub fn len(&self) -> usize {
88        self.len
89    }
90
91    /// Whether the pool is empty.
92    pub fn is_empty(&self) -> bool {
93        self.len == 0
94    }
95
96    /// Total capacity in bytes.
97    pub fn capacity(&self) -> usize {
98        self.capacity
99    }
100
101    /// Verify that the aligned pointer is indeed 16-byte aligned.
102    pub fn check_alignment(&self) -> bool {
103        Self::is_aligned_16(self.as_ptr())
104    }
105}
106
107impl fmt::Display for AlignedPool {
108    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
109        write!(
110            f,
111            "AlignedPool(len={}, capacity={}, aligned={})",
112            self.len, self.capacity, self.check_alignment()
113        )
114    }
115}
116
117/// An alignment-aware byte slice that guarantees 16-byte alignment for
118/// tensor weight mapping. If the source bytes are already aligned, it
119/// wraps them directly. If misaligned, it copies into an `AlignedPool`.
120#[derive(Debug, Clone)]
121pub struct AlignedByteSlice {
122    /// The pool holds the aligned copy (if a copy was needed).
123    pool: Option<AlignedPool>,
124    /// Original bytes (kept for reference / fallback).
125    original: Rc<Vec<u8>>,
126    /// Whether a copy was performed (true = was misaligned).
127    was_copied: bool,
128}
129
130impl AlignedByteSlice {
131    /// Create an aligned byte slice from raw bytes.
132    ///
133    /// If the data is already 16-byte aligned, no copy is performed.
134    /// If misaligned, the data is copied into an aligned pool and a
135    /// warning flag is set.
136    pub fn from_bytes(data: Rc<Vec<u8>>) -> Self {
137        let ptr = data.as_ptr();
138        if AlignedPool::is_aligned_16(ptr) {
139            AlignedByteSlice {
140                pool: None,
141                original: data,
142                was_copied: false,
143            }
144        } else {
145            let mut pool = AlignedPool::new(data.len());
146            // This cannot fail: pool capacity == data.len()
147            pool.copy_from(&data).unwrap();
148            AlignedByteSlice {
149                pool: Some(pool),
150                original: data,
151                was_copied: true,
152            }
153        }
154    }
155
156    /// Get the aligned bytes. If a copy was needed, returns the pool's
157    /// bytes; otherwise returns the original directly.
158    pub fn as_bytes(&self) -> &[u8] {
159        match &self.pool {
160            Some(pool) => pool.as_bytes(),
161            None => &self.original,
162        }
163    }
164
165    /// Whether a copy was required for alignment.
166    pub fn was_realigned(&self) -> bool {
167        self.was_copied
168    }
169
170    /// Length in bytes.
171    pub fn len(&self) -> usize {
172        self.original.len()
173    }
174
175    /// Whether empty.
176    pub fn is_empty(&self) -> bool {
177        self.original.is_empty()
178    }
179
180    /// Map these aligned bytes to a Tensor, identical to Tensor::from_bytes
181    /// but with alignment guarantee.
182    pub fn as_tensor(&self, shape: &[usize], dtype: &str) -> Result<Tensor, RuntimeError> {
183        Tensor::from_bytes(self.as_bytes(), shape, dtype)
184    }
185}
186
187impl fmt::Display for AlignedByteSlice {
188    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
189        write!(
190            f,
191            "AlignedByteSlice(len={}, realigned={})",
192            self.len(),
193            self.was_copied
194        )
195    }
196}
197