Skip to main content

burn_train/learner/supervised/step/
train.rs

1use crate::{LearningComponentsTypes, TrainingModel};
2use crate::{TrainOutput, TrainStep, TrainingBackend, TrainingModelInput, TrainingModelOutput};
3use burn_core::data::dataloader::DataLoaderIterator;
4use burn_core::data::dataloader::Progress;
5use burn_core::module::Module;
6use burn_core::prelude::DeviceOps;
7use burn_core::tensor::Device;
8use burn_core::tensor::backend::DeviceId;
9use std::sync::mpsc::{Receiver, Sender};
10use std::thread::spawn;
11
12/// Multi devices train step.
13pub struct MultiDevicesTrainStep<LC: LearningComponentsTypes> {
14    workers: Vec<Worker<LC>>,
15    receiver: Receiver<MultiTrainOutput<TrainingModelOutput<LC>>>,
16}
17
18struct Message<M, TI> {
19    item: TI,
20    model: M,
21}
22
23struct Worker<LC: LearningComponentsTypes> {
24    // Not that complex. Extracting into another type would only make it more confusing.
25    #[allow(clippy::type_complexity)]
26    sender_input: Sender<Message<TrainingModel<LC>, TrainingModelInput<LC>>>,
27    device: Device<TrainingBackend<LC>>,
28}
29
30impl<LC: LearningComponentsTypes> Worker<LC> {
31    fn register(&self, item: TrainingModelInput<LC>, model: &TrainingModel<LC>) {
32        let message = Message {
33            item,
34            model: model.clone(),
35        };
36        self.sender_input.send(message).unwrap();
37    }
38
39    // Not that complex. Extracting into another type would only make it more confusing.
40    #[allow(clippy::type_complexity)]
41    fn start(
42        &self,
43        sender_output: Sender<MultiTrainOutput<TrainingModelOutput<LC>>>,
44        receiver_input: Receiver<Message<TrainingModel<LC>, TrainingModelInput<LC>>>,
45    ) {
46        let device = self.device.clone();
47
48        spawn(move || {
49            loop {
50                match receiver_input.recv() {
51                    Ok(item) => {
52                        let model = item.model.fork(&device);
53                        let output = model.step(item.item);
54                        let item = MultiTrainOutput {
55                            output,
56                            device: device.to_id(),
57                        };
58
59                        sender_output.send(item).unwrap();
60                    }
61                    Err(_err) => {
62                        log::info!("Closing thread on device {device:?}");
63                        break;
64                    }
65                }
66            }
67        });
68    }
69}
70
71/// Multiple output items.
72pub struct MultiTrainOutput<TO> {
73    /// The training output.
74    pub output: TrainOutput<TO>,
75    /// The device on which the computing happened.
76    pub device: DeviceId,
77}
78
79impl<LC: LearningComponentsTypes> MultiDevicesTrainStep<LC> {
80    /// Create a new multi devices train step.
81    ///
82    /// # Arguments
83    ///
84    /// * `devices` - Devices.
85    ///
86    /// # Returns
87    ///
88    /// MultiDevicesTrainStep instance.
89    pub fn new(devices: &[Device<TrainingBackend<LC>>]) -> Self {
90        let (sender_output, receiver_output) = std::sync::mpsc::channel();
91        let workers = devices
92            .iter()
93            .map(|device| {
94                let (sender_input, receiver_input) = std::sync::mpsc::channel();
95                let worker = Worker {
96                    sender_input,
97                    device: device.clone(),
98                };
99
100                worker.start(sender_output.clone(), receiver_input);
101                worker
102            })
103            .collect();
104
105        Self {
106            workers,
107            receiver: receiver_output,
108        }
109    }
110
111    /// Collect outputs from workers for one step.
112    ///
113    /// # Arguments
114    ///
115    /// * `model` - Model.
116    /// * `dataloaders` - The data loader for each worker.
117    ///
118    /// # Returns
119    ///
120    /// Outputs.
121    pub fn step<'a>(
122        &self,
123        dataloaders: &mut [Box<dyn DataLoaderIterator<TrainingModelInput<LC>> + 'a>],
124        model: &TrainingModel<LC>,
125    ) -> (Vec<MultiTrainOutput<TrainingModelOutput<LC>>>, Progress) {
126        let mut num_send = 0;
127
128        let mut items_total = 0;
129        let mut items_processed = 0;
130
131        for (i, worker) in self.workers.iter().enumerate() {
132            let dataloader = &mut dataloaders[i];
133            if let Some(item) = dataloader.next() {
134                worker.register(item, model);
135                num_send += 1;
136                let progress = dataloader.progress();
137                items_total += progress.items_total;
138                items_processed += progress.items_processed;
139            }
140        }
141
142        let mut outputs = Vec::with_capacity(num_send);
143
144        for _ in 0..num_send {
145            let output = self.receiver.recv().unwrap();
146            outputs.push(output);
147        }
148
149        (outputs, Progress::new(items_processed, items_total))
150    }
151}