burn_train/learner/step/
train.rs1use crate::components::{InputTrain, LearnerComponentTypes, OutputTrain};
2use crate::{TrainOutput, TrainStep};
3use burn_core::data::dataloader::DataLoaderIterator;
4use burn_core::data::dataloader::Progress;
5use burn_core::module::Module;
6use burn_core::prelude::{Backend, DeviceOps};
7use burn_core::tensor::backend::DeviceId;
8use std::sync::mpsc::{Receiver, Sender};
9use std::thread::spawn;
10
11pub struct MultiDevicesTrainStep<LC: LearnerComponentTypes> {
13 workers: Vec<Worker<LC>>,
14 receiver: Receiver<MultiTrainOutput<OutputTrain<LC>>>,
15}
16
17struct Message<M, TI> {
18 item: TI,
19 model: M,
20}
21
22struct Worker<LC: LearnerComponentTypes> {
23 sender_input: Sender<Message<LC::Model, InputTrain<LC>>>,
24 device: <LC::Backend as Backend>::Device,
25}
26
27impl<LC: LearnerComponentTypes> Worker<LC> {
28 fn register(&self, item: InputTrain<LC>, model: &LC::Model) {
29 let message = Message {
30 item,
31 model: model.clone(),
32 };
33 self.sender_input.send(message).unwrap();
34 }
35
36 fn start(
37 &self,
38 sender_output: Sender<MultiTrainOutput<OutputTrain<LC>>>,
39 receiver_input: Receiver<Message<LC::Model, InputTrain<LC>>>,
40 ) {
41 let device = self.device.clone();
42
43 spawn(move || {
44 loop {
45 match receiver_input.recv() {
46 Ok(item) => {
47 let model = item.model.fork(&device);
48 let output = model.step(item.item);
49 let item = MultiTrainOutput {
50 output,
51 device: device.to_id(),
52 };
53
54 sender_output.send(item).unwrap();
55 }
56 Err(_err) => {
57 log::info!("Closing thread on device {device:?}");
58 break;
59 }
60 }
61 }
62 });
63 }
64}
65
66pub struct MultiTrainOutput<TO> {
68 pub output: TrainOutput<TO>,
70 pub device: DeviceId,
72}
73
74impl<LC: LearnerComponentTypes> MultiDevicesTrainStep<LC> {
75 pub fn new(devices: &[<LC::Backend as Backend>::Device]) -> Self {
85 let (sender_output, receiver_output) = std::sync::mpsc::channel();
86 let workers = devices
87 .iter()
88 .map(|device| {
89 let (sender_input, receiver_input) = std::sync::mpsc::channel();
90 let worker = Worker {
91 sender_input,
92 device: device.clone(),
93 };
94
95 worker.start(sender_output.clone(), receiver_input);
96 worker
97 })
98 .collect();
99
100 Self {
101 workers,
102 receiver: receiver_output,
103 }
104 }
105
106 pub fn step<'a>(
117 &self,
118 dataloaders: &mut [Box<dyn DataLoaderIterator<InputTrain<LC>> + 'a>],
119 model: &LC::Model,
120 ) -> (Vec<MultiTrainOutput<OutputTrain<LC>>>, Progress) {
121 let mut num_send = 0;
122
123 let mut items_total = 0;
124 let mut items_processed = 0;
125
126 for (i, worker) in self.workers.iter().enumerate() {
127 let dataloader = &mut dataloaders[i];
128 if let Some(item) = dataloader.next() {
129 worker.register(item, model);
130 num_send += 1;
131 let progress = dataloader.progress();
132 items_total += progress.items_total;
133 items_processed += progress.items_processed;
134 }
135 }
136
137 let mut outputs = Vec::with_capacity(num_send);
138
139 for _ in 0..num_send {
140 let output = self.receiver.recv().unwrap();
141 outputs.push(output);
142 }
143
144 (outputs, Progress::new(items_processed, items_total))
145 }
146}