burn_train/learner/step/
train.rs1use crate::{TrainOutput, TrainStep};
2use burn_core::data::dataloader::Progress;
3use burn_core::{
4 data::dataloader::DataLoaderIterator, module::AutodiffModule, tensor::backend::AutodiffBackend,
5};
6use std::sync::mpsc::{Receiver, Sender};
7use std::thread::spawn;
8
9pub struct MultiDevicesTrainStep<B: AutodiffBackend, M, TI, TO> {
11 workers: Vec<Worker<B, M, TI>>,
12 receiver: Receiver<TrainOutput<TO>>,
13}
14
15struct Message<M, TI> {
16 item: TI,
17 model: M,
18}
19
20struct Worker<B: AutodiffBackend, M, TI> {
21 sender_input: Sender<Message<M, TI>>,
22 device: B::Device,
23}
24
25impl<B, M, TI> Worker<B, M, TI>
26where
27 B: AutodiffBackend,
28 M: AutodiffModule<B>,
29{
30 fn register(&self, item: TI, model: &M) {
31 let message = Message {
32 item,
33 model: model.clone(),
34 };
35 self.sender_input.send(message).unwrap();
36 }
37
38 fn start<TO>(
39 &self,
40 sender_output: Sender<TrainOutput<TO>>,
41 receiver_input: Receiver<Message<M, TI>>,
42 ) where
43 TI: Send + 'static,
44 TO: Send + 'static,
45 M: TrainStep<TI, TO> + Send + 'static,
46 {
47 let device = self.device.clone();
48
49 spawn(move || {
50 loop {
51 match receiver_input.recv() {
52 Ok(item) => {
53 let model = item.model.fork(&device);
54 let output = model.step(item.item);
55
56 sender_output.send(output).unwrap();
57 }
58 Err(_err) => {
59 log::info!("Closing thread on device {device:?}");
60 break;
61 }
62 }
63 }
64 });
65 }
66}
67
68impl<B, M, TI, TO> MultiDevicesTrainStep<B, M, TI, TO>
69where
70 B: AutodiffBackend,
71 M: AutodiffModule<B> + TrainStep<TI, TO> + Send + Clone + 'static,
72 TI: Send + 'static,
73 TO: Send + 'static,
74{
75 pub fn new(devices: &[B::Device]) -> Self
85 where
86 TI: Send + 'static,
87 {
88 let (sender_output, receiver_output) = std::sync::mpsc::channel();
89 let workers = devices
90 .iter()
91 .map(|device| {
92 let (sender_input, receiver_input) = std::sync::mpsc::channel();
93 let worker = Worker {
94 sender_input,
95 device: device.clone(),
96 };
97
98 worker.start(sender_output.clone(), receiver_input);
99 worker
100 })
101 .collect();
102
103 Self {
104 workers,
105 receiver: receiver_output,
106 }
107 }
108
109 pub fn step<'a>(
120 &self,
121 dataloaders: &mut [Box<dyn DataLoaderIterator<TI> + 'a>],
122 model: &M,
123 ) -> (Vec<TrainOutput<TO>>, Progress) {
124 let mut num_send = 0;
125
126 let mut items_total = 0;
127 let mut items_processed = 0;
128
129 for (i, worker) in self.workers.iter().enumerate() {
130 let dataloader = &mut dataloaders[i];
131 if let Some(item) = dataloader.next() {
132 worker.register(item, model);
133 num_send += 1;
134 let progress = dataloader.progress();
135 items_total += progress.items_total;
136 items_processed += progress.items_processed;
137 }
138 }
139
140 let mut outputs = Vec::with_capacity(num_send);
141
142 for _ in 0..num_send {
143 let output = self.receiver.recv().unwrap();
144 outputs.push(output);
145 }
146
147 (outputs, Progress::new(items_processed, items_total))
148 }
149}