candle_core/
streaming.rs

1//! StreamTensror useful for streaming ops.
2//!
3use crate::{Result, Shape, Tensor};
4
5pub trait Dim: crate::shape::Dim + Copy {}
6impl<T: crate::shape::Dim + Copy> Dim for T {}
7
8/// A stream tensor is used in streaming module. It can either contain an actual tensor or be
9/// empty.
10#[derive(Clone)]
11pub struct StreamTensor(Option<Tensor>);
12
13impl std::fmt::Debug for StreamTensor {
14    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
15        match &self.0 {
16            Some(t) => write!(f, "{:?}", t.shape()),
17            None => write!(f, "Empty"),
18        }
19    }
20}
21
22impl std::convert::From<Option<Tensor>> for StreamTensor {
23    fn from(value: Option<Tensor>) -> Self {
24        Self(value)
25    }
26}
27
28impl std::convert::From<Tensor> for StreamTensor {
29    fn from(value: Tensor) -> Self {
30        Self(Some(value))
31    }
32}
33
34impl std::convert::From<()> for StreamTensor {
35    fn from(_value: ()) -> Self {
36        Self(None)
37    }
38}
39
40impl StreamTensor {
41    pub fn empty() -> Self {
42        Self(None)
43    }
44
45    pub fn from_tensor(tensor: Tensor) -> Self {
46        Self(Some(tensor))
47    }
48
49    pub fn shape(&self) -> Option<&Shape> {
50        self.0.as_ref().map(|t| t.shape())
51    }
52
53    pub fn cat2<D: Dim>(&self, rhs: &Self, dim: D) -> Result<Self> {
54        let xs = match (&self.0, &rhs.0) {
55            (Some(lhs), Some(rhs)) => {
56                let xs = Tensor::cat(&[lhs, rhs], dim)?;
57                Some(xs)
58            }
59            (Some(xs), None) | (None, Some(xs)) => Some(xs.clone()),
60            (None, None) => None,
61        };
62        Ok(Self(xs))
63    }
64
65    pub fn seq_len<D: Dim>(&self, dim: D) -> Result<usize> {
66        match &self.0 {
67            None => Ok(0),
68            Some(v) => v.dim(dim),
69        }
70    }
71
72    pub fn reset(&mut self) {
73        self.0 = None
74    }
75
76    pub fn narrow<D: Dim>(&self, dim: D, offset: usize, len: usize) -> Result<StreamTensor> {
77        let t = match &self.0 {
78            None => None,
79            Some(t) => {
80                let seq_len = t.dim(dim)?;
81                if seq_len <= offset {
82                    None
83                } else {
84                    let t = t.narrow(dim, offset, usize::min(len, seq_len - offset))?;
85                    Some(t)
86                }
87            }
88        };
89        Ok(Self(t))
90    }
91
92    /// Splits the Streaming Tensor on the time axis `dim` with the first `lhs_len` elements
93    /// returned in the first output and the remaining in the second output.
94    pub fn split<D: Dim>(&self, dim: D, lhs_len: usize) -> Result<(Self, Self)> {
95        match &self.0 {
96            None => Ok((Self::empty(), Self::empty())),
97            Some(t) => {
98                let seq_len = t.dim(dim)?;
99                let lhs_len = usize::min(seq_len, lhs_len);
100                if lhs_len == 0 {
101                    Ok((Self::empty(), t.clone().into()))
102                } else {
103                    let lhs = Self::from_tensor(t.narrow(dim, 0, lhs_len)?);
104                    let rhs_len = seq_len - lhs_len;
105                    let rhs = if rhs_len == 0 {
106                        Self::empty()
107                    } else {
108                        Self::from_tensor(t.narrow(dim, lhs_len, rhs_len)?)
109                    };
110                    Ok((lhs, rhs))
111                }
112            }
113        }
114    }
115
116    pub fn as_option(&self) -> Option<&Tensor> {
117        self.0.as_ref()
118    }
119
120    pub fn apply<M: crate::Module>(&self, m: &M) -> Result<Self> {
121        match &self.0 {
122            None => Ok(Self::empty()),
123            Some(t) => Ok(Self::from_tensor(t.apply(m)?)),
124        }
125    }
126}
127
128/// Streaming modules take as input a stream tensor and return a stream tensor. They may perform
129/// some internal buffering so that enough data has been received for the module to be able to
130/// perform some operations.
131pub trait StreamingModule {
132    // TODO: Should we also have a flush method?
133    fn step(&mut self, xs: &StreamTensor) -> Result<StreamTensor>;
134    fn reset_state(&mut self);
135}
136
137#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
138pub enum BinOp {
139    Add,
140    Mul,
141    Sub,
142    Div,
143}
144
145#[derive(Debug, Clone)]
146pub struct StreamingBinOp {
147    prev_lhs: StreamTensor,
148    prev_rhs: StreamTensor,
149    pub op: BinOp,
150    pub dim: crate::D,
151}
152
153impl StreamingBinOp {
154    pub fn new(op: BinOp, dim: crate::D) -> Self {
155        Self {
156            prev_lhs: StreamTensor::empty(),
157            prev_rhs: StreamTensor::empty(),
158            op,
159            dim,
160        }
161    }
162
163    pub fn reset_state(&mut self) {
164        self.prev_lhs.reset();
165        self.prev_rhs.reset();
166    }
167
168    pub fn forward(&self, lhs: &Tensor, rhs: &Tensor) -> Result<Tensor> {
169        match self.op {
170            BinOp::Add => Tensor::add(lhs, rhs),
171            BinOp::Mul => Tensor::mul(lhs, rhs),
172            BinOp::Sub => Tensor::sub(lhs, rhs),
173            BinOp::Div => Tensor::div(lhs, rhs),
174        }
175    }
176
177    pub fn step(&mut self, lhs: &StreamTensor, rhs: &StreamTensor) -> Result<StreamTensor> {
178        let lhs = StreamTensor::cat2(&self.prev_lhs, lhs, self.dim)?;
179        let rhs = StreamTensor::cat2(&self.prev_rhs, rhs, self.dim)?;
180        let lhs_len = lhs.seq_len(self.dim)?;
181        let rhs_len = rhs.seq_len(self.dim)?;
182        let common_len = usize::min(lhs_len, rhs_len);
183        let (lhs, prev_lhs) = lhs.split(self.dim, common_len)?;
184        let (rhs, prev_rhs) = rhs.split(self.dim, common_len)?;
185        let ys = match (lhs.0, rhs.0) {
186            (Some(lhs), Some(rhs)) => {
187                let ys = self.forward(&lhs, &rhs)?;
188                StreamTensor::from_tensor(ys)
189            }
190            (None, None) => StreamTensor::empty(),
191            (lhs, rhs) => crate::bail!("INTERNAL ERROR inconsistent lhs and rhs {lhs:?} {rhs:?}"),
192        };
193        self.prev_lhs = prev_lhs;
194        self.prev_rhs = prev_rhs;
195        Ok(ys)
196    }
197}
198
199/// Simple wrapper that doesn't do any buffering.
200pub struct Map<T: crate::Module>(T);
201
202impl<T: crate::Module> StreamingModule for Map<T> {
203    fn reset_state(&mut self) {}
204
205    fn step(&mut self, xs: &StreamTensor) -> Result<StreamTensor> {
206        xs.apply(&self.0)
207    }
208}