1use candle::{Device, IndexOp, Result, Tensor};
6
7pub trait Dim: candle::shape::Dim + Copy {}
8impl<T: candle::shape::Dim + Copy> Dim for T {}
9
10#[derive(Clone)]
11pub struct StreamTensor(Option<Tensor>);
12
13#[derive(Debug, Clone)]
14struct MaskInner {
15 cpu: Vec<bool>,
16 mask: Tensor,
17}
18
19#[derive(Clone)]
20pub struct StreamMask(Option<MaskInner>);
21
22impl std::fmt::Debug for StreamMask {
23 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
24 match &self.0 {
25 Some(t) => write!(f, "{:?}", t.mask.shape()),
26 None => write!(f, "Empty"),
27 }
28 }
29}
30
31impl std::convert::From<()> for StreamMask {
32 fn from(_value: ()) -> Self {
33 Self(None)
34 }
35}
36
37impl StreamMask {
38 pub fn empty() -> Self {
39 Self(None)
40 }
41
42 pub fn new(cpu: Vec<bool>, device: &Device) -> Result<Self> {
43 let mask = cpu.iter().map(|&v| u8::from(v)).collect::<Vec<u8>>();
44 let mask = Tensor::new(mask, device)?;
45 Ok(Self(Some(MaskInner { cpu, mask })))
46 }
47
48 pub fn is_active(&self, batch_idx: usize) -> bool {
49 self.cpu().is_none_or(|v| v[batch_idx])
50 }
51
52 pub fn is_empty(&self) -> bool {
53 self.0.is_none()
54 }
55
56 pub fn shape(&self) -> Option<&candle::Shape> {
57 self.0.as_ref().map(|t| t.mask.shape())
58 }
59
60 pub fn as_option(&self) -> Option<&Tensor> {
61 self.0.as_ref().map(|v| &v.mask)
62 }
63
64 pub fn cpu(&self) -> Option<&[bool]> {
65 self.0.as_ref().map(|v| v.cpu.as_slice())
66 }
67}
68
69impl std::fmt::Debug for StreamTensor {
70 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71 match &self.0 {
72 Some(t) => write!(f, "{:?}", t.shape()),
73 None => write!(f, "Empty"),
74 }
75 }
76}
77
78impl std::convert::From<Option<Tensor>> for StreamTensor {
79 fn from(value: Option<Tensor>) -> Self {
80 Self(value)
81 }
82}
83
84impl std::convert::From<Tensor> for StreamTensor {
85 fn from(value: Tensor) -> Self {
86 Self(Some(value))
87 }
88}
89
90impl std::convert::From<()> for StreamTensor {
91 fn from(_value: ()) -> Self {
92 Self(None)
93 }
94}
95
96impl StreamTensor {
97 pub fn empty() -> Self {
98 Self(None)
99 }
100
101 pub fn is_empty(&self) -> bool {
102 self.0.is_none()
103 }
104
105 pub fn from_tensor(tensor: Tensor) -> Self {
106 Self(Some(tensor))
107 }
108
109 pub fn shape(&self) -> Option<&candle::Shape> {
110 self.0.as_ref().map(|t| t.shape())
111 }
112
113 pub fn cat2<D: Dim>(&self, rhs: &Self, dim: D) -> Result<Self> {
114 let xs = match (&self.0, &rhs.0) {
115 (Some(lhs), Some(rhs)) => {
116 let xs = Tensor::cat(&[lhs, rhs], dim)?;
117 Some(xs)
118 }
119 (Some(xs), None) | (None, Some(xs)) => Some(xs.clone()),
120 (None, None) => None,
121 };
122 Ok(Self(xs))
123 }
124
125 pub fn seq_len<D: Dim>(&self, dim: D) -> Result<usize> {
126 match &self.0 {
127 None => Ok(0),
128 Some(v) => v.dim(dim),
129 }
130 }
131
132 pub fn reset(&mut self) {
133 self.0 = None
134 }
135
136 pub fn narrow<D: Dim>(&self, dim: D, offset: usize, len: usize) -> Result<StreamTensor> {
137 let t = match &self.0 {
138 None => None,
139 Some(t) => {
140 let seq_len = t.dim(dim)?;
141 if seq_len <= offset {
142 None
143 } else {
144 let t = t.narrow(dim, offset, usize::min(len, seq_len - offset))?;
145 Some(t)
146 }
147 }
148 };
149 Ok(Self(t))
150 }
151
152 pub fn split<D: Dim>(&self, dim: D, lhs_len: usize) -> Result<(Self, Self)> {
155 match &self.0 {
156 None => Ok((Self::empty(), Self::empty())),
157 Some(t) => {
158 let seq_len = t.dim(dim)?;
159 let lhs_len = usize::min(seq_len, lhs_len);
160 if lhs_len == 0 {
161 Ok((Self::empty(), t.clone().into()))
162 } else {
163 let lhs = Self::from_tensor(t.narrow(dim, 0, lhs_len)?);
164 let rhs_len = seq_len - lhs_len;
165 let rhs = if rhs_len == 0 {
166 Self::empty()
167 } else {
168 Self::from_tensor(t.narrow(dim, lhs_len, rhs_len)?)
169 };
170 Ok((lhs, rhs))
171 }
172 }
173 }
174 }
175
176 pub fn as_option(&self) -> Option<&Tensor> {
177 self.0.as_ref()
178 }
179
180 pub fn apply<M: candle::Module>(&self, m: &M) -> Result<Self> {
181 match &self.0 {
182 None => Ok(Self::empty()),
183 Some(t) => Ok(Self::from_tensor(t.apply(m)?)),
184 }
185 }
186}
187
188pub trait StreamingModule {
189 fn step(&mut self, xs: &StreamTensor, mask: &StreamMask) -> Result<StreamTensor>;
191 fn reset_state(&mut self);
192}
193
194#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
195pub enum BinOp {
196 Add,
197 Mul,
198 Sub,
199 Div,
200}
201
202#[derive(Debug, Clone)]
203pub struct StreamingBinOp {
204 prev_lhs: StreamTensor,
205 prev_rhs: StreamTensor,
206 pub op: BinOp,
207 pub dim: candle::D,
208}
209
210impl StreamingBinOp {
211 pub fn new(op: BinOp, dim: candle::D) -> Self {
212 Self { prev_lhs: StreamTensor::empty(), prev_rhs: StreamTensor::empty(), op, dim }
213 }
214
215 pub fn reset_state(&mut self) {
216 self.prev_lhs.reset();
217 self.prev_rhs.reset();
218 }
219
220 pub fn forward(&self, lhs: &Tensor, rhs: &Tensor) -> Result<Tensor> {
221 match self.op {
222 BinOp::Add => Tensor::add(lhs, rhs),
223 BinOp::Mul => Tensor::mul(lhs, rhs),
224 BinOp::Sub => Tensor::sub(lhs, rhs),
225 BinOp::Div => Tensor::div(lhs, rhs),
226 }
227 }
228
229 pub fn step(
230 &mut self,
231 lhs: &StreamTensor,
232 rhs: &StreamTensor,
233 mask: &StreamMask,
234 ) -> Result<StreamTensor> {
235 let lhs = StreamTensor::cat2(&self.prev_lhs, lhs, self.dim)?;
236 let rhs = StreamTensor::cat2(&self.prev_rhs, rhs, self.dim)?;
237 let lhs_len = lhs.seq_len(self.dim)?;
238 let rhs_len = rhs.seq_len(self.dim)?;
239 let common_len = usize::min(lhs_len, rhs_len);
240 let (lhs, prev_lhs) = lhs.split(self.dim, common_len)?;
241 let (rhs, prev_rhs) = rhs.split(self.dim, common_len)?;
242 let ys = match (&lhs.0, &rhs.0) {
243 (Some(lhs), Some(rhs)) => {
244 let ys = self.forward(lhs, rhs)?;
245 StreamTensor::from_tensor(ys)
246 }
247 (None, None) => StreamTensor::empty(),
248 (lhs, rhs) => candle::bail!("INTERNAL ERROR inconsistent lhs and rhs {lhs:?} {rhs:?}"),
249 };
250 if !mask.is_empty() && (!prev_lhs.is_empty() || !prev_rhs.is_empty()) {
251 candle::bail!(
252 "cannot use a stream mask with a streaming bin op {prev_lhs:?} {prev_rhs:?} {lhs:?} {rhs:?}"
253 );
254 }
255 self.prev_lhs = prev_lhs;
256 self.prev_rhs = prev_rhs;
257 Ok(ys)
258 }
259
260 pub fn reset_batch_idx(&mut self, batch_idx: usize, _batch_size: usize) -> Result<()> {
261 if let Some(v) = self.prev_lhs.as_option() {
262 let v = v.contiguous()?;
263 v.i(batch_idx..(1 + batch_idx))?.zero_set()?;
264 self.prev_lhs = StreamTensor::from_tensor(v);
265 }
266 if let Some(v) = self.prev_rhs.as_option() {
267 let v = v.contiguous()?;
268 v.i(batch_idx..(1 + batch_idx))?.zero_set()?;
269 self.prev_rhs = StreamTensor::from_tensor(v);
270 }
271 Ok(())
272 }
273}
274
275pub struct Map<T: candle::Module>(T);
277
278impl<T: candle::Module> StreamingModule for Map<T> {
279 fn reset_state(&mut self) {}
280
281 fn step(&mut self, xs: &StreamTensor, _: &StreamMask) -> Result<StreamTensor> {
282 xs.apply(&self.0)
283 }
284}