Skip to main content

array_format/
ndarray_ext.rs

1//! Optional integration with the `ndarray` crate.
2//!
3//! Enabled via the `ndarray` Cargo feature.
4
5use std::ops::Range;
6
7use crate::array::ArrayElement;
8use crate::error::{Error, Result};
9use crate::file::{ArrayFile, ChunkedSchema};
10
11pub(crate) fn make_si(
12    ranges: &[Range<usize>],
13) -> ndarray::SliceInfo<Vec<ndarray::SliceInfoElem>, ndarray::IxDyn, ndarray::IxDyn> {
14    let elems: Vec<ndarray::SliceInfoElem> = ranges
15        .iter()
16        .map(|r| ndarray::SliceInfoElem::Slice {
17            start: r.start as isize,
18            end: Some(r.end as isize),
19            step: 1,
20        })
21        .collect();
22    // SAFETY: caller ensures elems.len() equals the target array's ndim.
23    unsafe { ndarray::SliceInfo::new(elems).expect("ndim/slice length mismatch") }
24}
25
26pub(crate) fn iter_nd_coords(ranges: &[Range<u32>]) -> impl Iterator<Item = Vec<u32>> + '_ {
27    let counts: Vec<u32> = ranges.iter().map(|r| r.end - r.start).collect();
28    let total: usize = counts.iter().map(|&c| c as usize).product();
29    (0..total).map(move |mut i| {
30        let mut coord = vec![0u32; ranges.len()];
31        for d in (0..ranges.len()).rev() {
32            coord[d] = ranges[d].start + (i as u32 % counts[d]);
33            i /= counts[d] as usize;
34        }
35        coord
36    })
37}
38
39/// Assembles an array from `file`, reading only chunks that overlap `slice`
40/// (or the full array when `slice` is `None`).
41pub(crate) async fn assemble_nd<T>(
42    file: &ArrayFile,
43    name: &str,
44    slice: Option<&[Range<usize>]>,
45) -> Result<ndarray::ArcArray<T, ndarray::IxDyn>>
46where
47    T: ArrayElement,
48{
49    let ChunkedSchema {
50        full_shape: full_shape_u32,
51        chunk_shape: chunk_shape_u32,
52        dtype,
53        all_coords,
54    } = file.get_chunked_schema(name)?;
55    if dtype != T::DTYPE {
56        return Err(Error::DTypeMismatch {
57            expected: dtype,
58            actual: T::DTYPE,
59        });
60    }
61    let fill = T::fill_element(file.get_array(name)?.fill_value.as_ref());
62    let full_shape: Vec<usize> = full_shape_u32.iter().map(|&x| x as usize).collect();
63    let chunk_shape: Vec<usize> = chunk_shape_u32.iter().map(|&x| x as usize).collect();
64    let ndim = full_shape.len();
65
66    let effective: Vec<Range<usize>> = match slice {
67        None => full_shape.iter().map(|&s| 0..s).collect(),
68        Some(s) => {
69            if s.len() != ndim {
70                return Err(Error::InvalidFooter(format!(
71                    "slice has {} axes but '{name}' has {ndim}",
72                    s.len()
73                )));
74            }
75            s.iter()
76                .zip(&full_shape)
77                .map(|(r, &s)| r.start.min(s)..r.end.min(s))
78                .collect()
79        }
80    };
81
82    let output_shape: Vec<usize> = effective.iter().map(|r| r.end - r.start).collect();
83    let mut output =
84        ndarray::Array::<T, ndarray::IxDyn>::from_elem(ndarray::IxDyn(&output_shape), fill);
85
86    for coord in all_coords {
87        let chunk_range: Vec<Range<usize>> = (0..ndim)
88            .map(|i| {
89                let start = coord[i] as usize * chunk_shape[i];
90                let end = (start + chunk_shape[i]).min(full_shape[i]);
91                start..end
92            })
93            .collect();
94
95        let overlap: Vec<Range<usize>> = (0..ndim)
96            .map(|i| {
97                effective[i].start.max(chunk_range[i].start)
98                    ..effective[i].end.min(chunk_range[i].end)
99            })
100            .collect();
101
102        if overlap.iter().any(|r| r.is_empty()) {
103            continue;
104        }
105
106        let values = file.read_chunk::<T>(name, &coord).await?;
107        let chunk_actual_shape: Vec<usize> = chunk_range.iter().map(|r| r.end - r.start).collect();
108
109        let chunk_nd = ndarray::Array::from_shape_vec(ndarray::IxDyn(&chunk_actual_shape), values)
110            .map_err(|e| Error::InvalidFooter(e.to_string()))?;
111
112        let chunk_si = make_si(
113            &(0..ndim)
114                .map(|i| {
115                    (overlap[i].start - chunk_range[i].start)
116                        ..(overlap[i].end - chunk_range[i].start)
117                })
118                .collect::<Vec<_>>(),
119        );
120
121        let out_si = make_si(
122            &(0..ndim)
123                .map(|i| {
124                    (overlap[i].start - effective[i].start)..(overlap[i].end - effective[i].start)
125                })
126                .collect::<Vec<_>>(),
127        );
128
129        output.slice_mut(out_si).assign(&chunk_nd.slice(chunk_si));
130    }
131
132    Ok(output.into_shared())
133}
134
135/// Writes `data` into a chunked array at `offset`, performing
136/// read-modify-write for partial chunks.
137pub(crate) async fn write_nd<T>(
138    file: &mut ArrayFile,
139    name: &str,
140    data: ndarray::ArrayView<'_, T, ndarray::IxDyn>,
141    offset: &[usize],
142) -> Result<()>
143where
144    T: ArrayElement,
145{
146    let ChunkedSchema {
147        full_shape: full_shape_u32,
148        chunk_shape: chunk_shape_u32,
149        dtype,
150        ..
151    } = file.get_chunked_schema(name)?;
152
153    if dtype != T::DTYPE {
154        return Err(Error::DTypeMismatch {
155            expected: dtype,
156            actual: T::DTYPE,
157        });
158    }
159    let ndim = full_shape_u32.len();
160    if offset.len() != ndim || data.ndim() != ndim {
161        return Err(Error::InvalidFooter(format!(
162            "'{name}' has {ndim} dimensions but offset has {} and data has {}",
163            offset.len(),
164            data.ndim()
165        )));
166    }
167
168    let full_shape: Vec<usize> = full_shape_u32.iter().map(|&x| x as usize).collect();
169    let chunk_shape: Vec<usize> = chunk_shape_u32.iter().map(|&x| x as usize).collect();
170
171    for i in 0..ndim {
172        let end = offset[i]
173            .checked_add(data.shape()[i])
174            .ok_or_else(|| Error::InvalidFooter(format!("offset overflow on axis {i}")))?;
175        if end > full_shape[i] {
176            return Err(Error::InvalidFooter(format!(
177                "write region [{}, {}) exceeds array size {} on axis {i}",
178                offset[i], end, full_shape[i]
179            )));
180        }
181    }
182
183    let write_end: Vec<usize> = (0..ndim).map(|i| offset[i] + data.shape()[i]).collect();
184
185    let chunk_ranges: Vec<Range<u32>> = (0..ndim)
186        .map(|i| {
187            let start = (offset[i] / chunk_shape[i]) as u32;
188            let end = write_end[i]
189                .div_ceil(chunk_shape[i])
190                .min(full_shape[i].div_ceil(chunk_shape[i])) as u32;
191            start..end
192        })
193        .collect();
194
195    // Phase 1: collect (coord, encoded_bytes) — reads allowed.
196    let mut writes: Vec<(Vec<u32>, Vec<u8>)> = Vec::new();
197
198    for coord in iter_nd_coords(&chunk_ranges) {
199        let chunk_global: Vec<Range<usize>> = (0..ndim)
200            .map(|i| {
201                let start = coord[i] as usize * chunk_shape[i];
202                let end = (start + chunk_shape[i]).min(full_shape[i]);
203                start..end
204            })
205            .collect();
206
207        let chunk_actual_shape: Vec<usize> = chunk_global.iter().map(|r| r.end - r.start).collect();
208
209        let overlap: Vec<Range<usize>> = (0..ndim)
210            .map(|i| offset[i].max(chunk_global[i].start)..write_end[i].min(chunk_global[i].end))
211            .collect();
212
213        if overlap.iter().any(|r| r.is_empty()) {
214            continue;
215        }
216
217        let full_cover = (0..ndim).all(|i| overlap[i] == chunk_global[i]);
218
219        let input_local: Vec<Range<usize>> = (0..ndim)
220            .map(|i| (overlap[i].start - offset[i])..(overlap[i].end - offset[i]))
221            .collect();
222
223        let encoded: Vec<u8> = if full_cover {
224            let v: Vec<T> = data.slice(make_si(&input_local)).iter().cloned().collect();
225            T::encode_chunk(&v)
226        } else {
227            let mut base = file.read_chunk::<T>(name, &coord).await?;
228            let mut chunk_nd =
229                ndarray::Array::from_shape_vec(ndarray::IxDyn(&chunk_actual_shape), base.clone())
230                    .map_err(|e| Error::InvalidFooter(e.to_string()))?;
231
232            let chunk_local: Vec<Range<usize>> = (0..ndim)
233                .map(|i| {
234                    (overlap[i].start - chunk_global[i].start)
235                        ..(overlap[i].end - chunk_global[i].start)
236                })
237                .collect();
238
239            chunk_nd
240                .slice_mut(make_si(&chunk_local))
241                .assign(&data.slice(make_si(&input_local)));
242
243            base = chunk_nd.iter().cloned().collect();
244            T::encode_chunk(&base)
245        };
246
247        writes.push((coord, encoded));
248    }
249
250    // Phase 2: apply writes — mutable, no reads.
251    for (coord, bytes) in writes {
252        file.write_chunk_raw(name, coord, bytes)?;
253    }
254
255    Ok(())
256}