1use crate::{Result, Shape, Tensor};
4
5pub trait Dim: crate::shape::Dim + Copy {}
6impl<T: crate::shape::Dim + Copy> Dim for T {}
7
8#[derive(Clone)]
11pub struct StreamTensor(Option<Tensor>);
12
13impl std::fmt::Debug for StreamTensor {
14 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
15 match &self.0 {
16 Some(t) => write!(f, "{:?}", t.shape()),
17 None => write!(f, "Empty"),
18 }
19 }
20}
21
22impl std::convert::From<Option<Tensor>> for StreamTensor {
23 fn from(value: Option<Tensor>) -> Self {
24 Self(value)
25 }
26}
27
28impl std::convert::From<Tensor> for StreamTensor {
29 fn from(value: Tensor) -> Self {
30 Self(Some(value))
31 }
32}
33
34impl std::convert::From<()> for StreamTensor {
35 fn from(_value: ()) -> Self {
36 Self(None)
37 }
38}
39
40impl StreamTensor {
41 pub fn empty() -> Self {
42 Self(None)
43 }
44
45 pub fn from_tensor(tensor: Tensor) -> Self {
46 Self(Some(tensor))
47 }
48
49 pub fn shape(&self) -> Option<&Shape> {
50 self.0.as_ref().map(|t| t.shape())
51 }
52
53 pub fn cat2<D: Dim>(&self, rhs: &Self, dim: D) -> Result<Self> {
54 let xs = match (&self.0, &rhs.0) {
55 (Some(lhs), Some(rhs)) => {
56 let xs = Tensor::cat(&[lhs, rhs], dim)?;
57 Some(xs)
58 }
59 (Some(xs), None) | (None, Some(xs)) => Some(xs.clone()),
60 (None, None) => None,
61 };
62 Ok(Self(xs))
63 }
64
65 pub fn seq_len<D: Dim>(&self, dim: D) -> Result<usize> {
66 match &self.0 {
67 None => Ok(0),
68 Some(v) => v.dim(dim),
69 }
70 }
71
72 pub fn reset(&mut self) {
73 self.0 = None
74 }
75
76 pub fn narrow<D: Dim>(&self, dim: D, offset: usize, len: usize) -> Result<StreamTensor> {
77 let t = match &self.0 {
78 None => None,
79 Some(t) => {
80 let seq_len = t.dim(dim)?;
81 if seq_len <= offset {
82 None
83 } else {
84 let t = t.narrow(dim, offset, usize::min(len, seq_len - offset))?;
85 Some(t)
86 }
87 }
88 };
89 Ok(Self(t))
90 }
91
92 pub fn split<D: Dim>(&self, dim: D, lhs_len: usize) -> Result<(Self, Self)> {
95 match &self.0 {
96 None => Ok((Self::empty(), Self::empty())),
97 Some(t) => {
98 let seq_len = t.dim(dim)?;
99 let lhs_len = usize::min(seq_len, lhs_len);
100 if lhs_len == 0 {
101 Ok((Self::empty(), t.clone().into()))
102 } else {
103 let lhs = Self::from_tensor(t.narrow(dim, 0, lhs_len)?);
104 let rhs_len = seq_len - lhs_len;
105 let rhs = if rhs_len == 0 {
106 Self::empty()
107 } else {
108 Self::from_tensor(t.narrow(dim, lhs_len, rhs_len)?)
109 };
110 Ok((lhs, rhs))
111 }
112 }
113 }
114 }
115
116 pub fn as_option(&self) -> Option<&Tensor> {
117 self.0.as_ref()
118 }
119
120 pub fn apply<M: crate::Module>(&self, m: &M) -> Result<Self> {
121 match &self.0 {
122 None => Ok(Self::empty()),
123 Some(t) => Ok(Self::from_tensor(t.apply(m)?)),
124 }
125 }
126}
127
128pub trait StreamingModule {
132 fn step(&mut self, xs: &StreamTensor) -> Result<StreamTensor>;
134 fn reset_state(&mut self);
135}
136
137#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
138pub enum BinOp {
139 Add,
140 Mul,
141 Sub,
142 Div,
143}
144
145#[derive(Debug, Clone)]
146pub struct StreamingBinOp {
147 prev_lhs: StreamTensor,
148 prev_rhs: StreamTensor,
149 pub op: BinOp,
150 pub dim: crate::D,
151}
152
153impl StreamingBinOp {
154 pub fn new(op: BinOp, dim: crate::D) -> Self {
155 Self {
156 prev_lhs: StreamTensor::empty(),
157 prev_rhs: StreamTensor::empty(),
158 op,
159 dim,
160 }
161 }
162
163 pub fn reset_state(&mut self) {
164 self.prev_lhs.reset();
165 self.prev_rhs.reset();
166 }
167
168 pub fn forward(&self, lhs: &Tensor, rhs: &Tensor) -> Result<Tensor> {
169 match self.op {
170 BinOp::Add => Tensor::add(lhs, rhs),
171 BinOp::Mul => Tensor::mul(lhs, rhs),
172 BinOp::Sub => Tensor::sub(lhs, rhs),
173 BinOp::Div => Tensor::div(lhs, rhs),
174 }
175 }
176
177 pub fn step(&mut self, lhs: &StreamTensor, rhs: &StreamTensor) -> Result<StreamTensor> {
178 let lhs = StreamTensor::cat2(&self.prev_lhs, lhs, self.dim)?;
179 let rhs = StreamTensor::cat2(&self.prev_rhs, rhs, self.dim)?;
180 let lhs_len = lhs.seq_len(self.dim)?;
181 let rhs_len = rhs.seq_len(self.dim)?;
182 let common_len = usize::min(lhs_len, rhs_len);
183 let (lhs, prev_lhs) = lhs.split(self.dim, common_len)?;
184 let (rhs, prev_rhs) = rhs.split(self.dim, common_len)?;
185 let ys = match (lhs.0, rhs.0) {
186 (Some(lhs), Some(rhs)) => {
187 let ys = self.forward(&lhs, &rhs)?;
188 StreamTensor::from_tensor(ys)
189 }
190 (None, None) => StreamTensor::empty(),
191 (lhs, rhs) => crate::bail!("INTERNAL ERROR inconsistent lhs and rhs {lhs:?} {rhs:?}"),
192 };
193 self.prev_lhs = prev_lhs;
194 self.prev_rhs = prev_rhs;
195 Ok(ys)
196 }
197}
198
199pub struct Map<T: crate::Module>(T);
201
202impl<T: crate::Module> StreamingModule for Map<T> {
203 fn reset_state(&mut self) {}
204
205 fn step(&mut self, xs: &StreamTensor) -> Result<StreamTensor> {
206 xs.apply(&self.0)
207 }
208}