burn_train/learner/supervised/step/
train.rs1use crate::{LearningComponentsTypes, TrainingModel};
2use crate::{TrainOutput, TrainStep, TrainingBackend, TrainingModelInput, TrainingModelOutput};
3use burn_core::data::dataloader::DataLoaderIterator;
4use burn_core::data::dataloader::Progress;
5use burn_core::module::Module;
6use burn_core::prelude::DeviceOps;
7use burn_core::tensor::Device;
8use burn_core::tensor::backend::DeviceId;
9use std::sync::mpsc::{Receiver, Sender};
10use std::thread::spawn;
11
12pub struct MultiDevicesTrainStep<LC: LearningComponentsTypes> {
14 workers: Vec<Worker<LC>>,
15 receiver: Receiver<MultiTrainOutput<TrainingModelOutput<LC>>>,
16}
17
18struct Message<M, TI> {
19 item: TI,
20 model: M,
21}
22
23struct Worker<LC: LearningComponentsTypes> {
24 #[allow(clippy::type_complexity)]
26 sender_input: Sender<Message<TrainingModel<LC>, TrainingModelInput<LC>>>,
27 device: Device<TrainingBackend<LC>>,
28}
29
30impl<LC: LearningComponentsTypes> Worker<LC> {
31 fn register(&self, item: TrainingModelInput<LC>, model: &TrainingModel<LC>) {
32 let message = Message {
33 item,
34 model: model.clone(),
35 };
36 self.sender_input.send(message).unwrap();
37 }
38
39 #[allow(clippy::type_complexity)]
41 fn start(
42 &self,
43 sender_output: Sender<MultiTrainOutput<TrainingModelOutput<LC>>>,
44 receiver_input: Receiver<Message<TrainingModel<LC>, TrainingModelInput<LC>>>,
45 ) {
46 let device = self.device.clone();
47
48 spawn(move || {
49 loop {
50 match receiver_input.recv() {
51 Ok(item) => {
52 let model = item.model.fork(&device);
53 let output = model.step(item.item);
54 let item = MultiTrainOutput {
55 output,
56 device: device.to_id(),
57 };
58
59 sender_output.send(item).unwrap();
60 }
61 Err(_err) => {
62 log::info!("Closing thread on device {device:?}");
63 break;
64 }
65 }
66 }
67 });
68 }
69}
70
71pub struct MultiTrainOutput<TO> {
73 pub output: TrainOutput<TO>,
75 pub device: DeviceId,
77}
78
79impl<LC: LearningComponentsTypes> MultiDevicesTrainStep<LC> {
80 pub fn new(devices: &[Device<TrainingBackend<LC>>]) -> Self {
90 let (sender_output, receiver_output) = std::sync::mpsc::channel();
91 let workers = devices
92 .iter()
93 .map(|device| {
94 let (sender_input, receiver_input) = std::sync::mpsc::channel();
95 let worker = Worker {
96 sender_input,
97 device: device.clone(),
98 };
99
100 worker.start(sender_output.clone(), receiver_input);
101 worker
102 })
103 .collect();
104
105 Self {
106 workers,
107 receiver: receiver_output,
108 }
109 }
110
111 pub fn step<'a>(
122 &self,
123 dataloaders: &mut [Box<dyn DataLoaderIterator<TrainingModelInput<LC>> + 'a>],
124 model: &TrainingModel<LC>,
125 ) -> (Vec<MultiTrainOutput<TrainingModelOutput<LC>>>, Progress) {
126 let mut num_send = 0;
127
128 let mut items_total = 0;
129 let mut items_processed = 0;
130
131 for (i, worker) in self.workers.iter().enumerate() {
132 let dataloader = &mut dataloaders[i];
133 if let Some(item) = dataloader.next() {
134 worker.register(item, model);
135 num_send += 1;
136 let progress = dataloader.progress();
137 items_total += progress.items_total;
138 items_processed += progress.items_processed;
139 }
140 }
141
142 let mut outputs = Vec::with_capacity(num_send);
143
144 for _ in 0..num_send {
145 let output = self.receiver.recv().unwrap();
146 outputs.push(output);
147 }
148
149 (outputs, Progress::new(items_processed, items_total))
150 }
151}