1use std::cmp::Ordering;
2use std::collections::{binary_heap, BinaryHeap};
3use std::fmt::{Display, Write};
4use std::ops::{Add, Mul, Range, Sub};
5
6use ndarray::{ArrayView1, ArrayView2, ArrayView3};
7
8#[derive(Debug)]
9pub enum Error {
10 Generic(String),
11}
12
13impl Error {
14 fn generic<S: Into<String>>(s: S) -> Self {
15 Self::Generic(s.into())
16 }
17}
18
19pub type Result<T> = std::result::Result<T, Error>;
20
21#[derive(Clone, PartialEq, Eq, Hash, Debug)]
22pub struct RaggedBuffer<T> {
23 pub data: Vec<T>,
24 pub subarrays: Vec<Range<usize>>,
27 pub features: usize,
28}
29
30pub trait BinOp<T> {
31 fn op(lhs: T, rhs: T) -> T;
32}
33
34pub struct BinOpAdd;
35
36impl<T: Add<T, Output = T>> BinOp<T> for BinOpAdd {
37 #[inline]
38 fn op(lhs: T, rhs: T) -> T {
39 lhs + rhs
40 }
41}
42
43pub struct BinOpSub;
44
45impl<T: Sub<T, Output = T>> BinOp<T> for BinOpSub {
46 #[inline]
47 fn op(lhs: T, rhs: T) -> T {
48 lhs - rhs
49 }
50}
51
52pub struct BinOpMul;
53
54impl<T: Mul<T, Output = T>> BinOp<T> for BinOpMul {
55 #[inline]
56 fn op(lhs: T, rhs: T) -> T {
57 lhs * rhs
58 }
59}
60
61impl<T: Copy + Display + std::fmt::Debug> RaggedBuffer<T> {
62 pub fn new(features: usize) -> Self {
63 RaggedBuffer {
64 data: Vec::new(),
65 subarrays: Vec::new(),
66 features,
67 }
68 }
69
70 pub fn from_array(data: ArrayView3<T>) -> Self {
71 let features = data.shape()[2];
72 RaggedBuffer {
73 data: data.iter().cloned().collect(),
74 subarrays: (0..data.shape()[0])
75 .map(|i| i * data.shape()[1]..(i + 1) * data.shape()[1])
76 .collect(),
77 features,
78 }
79 }
80
81 pub fn from_flattened(data: ArrayView2<T>, lengths: ArrayView1<i64>) -> Result<Self> {
82 let features = data.shape()[1];
83 let mut subarrays = Vec::new();
84 let mut item = 0;
85 for len in lengths.iter().cloned() {
86 subarrays.push(item..(item + len as usize));
87 item += len as usize;
88 }
89 if item != data.shape()[0] {
90 Err(Error::generic(format!(
91 "Lengths array specifies {} items, but data array has {} items",
92 item,
93 data.shape()[0]
94 )))
95 } else {
96 Ok(RaggedBuffer {
97 data: data.iter().cloned().collect(),
98 subarrays,
99 features,
100 })
101 }
102 }
103
104 pub fn extend(&mut self, other: &RaggedBuffer<T>) -> Result<()> {
105 if self.features != other.features {
106 return Err(Error::generic(format!(
107 "Features mismatch: {} != {}",
108 self.features, other.features
109 )));
110 }
111 let item = self.items();
112 self.data.extend(other.data.iter());
113 self.subarrays
114 .extend(other.subarrays.iter().map(|r| r.start + item..r.end + item));
115 Ok(())
116 }
117
118 pub fn clear(&mut self) {
119 self.data.clear();
120 self.subarrays.clear();
121 }
122
123 pub fn push(&mut self, data: &ArrayView2<T>) -> Result<()> {
133 if data.dim().1 != self.features {
134 return Err(Error::generic(format!(
135 "Features mismatch: {} != {}",
136 self.features,
137 data.dim().1
138 )));
139 }
140 self.subarrays
141 .push(self.items()..(self.items() + data.dim().0));
142 match data.as_slice() {
143 Some(slice) => self.data.extend_from_slice(slice),
144 None => {
145 for x in data.iter() {
146 self.data.push(*x);
147 }
148 }
149 }
150 Ok(())
151 }
152
153 pub fn push_empty(&mut self) {
154 self.subarrays.push(self.items()..self.items());
155 }
156
157 pub fn swizzle(&self, indices: ArrayView1<i64>) -> Result<RaggedBuffer<T>> {
158 let indices = indices
159 .as_slice()
160 .ok_or_else(|| Error::generic("Indices must be a **contiguous** 1D array"))?;
161 let mut subarrays = Vec::with_capacity(indices.len());
162 let mut item = 0usize;
163 for i in indices {
164 let sublen = self.subarrays[*i as usize].end - self.subarrays[*i as usize].start;
165 subarrays.push(item..(item + sublen));
166 item += sublen;
167 }
168 let mut data = Vec::with_capacity(item * self.features);
169 for i in indices {
170 let Range { start, end } = self.subarrays[*i as usize];
171 data.extend_from_slice(&self.data[start * self.features..end * self.features]);
172 }
173 Ok(RaggedBuffer {
174 data,
175 subarrays,
176 features: self.features,
177 })
178 }
179
180 pub fn swizzle_usize(&self, indices: &[usize]) -> Result<RaggedBuffer<T>> {
182 let mut subarrays = Vec::with_capacity(indices.len());
183 let mut item = 0usize;
184 for &i in indices {
185 let sublen = self.subarrays[i].end - self.subarrays[i].start;
186 subarrays.push(item..(item + sublen));
187 item += sublen;
188 }
189 let mut data = Vec::with_capacity(item * self.features);
190 for i in indices {
191 let Range { start, end } = self.subarrays[*i as usize];
192 data.extend_from_slice(&self.data[start * self.features..end * self.features]);
193 }
194 Ok(RaggedBuffer {
195 data,
196 subarrays,
197 features: self.features,
198 })
199 }
200
201 pub fn get(&self, i: usize) -> RaggedBuffer<T> {
202 let subarray = self.subarrays[i].clone();
203 let Range { start, end } = subarray;
204 RaggedBuffer {
205 subarrays: vec![0..subarray.len()],
206 data: self.data[start * self.features..end * self.features].to_vec(),
207 features: self.features,
208 }
209 }
210
211 pub fn size0(&self) -> usize {
212 self.subarrays.len()
213 }
214
215 pub fn lengths(&self) -> Vec<i64> {
216 self.subarrays
217 .iter()
218 .map(|r| (r.end - r.start) as i64)
219 .collect::<Vec<_>>()
220 }
221
222 pub fn size1(&self, i: usize) -> Result<usize> {
223 if i >= self.subarrays.len() {
224 Err(Error::generic(format!("Index {} out of range", i)))
225 } else {
226 Ok(self.subarrays[i].end - self.subarrays[i].start)
227 }
228 }
229
230 pub fn size2(&self) -> usize {
231 self.features
232 }
233
234 pub fn __str__(&self) -> Result<String> {
235 let mut array = String::new();
236 array.push_str("RaggedBuffer([");
237 array.push('\n');
238 for range in &self.subarrays {
239 let slice = range.start * self.features..range.end * self.features;
240 if range.start == range.end {
241 writeln!(array, " [],").unwrap();
242 } else if range.start + 1 == range.end {
243 writeln!(array, " [{:?}],", &self.data[slice]).unwrap();
244 } else {
245 writeln!(array, " [").unwrap();
246 for i in slice.clone() {
247 if i % self.features == 0 {
248 if i != slice.start {
249 writeln!(array, "],").unwrap();
250 }
251 write!(array, " [").unwrap();
252 }
253 write!(array, "{}", self.data[i]).unwrap();
254 if i % self.features != self.features - 1 {
255 write!(array, ", ").unwrap();
256 }
257 }
258 writeln!(array, "],").unwrap();
259 writeln!(array, " ],").unwrap();
260 }
261 }
262 write!(
263 array,
264 "], '{} * var * {} * {})",
265 self.subarrays.len(),
266 self.features,
267 std::any::type_name::<T>(),
268 )
269 .unwrap();
270
271 Ok(array)
272 }
273
274 pub fn binop<Op: BinOp<T>>(&self, rhs: &RaggedBuffer<T>) -> Result<RaggedBuffer<T>> {
275 if self.features == rhs.features && self.subarrays == rhs.subarrays {
276 let mut data = Vec::with_capacity(self.data.len());
277 for i in 0..self.data.len() {
278 data.push(Op::op(self.data[i], rhs.data[i]));
279 }
280 Ok(RaggedBuffer {
281 data,
282 subarrays: self.subarrays.clone(),
283 features: self.features,
284 })
285 } else if self.features == rhs.features
286 && self.subarrays.len() == rhs.subarrays.len()
287 && rhs.subarrays.iter().all(|r| r.end - r.start == 1)
288 {
289 let mut data = Vec::with_capacity(self.data.len());
290 for (subarray, rhs_subarray) in self.subarrays.iter().zip(rhs.subarrays.iter()) {
291 for item in subarray.clone() {
292 let lhs_offset = item * self.features;
293 let rhs_offset = rhs_subarray.start * self.features;
294 for i in 0..self.features {
295 data.push(Op::op(self.data[lhs_offset + i], rhs.data[rhs_offset + i]));
296 }
297 }
298 }
299 Ok(RaggedBuffer {
300 data,
301 subarrays: self.subarrays.clone(),
302 features: self.features,
303 })
304 } else if self.features == rhs.features
305 && self.subarrays.len() == rhs.subarrays.len()
306 && self.subarrays.iter().all(|r| r.end - r.start == 1)
307 {
308 rhs.binop::<Op>(self)
309 } else {
310 Err(Error::generic(format!(
311 "Dimensions mismatch: ({}, {:?}, {}) != ({}, {:?}, {})",
312 self.size0(),
313 self.subarrays
314 .iter()
315 .map(|r| r.end - r.start)
316 .collect::<Vec<_>>(),
317 self.size2(),
318 rhs.size0(),
319 rhs.subarrays
320 .iter()
321 .map(|r| r.end - r.start)
322 .collect::<Vec<_>>(),
323 rhs.size2(),
324 )))
325 }
326 }
327
328 pub fn op_scalar<Op: BinOp<T>>(&self, scalar: T) -> RaggedBuffer<T> {
329 RaggedBuffer {
330 data: self.data.iter().map(|x| Op::op(*x, scalar)).collect(),
331 subarrays: self.subarrays.clone(),
332 features: self.features,
333 }
334 }
335
336 pub fn indices(&self, dim: usize) -> Result<RaggedBuffer<i64>> {
337 match dim {
338 0 => {
339 let mut indices = Vec::with_capacity(self.items());
340 for (index, subarray) in self.subarrays.iter().enumerate() {
341 for _ in subarray.clone() {
342 indices.push(index as i64);
343 }
344 }
345 Ok(RaggedBuffer {
346 subarrays: self.subarrays.clone(),
347 data: indices,
348 features: 1,
349 })
350 }
351 1 => {
352 let mut indices = Vec::with_capacity(self.items());
353 for subarray in &self.subarrays {
354 for (i, _) in subarray.clone().enumerate() {
355 indices.push(i as i64);
356 }
357 }
358 Ok(RaggedBuffer {
359 subarrays: self.subarrays.clone(),
360 data: indices,
361 features: 1,
362 })
363 }
364 _ => Err(Error::generic(format!("Invalid dimension {}", dim))),
365 }
366 }
367
368 pub fn flat_indices(&self) -> Result<RaggedBuffer<i64>> {
369 Ok(RaggedBuffer {
370 subarrays: self.subarrays.clone(),
371 data: (0..self.items()).map(|i| i as i64).collect(),
372 features: 1,
373 })
374 }
375
376 pub fn cat(buffers: &[&RaggedBuffer<T>], dim: usize) -> Result<RaggedBuffer<T>> {
377 match dim {
378 0 => {
379 if buffers.iter().any(|b| b.features != buffers[0].features) {
380 return Err(Error::generic(format!(
381 "All buffers must have the same number of features, but found {}",
382 buffers
383 .iter()
384 .map(|b| b.features.to_string())
385 .collect::<Vec<_>>()
386 .join(", ")
387 )));
388 }
389 let mut data = Vec::with_capacity(buffers.iter().map(|b| b.data.len()).sum());
390 for buffer in buffers {
391 data.extend_from_slice(&buffer.data);
392 }
393 let mut subarrays =
394 Vec::with_capacity(buffers.iter().map(|b| b.subarrays.len()).sum());
395 let mut item = 0;
396 for buffer in buffers {
397 subarrays.extend_from_slice(
398 &buffer
399 .subarrays
400 .iter()
401 .map(|r| {
402 let start = r.start + item;
403 let end = r.end + item;
404 start..end
405 })
406 .collect::<Vec<_>>(),
407 );
408 item += buffer.items();
409 }
410 Ok(RaggedBuffer {
411 data,
412 subarrays,
413 features: buffers[0].features,
414 })
415 }
416 1 => {
417 if buffers
418 .iter()
419 .any(|b| b.subarrays.len() != buffers[0].subarrays.len())
420 {
421 return Err(Error::generic(format!(
422 "All buffers must have the same number of subarrays, but found {}",
423 buffers
424 .iter()
425 .map(|b| b.subarrays.len().to_string())
426 .collect::<Vec<_>>()
427 .join(", ")
428 )));
429 }
430 if buffers.iter().any(|b| b.features != buffers[0].features) {
431 return Err(Error::generic(format!(
432 "All buffers must have the same number of features, but found {}",
433 buffers
434 .iter()
435 .map(|b| b.features.to_string())
436 .collect::<Vec<_>>()
437 .join(", ")
438 )));
439 }
440 let mut data = Vec::with_capacity(buffers.iter().map(|b| b.data.len()).sum());
441 let mut subarrays =
442 Vec::with_capacity(buffers.iter().map(|b| b.subarrays.len()).sum());
443 let mut item = 0;
444 let mut last_item = 0;
445 for i in 0..buffers[0].subarrays.len() {
446 for buffer in buffers {
447 let Range { start, end } = &buffer.subarrays[i];
448 data.extend_from_slice(
449 &buffer.data[start * buffer.features..end * buffer.features],
450 );
451 item += end - start;
452 }
453 subarrays.push(Range {
454 start: last_item,
455 end: item,
456 });
457 last_item = item;
458 }
459 Ok(RaggedBuffer {
460 data,
461 subarrays,
462 features: buffers[0].features,
463 })
464 }
465 2 => {
466 let sequences = buffers[0].size0();
469 if buffers.iter().any(|b| b.size0() != sequences) {
470 return Err(Error::generic(format!(
471 "All buffers must have the same number of sequences, but found {}",
472 buffers
473 .iter()
474 .map(|b| b.size0().to_string())
475 .collect::<Vec<_>>()
476 .join(", ")
477 )));
478 }
479
480 let features = buffers.iter().map(|b| b.features).sum();
481 let mut subarrays = Vec::with_capacity(sequences);
482 let mut data = Vec::with_capacity(sequences * features);
483 let mut items = 0;
484 for iseq in 0..sequences {
485 let seqlen = if buffers.iter().any(|b| {
486 b.size1(iseq)
487 .expect("All sequences should be the same length.")
488 == 0
489 }) {
490 0
491 } else {
492 buffers
493 .iter()
494 .map(|b| {
495 b.size1(iseq)
496 .expect("All sequences should be the same length.")
497 })
498 .max()
499 .expect("There should be at least one buffer.")
500 };
501 subarrays.push(items..items + seqlen);
502 items += seqlen;
503 for iitem in 0..seqlen {
504 for (ibuf, buffer) in buffers.iter().enumerate() {
505 let _items = buffer.subarrays[iseq].len();
506 if _items == 1 {
507 data.extend_from_slice(
508 &buffer.data[buffer.subarrays[iseq].start * buffer.features
509 ..buffer.subarrays[iseq].end * buffer.features],
510 );
511 } else {
512 if _items != seqlen {
513 return Err(Error::generic(format!(
514 "Buffer {} has {} items for sequence {}, but expected {}",
515 ibuf, _items, iseq, seqlen
516 )));
517 }
518 let start_item = buffer.subarrays[iseq].start + iitem;
519 data.extend_from_slice(
520 &buffer.data[start_item * buffer.features
521 ..(start_item + 1) * buffer.features],
522 );
523 }
524 }
525 }
526 }
527
528 Ok(RaggedBuffer {
529 data,
530 subarrays,
531 features,
532 })
533 }
534 _ => Err(Error::generic(format!(
535 "Invalid dimension {}, RaggedBuffer only has 3 dimensions",
536 dim
537 ))),
538 }
539 }
540
541 #[allow(clippy::type_complexity)]
542 pub fn padpack(&self) -> Option<(Vec<i64>, Vec<f32>, Vec<i64>, (usize, usize))> {
543 if self.subarrays.is_empty()
544 || self
545 .subarrays
546 .iter()
547 .all(|r| r.end - r.start == self.subarrays[0].end - self.subarrays[0].start)
548 {
549 return None;
550 }
551
552 let mut padbpack_index = vec![];
553 let mut padpack_batch = vec![];
554 let mut padpack_inverse_index = vec![];
555 let max_seq_len = self
556 .subarrays
557 .iter()
558 .map(|r| r.end - r.start)
559 .max()
560 .unwrap();
561 let mut sequences: BinaryHeap<Sequence> = binary_heap::BinaryHeap::new();
562
563 for (batch_index, subarray) in self.subarrays.iter().enumerate() {
564 let (free, packed_batch_index) = match sequences.peek().cloned() {
565 Some(seq) if seq.free >= subarray.end - subarray.start => {
566 sequences.pop();
567 (seq.free, seq.batch_index)
568 }
569 _ => {
570 for _ in 0..max_seq_len {
571 padbpack_index.push(0);
572 padpack_batch.push(f32::NAN);
573 }
574 (max_seq_len, sequences.len())
575 }
576 };
577
578 for (i, item) in subarray.clone().enumerate() {
579 let packed_index = packed_batch_index * max_seq_len + max_seq_len - free + i;
580 padbpack_index[packed_index] = item as i64;
581 padpack_batch[packed_index] = batch_index as f32;
582 padpack_inverse_index.push(packed_index as i64);
583 }
584 sequences.push(Sequence {
585 batch_index: packed_batch_index,
586 free: free - (subarray.end - subarray.start),
587 });
588 }
589
590 Some((
591 padbpack_index,
592 padpack_batch,
593 padpack_inverse_index,
594 (sequences.len(), max_seq_len),
595 ))
596 }
597
598 pub fn items(&self) -> usize {
599 self.subarrays.last().map(|r| r.end).unwrap_or(0)
600 }
601
602 pub fn len(&self) -> usize {
603 self.data.len()
604 }
605
606 pub fn is_empty(&self) -> bool {
607 self.data.is_empty()
608 }
609}
610
611#[derive(Copy, Clone, Eq, PartialEq, Debug)]
612struct Sequence {
613 free: usize,
614 batch_index: usize,
615}
616
617impl Ord for Sequence {
618 fn cmp(&self, other: &Self) -> Ordering {
619 self.free
620 .cmp(&other.free)
621 .then_with(|| other.batch_index.cmp(&self.batch_index))
622 }
623}
624
625impl PartialOrd for Sequence {
626 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
627 Some(self.cmp(other))
628 }
629}