burn_train/learner/step/
train.rs

1use crate::{TrainOutput, TrainStep};
2use burn_core::data::dataloader::Progress;
3use burn_core::{
4    data::dataloader::DataLoaderIterator, module::AutodiffModule, tensor::backend::AutodiffBackend,
5};
6use std::sync::mpsc::{Receiver, Sender};
7use std::thread::spawn;
8
9/// Multi devices train step.
10pub struct MultiDevicesTrainStep<B: AutodiffBackend, M, TI, TO> {
11    workers: Vec<Worker<B, M, TI>>,
12    receiver: Receiver<TrainOutput<TO>>,
13}
14
15struct Message<M, TI> {
16    item: TI,
17    model: M,
18}
19
20struct Worker<B: AutodiffBackend, M, TI> {
21    sender_input: Sender<Message<M, TI>>,
22    device: B::Device,
23}
24
25impl<B, M, TI> Worker<B, M, TI>
26where
27    B: AutodiffBackend,
28    M: AutodiffModule<B>,
29{
30    fn register(&self, item: TI, model: &M) {
31        let message = Message {
32            item,
33            model: model.clone(),
34        };
35        self.sender_input.send(message).unwrap();
36    }
37
38    fn start<TO>(
39        &self,
40        sender_output: Sender<TrainOutput<TO>>,
41        receiver_input: Receiver<Message<M, TI>>,
42    ) where
43        TI: Send + 'static,
44        TO: Send + 'static,
45        M: TrainStep<TI, TO> + Send + 'static,
46    {
47        let device = self.device.clone();
48
49        spawn(move || {
50            loop {
51                match receiver_input.recv() {
52                    Ok(item) => {
53                        let model = item.model.fork(&device);
54                        let output = model.step(item.item);
55
56                        sender_output.send(output).unwrap();
57                    }
58                    Err(_err) => {
59                        log::info!("Closing thread on device {device:?}");
60                        break;
61                    }
62                }
63            }
64        });
65    }
66}
67
68impl<B, M, TI, TO> MultiDevicesTrainStep<B, M, TI, TO>
69where
70    B: AutodiffBackend,
71    M: AutodiffModule<B> + TrainStep<TI, TO> + Send + Clone + 'static,
72    TI: Send + 'static,
73    TO: Send + 'static,
74{
75    /// Create a new multi devices train step.
76    ///
77    /// # Arguments
78    ///
79    /// * `devices` - Devices.
80    ///
81    /// # Returns
82    ///
83    /// MultiDevicesTrainStep instance.
84    pub fn new(devices: &[B::Device]) -> Self
85    where
86        TI: Send + 'static,
87    {
88        let (sender_output, receiver_output) = std::sync::mpsc::channel();
89        let workers = devices
90            .iter()
91            .map(|device| {
92                let (sender_input, receiver_input) = std::sync::mpsc::channel();
93                let worker = Worker {
94                    sender_input,
95                    device: device.clone(),
96                };
97
98                worker.start(sender_output.clone(), receiver_input);
99                worker
100            })
101            .collect();
102
103        Self {
104            workers,
105            receiver: receiver_output,
106        }
107    }
108
109    /// Collect outputs from workers for one step.
110    ///
111    /// # Arguments
112    ///
113    /// * `model` - Model.
114    /// * `dataloaders` - The data loader for each worker.
115    ///
116    /// # Returns
117    ///
118    /// Outputs.
119    pub fn step<'a>(
120        &self,
121        dataloaders: &mut [Box<dyn DataLoaderIterator<TI> + 'a>],
122        model: &M,
123    ) -> (Vec<TrainOutput<TO>>, Progress) {
124        let mut num_send = 0;
125
126        let mut items_total = 0;
127        let mut items_processed = 0;
128
129        for (i, worker) in self.workers.iter().enumerate() {
130            let dataloader = &mut dataloaders[i];
131            if let Some(item) = dataloader.next() {
132                worker.register(item, model);
133                num_send += 1;
134                let progress = dataloader.progress();
135                items_total += progress.items_total;
136                items_processed += progress.items_processed;
137            }
138        }
139
140        let mut outputs = Vec::with_capacity(num_send);
141
142        for _ in 0..num_send {
143            let output = self.receiver.recv().unwrap();
144            outputs.push(output);
145        }
146
147        (outputs, Progress::new(items_processed, items_total))
148    }
149}