Skip to main content

burn_fusion/stream/
multi.rs

1use std::sync::Arc;
2
3use burn_ir::{HandleContainer, OperationIr, TensorId, TensorIr, TensorStatus};
4use hashbrown::{HashMap, HashSet};
5
6use super::{
7    StreamId,
8    execution::{ExecutionMode, Operation, Processor, StreamSegment},
9    queue::OperationQueue,
10    shared_tensors::SharedTensors,
11    store::{ExecutionPlanId, ExecutionPlanStore},
12};
13use crate::{
14    DropOp, FusionRuntime,
15    stream::shared_tensors::{SharedTensorAnalysis, SharedTensorDropAction},
16};
17
18/// Keep track of multiple concurrent lazy streams of operations.
19pub struct MultiStream<R: FusionRuntime> {
20    streams: HashMap<StreamId, Stream<R>>,
21    optimizations: ExecutionPlanStore<R::Optimization>,
22    shared_tensors: SharedTensors,
23    device: R::FusionDevice,
24    #[cfg(feature = "memory-checks")]
25    memory_checks: super::memory_checks::MemoryChecks,
26}
27
28#[derive(Debug)]
29enum DropAction {
30    SkipSharedTensor,
31    ForceSharedTensor(Vec<StreamId>, TensorId),
32    ContinueDrop,
33}
34
35impl<R: FusionRuntime> MultiStream<R> {
36    pub(crate) fn new(device: R::FusionDevice) -> Self {
37        Self {
38            streams: HashMap::new(),
39            optimizations: ExecutionPlanStore::new(),
40            shared_tensors: SharedTensors::default(),
41            device,
42            #[cfg(feature = "memory-checks")]
43            memory_checks: super::memory_checks::MemoryChecks::default(),
44        }
45    }
46
47    /// Register a new tensor operation.
48    pub(crate) fn register(
49        &mut self,
50        streams: OperationStreams,
51        mut repr: OperationIr,
52        operation: Arc<dyn Operation<R>>,
53        handles: &mut HandleContainer<R::FusionHandle>,
54    ) {
55        let id = self.resolve_streams(&streams, handles, &mut repr);
56
57        let drop_action = match &mut repr {
58            OperationIr::Drop(tensor_ir) => Some(self.handle_drop_op(id, tensor_ir)),
59            _ => None,
60        };
61
62        let sync = match drop_action {
63            Some(DropAction::SkipSharedTensor) => return,
64            Some(DropAction::ContinueDrop) => true,
65            Some(DropAction::ForceSharedTensor(stream_ids, tid)) => {
66                for stream_id in stream_ids {
67                    if let Some(stream) = self.streams.get_mut(&stream_id) {
68                        stream.queue.variables.remove(&tid);
69                        if stream.queue.variables.is_empty() {
70                            self.streams.remove(&stream_id);
71                        }
72                    }
73                }
74                true
75            }
76            None => false,
77        };
78
79        let num_executed = self.enqueue_operation(id, repr, &streams, operation, handles);
80
81        if num_executed > 0
82            && let Some(stream) = self.streams.get_mut(&id)
83        {
84            let cleared = self.shared_tensors.on_executed_ops(id, stream);
85            self.clear_shared_tensors(&cleared, id);
86            let to_drop = self.shared_tensors.clear_tensors(cleared);
87            self.drop_shared_tensors(to_drop, handles, id);
88        }
89
90        let stream = match self.streams.get(&id) {
91            Some(val) => val,
92            None => {
93                #[cfg(feature = "memory-checks")]
94                self.memory_checks.check(&self.streams, handles);
95                return;
96            }
97        };
98
99        if !stream.queue.variables.is_empty() && sync {
100            // Not draining the queue can cause a memory leak when a stream is closing.
101            self.drain(handles, id);
102        }
103
104        #[cfg(feature = "memory-checks")]
105        self.memory_checks.check(&self.streams, handles);
106    }
107
108    /// Checks if the current operation is a drop.
109    ///
110    /// When a tensor is shared across multiple concurrent streams, dropping a tensor might cause a
111    /// problem when the same tensor is registered lazily on another stream, but not yet executed.
112    fn handle_drop_op(&mut self, id: StreamId, tensor_ir: &mut TensorIr) -> DropAction {
113        match !matches!(tensor_ir.status, TensorStatus::ReadWrite) {
114            true => {
115                let stream = self.streams.get(&id);
116                let on_drop = self
117                    .shared_tensors
118                    .on_drop(id, tensor_ir.id, stream.is_none());
119
120                match on_drop {
121                    SharedTensorDropAction::ForceDrop(streams) => {
122                        tensor_ir.status = TensorStatus::ReadWrite;
123                        DropAction::ForceSharedTensor(streams, tensor_ir.id)
124                    }
125                    SharedTensorDropAction::Skip => DropAction::SkipSharedTensor,
126                }
127            }
128            false => DropAction::ContinueDrop,
129        }
130    }
131
132    /// Enqueue an operation on the queue.
133    fn enqueue_operation(
134        &mut self,
135        id: StreamId,
136        repr: OperationIr,
137        streams: &OperationStreams,
138        operation: Arc<dyn Operation<R>>,
139        handles: &mut HandleContainer<R::FusionHandle>,
140    ) -> usize {
141        let stream = match self.streams.get_mut(&id) {
142            Some(stream) => stream,
143            None => {
144                let stream = Stream::new(self.device.clone());
145                self.streams.insert(id, stream);
146                self.streams
147                    .get_mut(&id)
148                    .expect("Just added, so should be included in the hashmap.")
149            }
150        };
151
152        stream.queue.add(repr, operation, streams, id);
153
154        let len_before = stream.queue.global.len();
155        stream.processor.process(
156            Segment::new(&mut stream.queue, handles),
157            &mut self.optimizations,
158            ExecutionMode::Lazy,
159        );
160        let len_after = stream.queue.global.len();
161        let num_executed = len_before - len_after;
162
163        stream.cursor += num_executed as u64;
164
165        num_executed
166    }
167
168    /// Mark a tensor as read.
169    #[allow(unused_variables)]
170    pub fn mark_read(
171        &mut self,
172        id: StreamId,
173        ir: &TensorIr,
174        handles: &HandleContainer<R::FusionHandle>,
175    ) {
176        if !matches!(ir.status, TensorStatus::ReadWrite) {
177            return;
178        };
179
180        let stream = match self.streams.get_mut(&id) {
181            Some(val) => val,
182            None => return,
183        };
184
185        stream.queue.variables.remove(&ir.id);
186
187        if stream.queue.variables.is_empty() {
188            self.streams.remove(&id);
189        }
190
191        #[cfg(feature = "memory-checks")]
192        self.memory_checks.check(&self.streams, handles);
193    }
194
195    /// Drain a stream
196    pub fn drain(&mut self, handles: &mut HandleContainer<R::FusionHandle>, id: StreamId) {
197        if let Some(stream) = self.streams.get_mut(&id) {
198            let old = unsafe { StreamId::swap(id) };
199            let num_executed = stream.queue.global.len();
200            stream.processor.process(
201                Segment::new(&mut stream.queue, handles),
202                &mut self.optimizations,
203                ExecutionMode::Sync,
204            );
205            stream.cursor += num_executed as u64;
206
207            let cleared = self.shared_tensors.on_executed_ops(id, stream);
208            self.clear_shared_tensors(&cleared, id);
209            let to_drop = self.shared_tensors.clear_tensors(cleared);
210
211            self.drop_shared_tensors(to_drop, handles, id);
212            unsafe {
213                StreamId::swap(old);
214            };
215        }
216    }
217
218    /// When one of the provided streams is different from the current stream, we drain them.
219    ///
220    /// Returns the selected stream id.
221    fn resolve_streams(
222        &mut self,
223        streams: &OperationStreams,
224        handles: &mut HandleContainer<R::FusionHandle>,
225        op: &mut OperationIr,
226    ) -> StreamId {
227        let current = streams.current;
228        let nodes = op.nodes();
229
230        let analysis = self.analyse_shared_tensors(&nodes, streams, current);
231
232        self.merge_streams_timelines(handles, &analysis, current, &nodes);
233        self.register_shared_tensors_drop(&analysis, op);
234
235        current
236    }
237
238    /// Drain the stream only if one of the tensor in the given nodes is also included in the
239    /// stream queue.
240    fn resolve_stream(
241        &mut self,
242        handles: &mut HandleContainer<R::FusionHandle>,
243        id: StreamId,
244        nodes: &[&TensorIr],
245    ) {
246        if let Some(stream) = self.streams.get(&id) {
247            for node in nodes {
248                if stream.queue.variables.contains_key(&node.id) {
249                    self.drain(handles, id);
250                    return;
251                }
252            }
253        }
254    }
255
256    fn analyse_shared_tensors(
257        &mut self,
258        nodes: &[&TensorIr],
259        streams: &OperationStreams,
260        current: StreamId,
261    ) -> MultiSharedTensorAnalysis {
262        let mut shared_analysis = MultiSharedTensorAnalysis::default();
263
264        for node in nodes.iter() {
265            let analysis = self
266                .shared_tensors
267                .analyse(current, node, streams, &self.streams);
268            match analysis {
269                SharedTensorAnalysis::SharedFromCurrentStream => {
270                    shared_analysis.current.push(node.id);
271                }
272                SharedTensorAnalysis::NotShared => {}
273                SharedTensorAnalysis::SharedFromExistingStream {
274                    stream_id,
275                    original_cursor,
276                } => {
277                    shared_analysis
278                        .existing
279                        .push((node.id, stream_id, original_cursor));
280                }
281                SharedTensorAnalysis::SharedFromNewStream { stream_id } => {
282                    shared_analysis.new.push((node.id, stream_id));
283                }
284            }
285        }
286
287        shared_analysis
288    }
289
290    fn merge_streams_timelines(
291        &mut self,
292        handles: &mut HandleContainer<R::FusionHandle>,
293        analysis: &MultiSharedTensorAnalysis,
294        current: StreamId,
295        nodes: &[&TensorIr],
296    ) {
297        // If we only have current tensors that are shared, we're safe to not sync the timelines.
298        if analysis.new.is_empty() && analysis.existing.is_empty() {
299            return;
300        }
301
302        let mut streams_to_sync = HashSet::new();
303        for (_tensor_id, stream_id) in analysis.new.iter() {
304            streams_to_sync.insert(*stream_id);
305        }
306
307        for (_tensor_id, stream_id, original_cursor) in analysis.existing.iter() {
308            if let Some(stream) = self.streams.get(stream_id) {
309                // We only have to sync a stream when the stream isn't up to date with
310                // the original cursor of the current operation.
311                if stream.cursor <= *original_cursor && *stream_id != current {
312                    streams_to_sync.insert(*stream_id);
313                }
314            }
315        }
316
317        for id in streams_to_sync.drain() {
318            log::trace!("Drain stream {id} for use in current {current}");
319            self.resolve_stream(handles, id, nodes);
320        }
321    }
322
323    fn register_shared_tensors_drop(
324        &mut self,
325        analysis: &MultiSharedTensorAnalysis,
326        op: &mut OperationIr,
327    ) {
328        let mut readonly_tensors = Vec::new();
329
330        for (tensor_id, _stream_id) in analysis.new.iter() {
331            readonly_tensors.push(*tensor_id);
332        }
333        for (tensor_id, _stream_id, _cursor) in analysis.existing.iter() {
334            readonly_tensors.push(*tensor_id);
335        }
336        for tensor_id in analysis.current.iter() {
337            readonly_tensors.push(*tensor_id);
338        }
339
340        self.shared_tensors
341            .tag_manual_drop(op.mark_read_only(&readonly_tensors));
342    }
343
344    fn drop_shared_tensors(
345        &mut self,
346        tensors: Vec<TensorIr>,
347        handles: &mut HandleContainer<R::FusionHandle>,
348        current: StreamId,
349    ) {
350        for (stream_id, s) in self.streams.iter_mut() {
351            for tensor in tensors.iter() {
352                if let Some((original, _status)) = s.queue.variables.get(&tensor.id)
353                    && original != stream_id
354                {
355                    s.queue.variables.remove(&tensor.id);
356                }
357            }
358        }
359        for tensor in tensors {
360            let streams = OperationStreams {
361                streams: HashMap::new(),
362                current,
363            };
364
365            let op = Arc::new(DropOp { id: tensor.id });
366            self.register(streams, OperationIr::Drop(tensor), op, handles);
367        }
368    }
369    fn clear_shared_tensors(&mut self, tensors: &[TensorId], current: StreamId) {
370        let mut to_remove = Vec::new();
371        for (stream_id, s) in self.streams.iter_mut() {
372            for tensor in tensors.iter() {
373                s.queue.variables.remove(tensor);
374            }
375
376            if s.queue.variables.is_empty() && current != *stream_id {
377                to_remove.push(*stream_id);
378            }
379        }
380
381        for s in to_remove {
382            self.streams.remove(&s);
383        }
384    }
385}
386
387pub(crate) struct Stream<R: FusionRuntime> {
388    pub(crate) queue: OperationQueue<R>,
389    processor: Processor<R::Optimization>,
390    pub(crate) cursor: u64,
391}
392
393#[derive(new)]
394struct Segment<'a, R: FusionRuntime> {
395    queue: &'a mut OperationQueue<R>,
396    handles: &'a mut HandleContainer<R::FusionHandle>,
397}
398
399impl<R: FusionRuntime> StreamSegment<R::Optimization> for Segment<'_, R> {
400    fn operations(&self) -> &[OperationIr] {
401        &self.queue.relative
402    }
403
404    fn execute(&mut self, id: ExecutionPlanId, store: &mut ExecutionPlanStore<R::Optimization>) {
405        self.queue.execute(id, self.handles, store)
406    }
407}
408
409impl<R: FusionRuntime> Stream<R> {
410    fn new(device: R::FusionDevice) -> Self {
411        Self {
412            processor: Processor::new(R::fusers(device)),
413            queue: OperationQueue::new(),
414            cursor: 0,
415        }
416    }
417}
418
419#[derive(Debug)]
420/// Manage the streams used for the current [operation](OperationIr).
421pub struct OperationStreams {
422    pub(crate) streams: HashMap<TensorId, StreamId>,
423    pub(crate) current: StreamId,
424}
425
426impl Default for OperationStreams {
427    fn default() -> Self {
428        Self {
429            streams: HashMap::new(),
430            current: StreamId::current(),
431        }
432    }
433}
434
435impl OperationStreams {
436    /// Register a tensor in the list of streams used for the current [operation](OperationIr).
437    ///
438    /// You only need to register input tensors, not the outputs.
439    /// So init tensor operations should have no streams registered.
440    pub fn tensor<R: FusionRuntime>(&mut self, tensor: &crate::FusionTensor<R>) {
441        self.streams.insert(tensor.id, tensor.stream);
442    }
443
444    pub(crate) fn get(&self, id: TensorId) -> Option<StreamId> {
445        self.streams.get(&id).cloned()
446    }
447
448    /// Create new operation streams with the given inputs.
449    ///
450    /// The inputs are automatically registered.
451    pub fn with_inputs<'a, R: FusionRuntime + 'a, I>(tensors: I) -> Self
452    where
453        I: IntoIterator<Item = &'a crate::FusionTensor<R>>,
454    {
455        let mut streams = OperationStreams::default();
456        for tensor in tensors.into_iter() {
457            streams.tensor(tensor)
458        }
459        streams
460    }
461}
462
463#[derive(Default, Debug)]
464struct MultiSharedTensorAnalysis {
465    /// Tensors that are shared with other streams, but we're currently executing on the same stream
466    /// the tensor was originally created.
467    current: Vec<TensorId>,
468    /// Tensors that are shared with new streams.
469    new: Vec<(TensorId, StreamId)>,
470    /// Tensors that are shared with existing streams.
471    existing: Vec<(TensorId, StreamId, u64)>,
472}