burn_train/learner/step/
train.rs

1use crate::components::{InputTrain, LearnerComponentTypes, OutputTrain};
2use crate::{TrainOutput, TrainStep};
3use burn_core::data::dataloader::DataLoaderIterator;
4use burn_core::data::dataloader::Progress;
5use burn_core::module::Module;
6use burn_core::prelude::{Backend, DeviceOps};
7use burn_core::tensor::backend::DeviceId;
8use std::sync::mpsc::{Receiver, Sender};
9use std::thread::spawn;
10
11/// Multi devices train step.
12pub struct MultiDevicesTrainStep<LC: LearnerComponentTypes> {
13    workers: Vec<Worker<LC>>,
14    receiver: Receiver<MultiTrainOutput<OutputTrain<LC>>>,
15}
16
17struct Message<M, TI> {
18    item: TI,
19    model: M,
20}
21
22struct Worker<LC: LearnerComponentTypes> {
23    sender_input: Sender<Message<LC::Model, InputTrain<LC>>>,
24    device: <LC::Backend as Backend>::Device,
25}
26
27impl<LC: LearnerComponentTypes> Worker<LC> {
28    fn register(&self, item: InputTrain<LC>, model: &LC::Model) {
29        let message = Message {
30            item,
31            model: model.clone(),
32        };
33        self.sender_input.send(message).unwrap();
34    }
35
36    fn start(
37        &self,
38        sender_output: Sender<MultiTrainOutput<OutputTrain<LC>>>,
39        receiver_input: Receiver<Message<LC::Model, InputTrain<LC>>>,
40    ) {
41        let device = self.device.clone();
42
43        spawn(move || {
44            loop {
45                match receiver_input.recv() {
46                    Ok(item) => {
47                        let model = item.model.fork(&device);
48                        let output = model.step(item.item);
49                        let item = MultiTrainOutput {
50                            output,
51                            device: device.to_id(),
52                        };
53
54                        sender_output.send(item).unwrap();
55                    }
56                    Err(_err) => {
57                        log::info!("Closing thread on device {device:?}");
58                        break;
59                    }
60                }
61            }
62        });
63    }
64}
65
66/// Multiple output items.
67pub struct MultiTrainOutput<TO> {
68    /// The training output.
69    pub output: TrainOutput<TO>,
70    /// The device on which the computing happened.
71    pub device: DeviceId,
72}
73
74impl<LC: LearnerComponentTypes> MultiDevicesTrainStep<LC> {
75    /// Create a new multi devices train step.
76    ///
77    /// # Arguments
78    ///
79    /// * `devices` - Devices.
80    ///
81    /// # Returns
82    ///
83    /// MultiDevicesTrainStep instance.
84    pub fn new(devices: &[<LC::Backend as Backend>::Device]) -> Self {
85        let (sender_output, receiver_output) = std::sync::mpsc::channel();
86        let workers = devices
87            .iter()
88            .map(|device| {
89                let (sender_input, receiver_input) = std::sync::mpsc::channel();
90                let worker = Worker {
91                    sender_input,
92                    device: device.clone(),
93                };
94
95                worker.start(sender_output.clone(), receiver_input);
96                worker
97            })
98            .collect();
99
100        Self {
101            workers,
102            receiver: receiver_output,
103        }
104    }
105
106    /// Collect outputs from workers for one step.
107    ///
108    /// # Arguments
109    ///
110    /// * `model` - Model.
111    /// * `dataloaders` - The data loader for each worker.
112    ///
113    /// # Returns
114    ///
115    /// Outputs.
116    pub fn step<'a>(
117        &self,
118        dataloaders: &mut [Box<dyn DataLoaderIterator<InputTrain<LC>> + 'a>],
119        model: &LC::Model,
120    ) -> (Vec<MultiTrainOutput<OutputTrain<LC>>>, Progress) {
121        let mut num_send = 0;
122
123        let mut items_total = 0;
124        let mut items_processed = 0;
125
126        for (i, worker) in self.workers.iter().enumerate() {
127            let dataloader = &mut dataloaders[i];
128            if let Some(item) = dataloader.next() {
129                worker.register(item, model);
130                num_send += 1;
131                let progress = dataloader.progress();
132                items_total += progress.items_total;
133                items_processed += progress.items_processed;
134            }
135        }
136
137        let mut outputs = Vec::with_capacity(num_send);
138
139        for _ in 0..num_send {
140            let output = self.receiver.recv().unwrap();
141            outputs.push(output);
142        }
143
144        (outputs, Progress::new(items_processed, items_total))
145    }
146}