1use std::sync::Arc;
2
3use burn_ir::{HandleContainer, OperationIr, TensorId, TensorIr, TensorStatus};
4use hashbrown::{HashMap, HashSet};
5
6use super::{
7 StreamId,
8 execution::{ExecutionMode, Operation, Processor, StreamSegment},
9 queue::OperationQueue,
10 shared_tensors::SharedTensors,
11 store::{ExecutionPlanId, ExecutionPlanStore},
12};
13use crate::{
14 DropOp, FusionRuntime,
15 stream::shared_tensors::{SharedTensorAnalysis, SharedTensorDropAction},
16};
17
18pub struct MultiStream<R: FusionRuntime> {
20 streams: HashMap<StreamId, Stream<R>>,
21 optimizations: ExecutionPlanStore<R::Optimization>,
22 shared_tensors: SharedTensors,
23 device: R::FusionDevice,
24 #[cfg(feature = "memory-checks")]
25 memory_checks: super::memory_checks::MemoryChecks,
26}
27
28#[derive(Debug)]
29enum DropAction {
30 SkipSharedTensor,
31 ForceSharedTensor(Vec<StreamId>, TensorId),
32 ContinueDrop,
33}
34
35impl<R: FusionRuntime> MultiStream<R> {
36 pub(crate) fn new(device: R::FusionDevice) -> Self {
37 Self {
38 streams: HashMap::new(),
39 optimizations: ExecutionPlanStore::new(),
40 shared_tensors: SharedTensors::default(),
41 device,
42 #[cfg(feature = "memory-checks")]
43 memory_checks: super::memory_checks::MemoryChecks::default(),
44 }
45 }
46
47 pub(crate) fn register(
49 &mut self,
50 streams: OperationStreams,
51 mut repr: OperationIr,
52 operation: Arc<dyn Operation<R>>,
53 handles: &mut HandleContainer<R::FusionHandle>,
54 ) {
55 let id = self.resolve_streams(&streams, handles, &mut repr);
56
57 let drop_action = match &mut repr {
58 OperationIr::Drop(tensor_ir) => Some(self.handle_drop_op(id, tensor_ir)),
59 _ => None,
60 };
61
62 let sync = match drop_action {
63 Some(DropAction::SkipSharedTensor) => return,
64 Some(DropAction::ContinueDrop) => true,
65 Some(DropAction::ForceSharedTensor(stream_ids, tid)) => {
66 for stream_id in stream_ids {
67 if let Some(stream) = self.streams.get_mut(&stream_id) {
68 stream.queue.variables.remove(&tid);
69 if stream.queue.variables.is_empty() {
70 self.streams.remove(&stream_id);
71 }
72 }
73 }
74 true
75 }
76 None => false,
77 };
78
79 let num_executed = self.enqueue_operation(id, repr, &streams, operation, handles);
80
81 if num_executed > 0
82 && let Some(stream) = self.streams.get_mut(&id)
83 {
84 let cleared = self.shared_tensors.on_executed_ops(id, stream);
85 self.clear_shared_tensors(&cleared, id);
86 let to_drop = self.shared_tensors.clear_tensors(cleared);
87 self.drop_shared_tensors(to_drop, handles, id);
88 }
89
90 let stream = match self.streams.get(&id) {
91 Some(val) => val,
92 None => {
93 #[cfg(feature = "memory-checks")]
94 self.memory_checks.check(&self.streams, handles);
95 return;
96 }
97 };
98
99 if !stream.queue.variables.is_empty() && sync {
100 self.drain(handles, id);
102 }
103
104 #[cfg(feature = "memory-checks")]
105 self.memory_checks.check(&self.streams, handles);
106 }
107
108 fn handle_drop_op(&mut self, id: StreamId, tensor_ir: &mut TensorIr) -> DropAction {
113 match !matches!(tensor_ir.status, TensorStatus::ReadWrite) {
114 true => {
115 let stream = self.streams.get(&id);
116 let on_drop = self
117 .shared_tensors
118 .on_drop(id, tensor_ir.id, stream.is_none());
119
120 match on_drop {
121 SharedTensorDropAction::ForceDrop(streams) => {
122 tensor_ir.status = TensorStatus::ReadWrite;
123 DropAction::ForceSharedTensor(streams, tensor_ir.id)
124 }
125 SharedTensorDropAction::Skip => DropAction::SkipSharedTensor,
126 }
127 }
128 false => DropAction::ContinueDrop,
129 }
130 }
131
132 fn enqueue_operation(
134 &mut self,
135 id: StreamId,
136 repr: OperationIr,
137 streams: &OperationStreams,
138 operation: Arc<dyn Operation<R>>,
139 handles: &mut HandleContainer<R::FusionHandle>,
140 ) -> usize {
141 let stream = match self.streams.get_mut(&id) {
142 Some(stream) => stream,
143 None => {
144 let stream = Stream::new(self.device.clone());
145 self.streams.insert(id, stream);
146 self.streams
147 .get_mut(&id)
148 .expect("Just added, so should be included in the hashmap.")
149 }
150 };
151
152 stream.queue.add(repr, operation, streams, id);
153
154 let len_before = stream.queue.global.len();
155 stream.processor.process(
156 Segment::new(&mut stream.queue, handles),
157 &mut self.optimizations,
158 ExecutionMode::Lazy,
159 );
160 let len_after = stream.queue.global.len();
161 let num_executed = len_before - len_after;
162
163 stream.cursor += num_executed as u64;
164
165 num_executed
166 }
167
168 #[allow(unused_variables)]
170 pub fn mark_read(
171 &mut self,
172 id: StreamId,
173 ir: &TensorIr,
174 handles: &HandleContainer<R::FusionHandle>,
175 ) {
176 if !matches!(ir.status, TensorStatus::ReadWrite) {
177 return;
178 };
179
180 let stream = match self.streams.get_mut(&id) {
181 Some(val) => val,
182 None => return,
183 };
184
185 stream.queue.variables.remove(&ir.id);
186
187 if stream.queue.variables.is_empty() {
188 self.streams.remove(&id);
189 }
190
191 #[cfg(feature = "memory-checks")]
192 self.memory_checks.check(&self.streams, handles);
193 }
194
195 pub fn drain(&mut self, handles: &mut HandleContainer<R::FusionHandle>, id: StreamId) {
197 if let Some(stream) = self.streams.get_mut(&id) {
198 let old = unsafe { StreamId::swap(id) };
199 let num_executed = stream.queue.global.len();
200 stream.processor.process(
201 Segment::new(&mut stream.queue, handles),
202 &mut self.optimizations,
203 ExecutionMode::Sync,
204 );
205 stream.cursor += num_executed as u64;
206
207 let cleared = self.shared_tensors.on_executed_ops(id, stream);
208 self.clear_shared_tensors(&cleared, id);
209 let to_drop = self.shared_tensors.clear_tensors(cleared);
210
211 self.drop_shared_tensors(to_drop, handles, id);
212 unsafe {
213 StreamId::swap(old);
214 };
215 }
216 }
217
218 fn resolve_streams(
222 &mut self,
223 streams: &OperationStreams,
224 handles: &mut HandleContainer<R::FusionHandle>,
225 op: &mut OperationIr,
226 ) -> StreamId {
227 let current = streams.current;
228 let nodes = op.nodes();
229
230 let analysis = self.analyse_shared_tensors(&nodes, streams, current);
231
232 self.merge_streams_timelines(handles, &analysis, current, &nodes);
233 self.register_shared_tensors_drop(&analysis, op);
234
235 current
236 }
237
238 fn resolve_stream(
241 &mut self,
242 handles: &mut HandleContainer<R::FusionHandle>,
243 id: StreamId,
244 nodes: &[&TensorIr],
245 ) {
246 if let Some(stream) = self.streams.get(&id) {
247 for node in nodes {
248 if stream.queue.variables.contains_key(&node.id) {
249 self.drain(handles, id);
250 return;
251 }
252 }
253 }
254 }
255
256 fn analyse_shared_tensors(
257 &mut self,
258 nodes: &[&TensorIr],
259 streams: &OperationStreams,
260 current: StreamId,
261 ) -> MultiSharedTensorAnalysis {
262 let mut shared_analysis = MultiSharedTensorAnalysis::default();
263
264 for node in nodes.iter() {
265 let analysis = self
266 .shared_tensors
267 .analyse(current, node, streams, &self.streams);
268 match analysis {
269 SharedTensorAnalysis::SharedFromCurrentStream => {
270 shared_analysis.current.push(node.id);
271 }
272 SharedTensorAnalysis::NotShared => {}
273 SharedTensorAnalysis::SharedFromExistingStream {
274 stream_id,
275 original_cursor,
276 } => {
277 shared_analysis
278 .existing
279 .push((node.id, stream_id, original_cursor));
280 }
281 SharedTensorAnalysis::SharedFromNewStream { stream_id } => {
282 shared_analysis.new.push((node.id, stream_id));
283 }
284 }
285 }
286
287 shared_analysis
288 }
289
290 fn merge_streams_timelines(
291 &mut self,
292 handles: &mut HandleContainer<R::FusionHandle>,
293 analysis: &MultiSharedTensorAnalysis,
294 current: StreamId,
295 nodes: &[&TensorIr],
296 ) {
297 if analysis.new.is_empty() && analysis.existing.is_empty() {
299 return;
300 }
301
302 let mut streams_to_sync = HashSet::new();
303 for (_tensor_id, stream_id) in analysis.new.iter() {
304 streams_to_sync.insert(*stream_id);
305 }
306
307 for (_tensor_id, stream_id, original_cursor) in analysis.existing.iter() {
308 if let Some(stream) = self.streams.get(stream_id) {
309 if stream.cursor <= *original_cursor && *stream_id != current {
312 streams_to_sync.insert(*stream_id);
313 }
314 }
315 }
316
317 for id in streams_to_sync.drain() {
318 log::trace!("Drain stream {id} for use in current {current}");
319 self.resolve_stream(handles, id, nodes);
320 }
321 }
322
323 fn register_shared_tensors_drop(
324 &mut self,
325 analysis: &MultiSharedTensorAnalysis,
326 op: &mut OperationIr,
327 ) {
328 let mut readonly_tensors = Vec::new();
329
330 for (tensor_id, _stream_id) in analysis.new.iter() {
331 readonly_tensors.push(*tensor_id);
332 }
333 for (tensor_id, _stream_id, _cursor) in analysis.existing.iter() {
334 readonly_tensors.push(*tensor_id);
335 }
336 for tensor_id in analysis.current.iter() {
337 readonly_tensors.push(*tensor_id);
338 }
339
340 self.shared_tensors
341 .tag_manual_drop(op.mark_read_only(&readonly_tensors));
342 }
343
344 fn drop_shared_tensors(
345 &mut self,
346 tensors: Vec<TensorIr>,
347 handles: &mut HandleContainer<R::FusionHandle>,
348 current: StreamId,
349 ) {
350 for (stream_id, s) in self.streams.iter_mut() {
351 for tensor in tensors.iter() {
352 if let Some((original, _status)) = s.queue.variables.get(&tensor.id)
353 && original != stream_id
354 {
355 s.queue.variables.remove(&tensor.id);
356 }
357 }
358 }
359 for tensor in tensors {
360 let streams = OperationStreams {
361 streams: HashMap::new(),
362 current,
363 };
364
365 let op = Arc::new(DropOp { id: tensor.id });
366 self.register(streams, OperationIr::Drop(tensor), op, handles);
367 }
368 }
369 fn clear_shared_tensors(&mut self, tensors: &[TensorId], current: StreamId) {
370 let mut to_remove = Vec::new();
371 for (stream_id, s) in self.streams.iter_mut() {
372 for tensor in tensors.iter() {
373 s.queue.variables.remove(tensor);
374 }
375
376 if s.queue.variables.is_empty() && current != *stream_id {
377 to_remove.push(*stream_id);
378 }
379 }
380
381 for s in to_remove {
382 self.streams.remove(&s);
383 }
384 }
385}
386
387pub(crate) struct Stream<R: FusionRuntime> {
388 pub(crate) queue: OperationQueue<R>,
389 processor: Processor<R::Optimization>,
390 pub(crate) cursor: u64,
391}
392
393#[derive(new)]
394struct Segment<'a, R: FusionRuntime> {
395 queue: &'a mut OperationQueue<R>,
396 handles: &'a mut HandleContainer<R::FusionHandle>,
397}
398
399impl<R: FusionRuntime> StreamSegment<R::Optimization> for Segment<'_, R> {
400 fn operations(&self) -> &[OperationIr] {
401 &self.queue.relative
402 }
403
404 fn execute(&mut self, id: ExecutionPlanId, store: &mut ExecutionPlanStore<R::Optimization>) {
405 self.queue.execute(id, self.handles, store)
406 }
407}
408
409impl<R: FusionRuntime> Stream<R> {
410 fn new(device: R::FusionDevice) -> Self {
411 Self {
412 processor: Processor::new(R::fusers(device)),
413 queue: OperationQueue::new(),
414 cursor: 0,
415 }
416 }
417}
418
419#[derive(Debug)]
420pub struct OperationStreams {
422 pub(crate) streams: HashMap<TensorId, StreamId>,
423 pub(crate) current: StreamId,
424}
425
426impl Default for OperationStreams {
427 fn default() -> Self {
428 Self {
429 streams: HashMap::new(),
430 current: StreamId::current(),
431 }
432 }
433}
434
435impl OperationStreams {
436 pub fn tensor<R: FusionRuntime>(&mut self, tensor: &crate::FusionTensor<R>) {
441 self.streams.insert(tensor.id, tensor.stream);
442 }
443
444 pub(crate) fn get(&self, id: TensorId) -> Option<StreamId> {
445 self.streams.get(&id).cloned()
446 }
447
448 pub fn with_inputs<'a, R: FusionRuntime + 'a, I>(tensors: I) -> Self
452 where
453 I: IntoIterator<Item = &'a crate::FusionTensor<R>>,
454 {
455 let mut streams = OperationStreams::default();
456 for tensor in tensors.into_iter() {
457 streams.tensor(tensor)
458 }
459 streams
460 }
461}
462
463#[derive(Default, Debug)]
464struct MultiSharedTensorAnalysis {
465 current: Vec<TensorId>,
468 new: Vec<(TensorId, StreamId)>,
470 existing: Vec<(TensorId, StreamId, u64)>,
472}