burn_train/learner/step/
train.rs

1use crate::components::{InputTrain, LearnerComponentTypes, OutputTrain};
2use crate::{TrainOutput, TrainStep};
3use burn_core::data::dataloader::DataLoaderIterator;
4use burn_core::data::dataloader::Progress;
5use burn_core::module::Module;
6use burn_core::prelude::Backend;
7use std::sync::mpsc::{Receiver, Sender};
8use std::thread::spawn;
9
10/// Multi devices train step.
11pub struct MultiDevicesTrainStep<LC: LearnerComponentTypes> {
12    workers: Vec<Worker<LC>>,
13    receiver: Receiver<TrainOutput<OutputTrain<LC>>>,
14}
15
16struct Message<M, TI> {
17    item: TI,
18    model: M,
19}
20
21struct Worker<LC: LearnerComponentTypes> {
22    sender_input: Sender<Message<LC::Model, InputTrain<LC>>>,
23    device: <LC::Backend as Backend>::Device,
24}
25
26impl<LC: LearnerComponentTypes> Worker<LC> {
27    fn register(&self, item: InputTrain<LC>, model: &LC::Model) {
28        let message = Message {
29            item,
30            model: model.clone(),
31        };
32        self.sender_input.send(message).unwrap();
33    }
34
35    fn start(
36        &self,
37        sender_output: Sender<TrainOutput<OutputTrain<LC>>>,
38        receiver_input: Receiver<Message<LC::Model, InputTrain<LC>>>,
39    ) {
40        let device = self.device.clone();
41
42        spawn(move || {
43            loop {
44                match receiver_input.recv() {
45                    Ok(item) => {
46                        let model = item.model.fork(&device);
47                        let output = model.step(item.item);
48
49                        sender_output.send(output).unwrap();
50                    }
51                    Err(_err) => {
52                        log::info!("Closing thread on device {device:?}");
53                        break;
54                    }
55                }
56            }
57        });
58    }
59}
60
61impl<LC: LearnerComponentTypes> MultiDevicesTrainStep<LC> {
62    /// Create a new multi devices train step.
63    ///
64    /// # Arguments
65    ///
66    /// * `devices` - Devices.
67    ///
68    /// # Returns
69    ///
70    /// MultiDevicesTrainStep instance.
71    pub fn new(devices: &[<LC::Backend as Backend>::Device]) -> Self {
72        let (sender_output, receiver_output) = std::sync::mpsc::channel();
73        let workers = devices
74            .iter()
75            .map(|device| {
76                let (sender_input, receiver_input) = std::sync::mpsc::channel();
77                let worker = Worker {
78                    sender_input,
79                    device: device.clone(),
80                };
81
82                worker.start(sender_output.clone(), receiver_input);
83                worker
84            })
85            .collect();
86
87        Self {
88            workers,
89            receiver: receiver_output,
90        }
91    }
92
93    /// Collect outputs from workers for one step.
94    ///
95    /// # Arguments
96    ///
97    /// * `model` - Model.
98    /// * `dataloaders` - The data loader for each worker.
99    ///
100    /// # Returns
101    ///
102    /// Outputs.
103    pub fn step<'a>(
104        &self,
105        dataloaders: &mut [Box<dyn DataLoaderIterator<InputTrain<LC>> + 'a>],
106        model: &LC::Model,
107    ) -> (Vec<TrainOutput<OutputTrain<LC>>>, Progress) {
108        let mut num_send = 0;
109
110        let mut items_total = 0;
111        let mut items_processed = 0;
112
113        for (i, worker) in self.workers.iter().enumerate() {
114            let dataloader = &mut dataloaders[i];
115            if let Some(item) = dataloader.next() {
116                worker.register(item, model);
117                num_send += 1;
118                let progress = dataloader.progress();
119                items_total += progress.items_total;
120                items_processed += progress.items_processed;
121            }
122        }
123
124        let mut outputs = Vec::with_capacity(num_send);
125
126        for _ in 0..num_send {
127            let output = self.receiver.recv().unwrap();
128            outputs.push(output);
129        }
130
131        (outputs, Progress::new(items_processed, items_total))
132    }
133}