ragged_buffer/
ragged_buffer.rs

1use std::cmp::Ordering;
2use std::collections::{binary_heap, BinaryHeap};
3use std::fmt::{Display, Write};
4use std::ops::{Add, Mul, Range, Sub};
5
6use ndarray::{ArrayView1, ArrayView2, ArrayView3};
7
8#[derive(Debug)]
9pub enum Error {
10    Generic(String),
11}
12
13impl Error {
14    fn generic<S: Into<String>>(s: S) -> Self {
15        Self::Generic(s.into())
16    }
17}
18
19pub type Result<T> = std::result::Result<T, Error>;
20
21#[derive(Clone, PartialEq, Eq, Hash, Debug)]
22pub struct RaggedBuffer<T> {
23    pub data: Vec<T>,
24    // Each element of `subarrays` gives the start/end index of the items within that subarray (step size 1).
25    // The start index of the data of an item is obtained by multiplying its index by `features`.
26    pub subarrays: Vec<Range<usize>>,
27    pub features: usize,
28}
29
30pub trait BinOp<T> {
31    fn op(lhs: T, rhs: T) -> T;
32}
33
34pub struct BinOpAdd;
35
36impl<T: Add<T, Output = T>> BinOp<T> for BinOpAdd {
37    #[inline]
38    fn op(lhs: T, rhs: T) -> T {
39        lhs + rhs
40    }
41}
42
43pub struct BinOpSub;
44
45impl<T: Sub<T, Output = T>> BinOp<T> for BinOpSub {
46    #[inline]
47    fn op(lhs: T, rhs: T) -> T {
48        lhs - rhs
49    }
50}
51
52pub struct BinOpMul;
53
54impl<T: Mul<T, Output = T>> BinOp<T> for BinOpMul {
55    #[inline]
56    fn op(lhs: T, rhs: T) -> T {
57        lhs * rhs
58    }
59}
60
61impl<T: Copy + Display + std::fmt::Debug> RaggedBuffer<T> {
62    pub fn new(features: usize) -> Self {
63        RaggedBuffer {
64            data: Vec::new(),
65            subarrays: Vec::new(),
66            features,
67        }
68    }
69
70    pub fn from_array(data: ArrayView3<T>) -> Self {
71        let features = data.shape()[2];
72        RaggedBuffer {
73            data: data.iter().cloned().collect(),
74            subarrays: (0..data.shape()[0])
75                .map(|i| i * data.shape()[1]..(i + 1) * data.shape()[1])
76                .collect(),
77            features,
78        }
79    }
80
81    pub fn from_flattened(data: ArrayView2<T>, lengths: ArrayView1<i64>) -> Result<Self> {
82        let features = data.shape()[1];
83        let mut subarrays = Vec::new();
84        let mut item = 0;
85        for len in lengths.iter().cloned() {
86            subarrays.push(item..(item + len as usize));
87            item += len as usize;
88        }
89        if item != data.shape()[0] {
90            Err(Error::generic(format!(
91                "Lengths array specifies {} items, but data array has {} items",
92                item,
93                data.shape()[0]
94            )))
95        } else {
96            Ok(RaggedBuffer {
97                data: data.iter().cloned().collect(),
98                subarrays,
99                features,
100            })
101        }
102    }
103
104    pub fn extend(&mut self, other: &RaggedBuffer<T>) -> Result<()> {
105        if self.features != other.features {
106            return Err(Error::generic(format!(
107                "Features mismatch: {} != {}",
108                self.features, other.features
109            )));
110        }
111        let item = self.items();
112        self.data.extend(other.data.iter());
113        self.subarrays
114            .extend(other.subarrays.iter().map(|r| r.start + item..r.end + item));
115        Ok(())
116    }
117
118    pub fn clear(&mut self) {
119        self.data.clear();
120        self.subarrays.clear();
121    }
122
123    // pub fn as_array<'a>(
124    //     &self,
125    //     py: Python<'a>,
126    // ) -> PyResult<&'a numpy::PyArray<T, numpy::ndarray::Dim<[usize; 2]>>> {
127    //     self.data
128    //         .to_pyarray(py)
129    //         .reshape((self.items, self.features))
130    // }
131
132    pub fn push(&mut self, data: &ArrayView2<T>) -> Result<()> {
133        if data.dim().1 != self.features {
134            return Err(Error::generic(format!(
135                "Features mismatch: {} != {}",
136                self.features,
137                data.dim().1
138            )));
139        }
140        self.subarrays
141            .push(self.items()..(self.items() + data.dim().0));
142        match data.as_slice() {
143            Some(slice) => self.data.extend_from_slice(slice),
144            None => {
145                for x in data.iter() {
146                    self.data.push(*x);
147                }
148            }
149        }
150        Ok(())
151    }
152
153    pub fn push_empty(&mut self) {
154        self.subarrays.push(self.items()..self.items());
155    }
156
157    pub fn swizzle(&self, indices: ArrayView1<i64>) -> Result<RaggedBuffer<T>> {
158        let indices = indices
159            .as_slice()
160            .ok_or_else(|| Error::generic("Indices must be a **contiguous** 1D array"))?;
161        let mut subarrays = Vec::with_capacity(indices.len());
162        let mut item = 0usize;
163        for i in indices {
164            let sublen = self.subarrays[*i as usize].end - self.subarrays[*i as usize].start;
165            subarrays.push(item..(item + sublen));
166            item += sublen;
167        }
168        let mut data = Vec::with_capacity(item * self.features);
169        for i in indices {
170            let Range { start, end } = self.subarrays[*i as usize];
171            data.extend_from_slice(&self.data[start * self.features..end * self.features]);
172        }
173        Ok(RaggedBuffer {
174            data,
175            subarrays,
176            features: self.features,
177        })
178    }
179
180    // TODO: dedupe with swizzle
181    pub fn swizzle_usize(&self, indices: &[usize]) -> Result<RaggedBuffer<T>> {
182        let mut subarrays = Vec::with_capacity(indices.len());
183        let mut item = 0usize;
184        for &i in indices {
185            let sublen = self.subarrays[i].end - self.subarrays[i].start;
186            subarrays.push(item..(item + sublen));
187            item += sublen;
188        }
189        let mut data = Vec::with_capacity(item * self.features);
190        for i in indices {
191            let Range { start, end } = self.subarrays[*i as usize];
192            data.extend_from_slice(&self.data[start * self.features..end * self.features]);
193        }
194        Ok(RaggedBuffer {
195            data,
196            subarrays,
197            features: self.features,
198        })
199    }
200
201    pub fn get(&self, i: usize) -> RaggedBuffer<T> {
202        let subarray = self.subarrays[i].clone();
203        let Range { start, end } = subarray;
204        RaggedBuffer {
205            subarrays: vec![0..subarray.len()],
206            data: self.data[start * self.features..end * self.features].to_vec(),
207            features: self.features,
208        }
209    }
210
211    pub fn size0(&self) -> usize {
212        self.subarrays.len()
213    }
214
215    pub fn lengths(&self) -> Vec<i64> {
216        self.subarrays
217            .iter()
218            .map(|r| (r.end - r.start) as i64)
219            .collect::<Vec<_>>()
220    }
221
222    pub fn size1(&self, i: usize) -> Result<usize> {
223        if i >= self.subarrays.len() {
224            Err(Error::generic(format!("Index {} out of range", i)))
225        } else {
226            Ok(self.subarrays[i].end - self.subarrays[i].start)
227        }
228    }
229
230    pub fn size2(&self) -> usize {
231        self.features
232    }
233
234    pub fn __str__(&self) -> Result<String> {
235        let mut array = String::new();
236        array.push_str("RaggedBuffer([");
237        array.push('\n');
238        for range in &self.subarrays {
239            let slice = range.start * self.features..range.end * self.features;
240            if range.start == range.end {
241                writeln!(array, "    [],").unwrap();
242            } else if range.start + 1 == range.end {
243                writeln!(array, "    [{:?}],", &self.data[slice]).unwrap();
244            } else {
245                writeln!(array, "    [").unwrap();
246                for i in slice.clone() {
247                    if i % self.features == 0 {
248                        if i != slice.start {
249                            writeln!(array, "],").unwrap();
250                        }
251                        write!(array, "        [").unwrap();
252                    }
253                    write!(array, "{}", self.data[i]).unwrap();
254                    if i % self.features != self.features - 1 {
255                        write!(array, ", ").unwrap();
256                    }
257                }
258                writeln!(array, "],").unwrap();
259                writeln!(array, "    ],").unwrap();
260            }
261        }
262        write!(
263            array,
264            "], '{} * var * {} * {})",
265            self.subarrays.len(),
266            self.features,
267            std::any::type_name::<T>(),
268        )
269        .unwrap();
270
271        Ok(array)
272    }
273
274    pub fn binop<Op: BinOp<T>>(&self, rhs: &RaggedBuffer<T>) -> Result<RaggedBuffer<T>> {
275        if self.features == rhs.features && self.subarrays == rhs.subarrays {
276            let mut data = Vec::with_capacity(self.data.len());
277            for i in 0..self.data.len() {
278                data.push(Op::op(self.data[i], rhs.data[i]));
279            }
280            Ok(RaggedBuffer {
281                data,
282                subarrays: self.subarrays.clone(),
283                features: self.features,
284            })
285        } else if self.features == rhs.features
286            && self.subarrays.len() == rhs.subarrays.len()
287            && rhs.subarrays.iter().all(|r| r.end - r.start == 1)
288        {
289            let mut data = Vec::with_capacity(self.data.len());
290            for (subarray, rhs_subarray) in self.subarrays.iter().zip(rhs.subarrays.iter()) {
291                for item in subarray.clone() {
292                    let lhs_offset = item * self.features;
293                    let rhs_offset = rhs_subarray.start * self.features;
294                    for i in 0..self.features {
295                        data.push(Op::op(self.data[lhs_offset + i], rhs.data[rhs_offset + i]));
296                    }
297                }
298            }
299            Ok(RaggedBuffer {
300                data,
301                subarrays: self.subarrays.clone(),
302                features: self.features,
303            })
304        } else if self.features == rhs.features
305            && self.subarrays.len() == rhs.subarrays.len()
306            && self.subarrays.iter().all(|r| r.end - r.start == 1)
307        {
308            rhs.binop::<Op>(self)
309        } else {
310            Err(Error::generic(format!(
311                "Dimensions mismatch: ({}, {:?}, {}) != ({}, {:?}, {})",
312                self.size0(),
313                self.subarrays
314                    .iter()
315                    .map(|r| r.end - r.start)
316                    .collect::<Vec<_>>(),
317                self.size2(),
318                rhs.size0(),
319                rhs.subarrays
320                    .iter()
321                    .map(|r| r.end - r.start)
322                    .collect::<Vec<_>>(),
323                rhs.size2(),
324            )))
325        }
326    }
327
328    pub fn op_scalar<Op: BinOp<T>>(&self, scalar: T) -> RaggedBuffer<T> {
329        RaggedBuffer {
330            data: self.data.iter().map(|x| Op::op(*x, scalar)).collect(),
331            subarrays: self.subarrays.clone(),
332            features: self.features,
333        }
334    }
335
336    pub fn indices(&self, dim: usize) -> Result<RaggedBuffer<i64>> {
337        match dim {
338            0 => {
339                let mut indices = Vec::with_capacity(self.items());
340                for (index, subarray) in self.subarrays.iter().enumerate() {
341                    for _ in subarray.clone() {
342                        indices.push(index as i64);
343                    }
344                }
345                Ok(RaggedBuffer {
346                    subarrays: self.subarrays.clone(),
347                    data: indices,
348                    features: 1,
349                })
350            }
351            1 => {
352                let mut indices = Vec::with_capacity(self.items());
353                for subarray in &self.subarrays {
354                    for (i, _) in subarray.clone().enumerate() {
355                        indices.push(i as i64);
356                    }
357                }
358                Ok(RaggedBuffer {
359                    subarrays: self.subarrays.clone(),
360                    data: indices,
361                    features: 1,
362                })
363            }
364            _ => Err(Error::generic(format!("Invalid dimension {}", dim))),
365        }
366    }
367
368    pub fn flat_indices(&self) -> Result<RaggedBuffer<i64>> {
369        Ok(RaggedBuffer {
370            subarrays: self.subarrays.clone(),
371            data: (0..self.items()).map(|i| i as i64).collect(),
372            features: 1,
373        })
374    }
375
376    pub fn cat(buffers: &[&RaggedBuffer<T>], dim: usize) -> Result<RaggedBuffer<T>> {
377        match dim {
378            0 => {
379                if buffers.iter().any(|b| b.features != buffers[0].features) {
380                    return Err(Error::generic(format!(
381                        "All buffers must have the same number of features, but found {}",
382                        buffers
383                            .iter()
384                            .map(|b| b.features.to_string())
385                            .collect::<Vec<_>>()
386                            .join(", ")
387                    )));
388                }
389                let mut data = Vec::with_capacity(buffers.iter().map(|b| b.data.len()).sum());
390                for buffer in buffers {
391                    data.extend_from_slice(&buffer.data);
392                }
393                let mut subarrays =
394                    Vec::with_capacity(buffers.iter().map(|b| b.subarrays.len()).sum());
395                let mut item = 0;
396                for buffer in buffers {
397                    subarrays.extend_from_slice(
398                        &buffer
399                            .subarrays
400                            .iter()
401                            .map(|r| {
402                                let start = r.start + item;
403                                let end = r.end + item;
404                                start..end
405                            })
406                            .collect::<Vec<_>>(),
407                    );
408                    item += buffer.items();
409                }
410                Ok(RaggedBuffer {
411                    data,
412                    subarrays,
413                    features: buffers[0].features,
414                })
415            }
416            1 => {
417                if buffers
418                    .iter()
419                    .any(|b| b.subarrays.len() != buffers[0].subarrays.len())
420                {
421                    return Err(Error::generic(format!(
422                        "All buffers must have the same number of subarrays, but found {}",
423                        buffers
424                            .iter()
425                            .map(|b| b.subarrays.len().to_string())
426                            .collect::<Vec<_>>()
427                            .join(", ")
428                    )));
429                }
430                if buffers.iter().any(|b| b.features != buffers[0].features) {
431                    return Err(Error::generic(format!(
432                        "All buffers must have the same number of features, but found {}",
433                        buffers
434                            .iter()
435                            .map(|b| b.features.to_string())
436                            .collect::<Vec<_>>()
437                            .join(", ")
438                    )));
439                }
440                let mut data = Vec::with_capacity(buffers.iter().map(|b| b.data.len()).sum());
441                let mut subarrays =
442                    Vec::with_capacity(buffers.iter().map(|b| b.subarrays.len()).sum());
443                let mut item = 0;
444                let mut last_item = 0;
445                for i in 0..buffers[0].subarrays.len() {
446                    for buffer in buffers {
447                        let Range { start, end } = &buffer.subarrays[i];
448                        data.extend_from_slice(
449                            &buffer.data[start * buffer.features..end * buffer.features],
450                        );
451                        item += end - start;
452                    }
453                    subarrays.push(Range {
454                        start: last_item,
455                        end: item,
456                    });
457                    last_item = item;
458                }
459                Ok(RaggedBuffer {
460                    data,
461                    subarrays,
462                    features: buffers[0].features,
463                })
464            }
465            2 => {
466                // TODO: disallow broadcasting on some sequences but not other?
467                // TODO: think more about empty sequences
468                let sequences = buffers[0].size0();
469                if buffers.iter().any(|b| b.size0() != sequences) {
470                    return Err(Error::generic(format!(
471                        "All buffers must have the same number of sequences, but found {}",
472                        buffers
473                            .iter()
474                            .map(|b| b.size0().to_string())
475                            .collect::<Vec<_>>()
476                            .join(", ")
477                    )));
478                }
479
480                let features = buffers.iter().map(|b| b.features).sum();
481                let mut subarrays = Vec::with_capacity(sequences);
482                let mut data = Vec::with_capacity(sequences * features);
483                let mut items = 0;
484                for iseq in 0..sequences {
485                    let seqlen = if buffers.iter().any(|b| {
486                        b.size1(iseq)
487                            .expect("All sequences should be the same length.")
488                            == 0
489                    }) {
490                        0
491                    } else {
492                        buffers
493                            .iter()
494                            .map(|b| {
495                                b.size1(iseq)
496                                    .expect("All sequences should be the same length.")
497                            })
498                            .max()
499                            .expect("There should be at least one buffer.")
500                    };
501                    subarrays.push(items..items + seqlen);
502                    items += seqlen;
503                    for iitem in 0..seqlen {
504                        for (ibuf, buffer) in buffers.iter().enumerate() {
505                            let _items = buffer.subarrays[iseq].len();
506                            if _items == 1 {
507                                data.extend_from_slice(
508                                    &buffer.data[buffer.subarrays[iseq].start * buffer.features
509                                        ..buffer.subarrays[iseq].end * buffer.features],
510                                );
511                            } else {
512                                if _items != seqlen {
513                                    return Err(Error::generic(format!(
514                                        "Buffer {} has {} items for sequence {}, but expected {}",
515                                        ibuf, _items, iseq, seqlen
516                                    )));
517                                }
518                                let start_item = buffer.subarrays[iseq].start + iitem;
519                                data.extend_from_slice(
520                                    &buffer.data[start_item * buffer.features
521                                        ..(start_item + 1) * buffer.features],
522                                );
523                            }
524                        }
525                    }
526                }
527
528                Ok(RaggedBuffer {
529                    data,
530                    subarrays,
531                    features,
532                })
533            }
534            _ => Err(Error::generic(format!(
535                "Invalid dimension {}, RaggedBuffer only has 3 dimensions",
536                dim
537            ))),
538        }
539    }
540
541    #[allow(clippy::type_complexity)]
542    pub fn padpack(&self) -> Option<(Vec<i64>, Vec<f32>, Vec<i64>, (usize, usize))> {
543        if self.subarrays.is_empty()
544            || self
545                .subarrays
546                .iter()
547                .all(|r| r.end - r.start == self.subarrays[0].end - self.subarrays[0].start)
548        {
549            return None;
550        }
551
552        let mut padbpack_index = vec![];
553        let mut padpack_batch = vec![];
554        let mut padpack_inverse_index = vec![];
555        let max_seq_len = self
556            .subarrays
557            .iter()
558            .map(|r| r.end - r.start)
559            .max()
560            .unwrap();
561        let mut sequences: BinaryHeap<Sequence> = binary_heap::BinaryHeap::new();
562
563        for (batch_index, subarray) in self.subarrays.iter().enumerate() {
564            let (free, packed_batch_index) = match sequences.peek().cloned() {
565                Some(seq) if seq.free >= subarray.end - subarray.start => {
566                    sequences.pop();
567                    (seq.free, seq.batch_index)
568                }
569                _ => {
570                    for _ in 0..max_seq_len {
571                        padbpack_index.push(0);
572                        padpack_batch.push(f32::NAN);
573                    }
574                    (max_seq_len, sequences.len())
575                }
576            };
577
578            for (i, item) in subarray.clone().enumerate() {
579                let packed_index = packed_batch_index * max_seq_len + max_seq_len - free + i;
580                padbpack_index[packed_index] = item as i64;
581                padpack_batch[packed_index] = batch_index as f32;
582                padpack_inverse_index.push(packed_index as i64);
583            }
584            sequences.push(Sequence {
585                batch_index: packed_batch_index,
586                free: free - (subarray.end - subarray.start),
587            });
588        }
589
590        Some((
591            padbpack_index,
592            padpack_batch,
593            padpack_inverse_index,
594            (sequences.len(), max_seq_len),
595        ))
596    }
597
598    pub fn items(&self) -> usize {
599        self.subarrays.last().map(|r| r.end).unwrap_or(0)
600    }
601
602    pub fn len(&self) -> usize {
603        self.data.len()
604    }
605
606    pub fn is_empty(&self) -> bool {
607        self.data.is_empty()
608    }
609}
610
611#[derive(Copy, Clone, Eq, PartialEq, Debug)]
612struct Sequence {
613    free: usize,
614    batch_index: usize,
615}
616
617impl Ord for Sequence {
618    fn cmp(&self, other: &Self) -> Ordering {
619        self.free
620            .cmp(&other.free)
621            .then_with(|| other.batch_index.cmp(&self.batch_index))
622    }
623}
624
625impl PartialOrd for Sequence {
626    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
627        Some(self.cmp(other))
628    }
629}