burn_train/learner/step/
train.rs

1use crate::{TrainOutput, TrainStep};
2use burn_core::{
3    data::dataloader::DataLoaderIterator, module::AutodiffModule, tensor::backend::AutodiffBackend,
4};
5use std::sync::mpsc::{Receiver, Sender};
6use std::thread::spawn;
7
8/// Multi devices train step.
9pub struct MultiDevicesTrainStep<B: AutodiffBackend, M, TI, TO> {
10    workers: Vec<Worker<B, M, TI>>,
11    receiver: Receiver<TrainOutput<TO>>,
12}
13
14struct Message<M, TI> {
15    item: TI,
16    model: M,
17}
18
19struct Worker<B: AutodiffBackend, M, TI> {
20    sender_input: Sender<Message<M, TI>>,
21    device: B::Device,
22}
23
24impl<B, M, TI> Worker<B, M, TI>
25where
26    B: AutodiffBackend,
27    M: AutodiffModule<B>,
28{
29    fn register(&self, item: TI, model: &M) {
30        let message = Message {
31            item,
32            model: model.clone(),
33        };
34        self.sender_input.send(message).unwrap();
35    }
36
37    fn start<TO>(
38        &self,
39        sender_output: Sender<TrainOutput<TO>>,
40        receiver_input: Receiver<Message<M, TI>>,
41    ) where
42        TI: Send + 'static,
43        TO: Send + 'static,
44        M: TrainStep<TI, TO> + Send + 'static,
45    {
46        let device = self.device.clone();
47
48        spawn(move || loop {
49            match receiver_input.recv() {
50                Ok(item) => {
51                    let step = item.model.fork(&device);
52                    let output = step.step(item.item);
53
54                    sender_output.send(output).unwrap();
55                }
56                Err(_err) => {
57                    log::info!("Closing thread on device {:?}", device);
58                    break;
59                }
60            }
61        });
62    }
63}
64
65impl<B, M, TI, TO> MultiDevicesTrainStep<B, M, TI, TO>
66where
67    B: AutodiffBackend,
68    M: AutodiffModule<B> + TrainStep<TI, TO> + Send + Clone + 'static,
69    TI: Send + 'static,
70    TO: Send + 'static,
71{
72    /// Create a new multi devices train step.
73    ///
74    /// # Arguments
75    ///
76    /// * `devices` - Devices.
77    ///
78    /// # Returns
79    ///
80    /// MultiDevicesTrainStep instance.
81    pub fn new(devices: &[B::Device]) -> Self
82    where
83        TI: Send + 'static,
84    {
85        let (sender_output, receiver_output) = std::sync::mpsc::channel();
86        let workers = devices
87            .iter()
88            .map(|device| {
89                let (sender_input, receiver_input) = std::sync::mpsc::channel();
90                let worker = Worker {
91                    sender_input,
92                    device: device.clone(),
93                };
94
95                worker.start(sender_output.clone(), receiver_input);
96                worker
97            })
98            .collect();
99
100        Self {
101            workers,
102            receiver: receiver_output,
103        }
104    }
105
106    /// Collect outputs from workers for one step.
107    ///
108    /// # Arguments
109    ///
110    /// * `dataloader` - Dataloader.
111    /// * `model` - Model.
112    ///
113    /// # Returns
114    ///
115    /// Outputs.
116    pub fn step<'a>(
117        &self,
118        dataloader: &mut Box<dyn DataLoaderIterator<TI> + 'a>,
119        model: &M,
120    ) -> Vec<TrainOutput<TO>> {
121        let mut num_send = 0;
122
123        for worker in self.workers.iter() {
124            if let Some(item) = dataloader.next() {
125                worker.register(item, model);
126                num_send += 1;
127            }
128        }
129
130        let mut outputs = Vec::with_capacity(num_send);
131
132        for _ in 0..num_send {
133            let output = self.receiver.recv().unwrap();
134            outputs.push(output);
135        }
136
137        outputs
138    }
139}