burn_train/learner/step/
train.rs1use crate::components::{InputTrain, LearnerComponentTypes, OutputTrain};
2use crate::{TrainOutput, TrainStep};
3use burn_core::data::dataloader::DataLoaderIterator;
4use burn_core::data::dataloader::Progress;
5use burn_core::module::Module;
6use burn_core::prelude::Backend;
7use std::sync::mpsc::{Receiver, Sender};
8use std::thread::spawn;
9
10pub struct MultiDevicesTrainStep<LC: LearnerComponentTypes> {
12 workers: Vec<Worker<LC>>,
13 receiver: Receiver<TrainOutput<OutputTrain<LC>>>,
14}
15
16struct Message<M, TI> {
17 item: TI,
18 model: M,
19}
20
21struct Worker<LC: LearnerComponentTypes> {
22 sender_input: Sender<Message<LC::Model, InputTrain<LC>>>,
23 device: <LC::Backend as Backend>::Device,
24}
25
26impl<LC: LearnerComponentTypes> Worker<LC> {
27 fn register(&self, item: InputTrain<LC>, model: &LC::Model) {
28 let message = Message {
29 item,
30 model: model.clone(),
31 };
32 self.sender_input.send(message).unwrap();
33 }
34
35 fn start(
36 &self,
37 sender_output: Sender<TrainOutput<OutputTrain<LC>>>,
38 receiver_input: Receiver<Message<LC::Model, InputTrain<LC>>>,
39 ) {
40 let device = self.device.clone();
41
42 spawn(move || {
43 loop {
44 match receiver_input.recv() {
45 Ok(item) => {
46 let model = item.model.fork(&device);
47 let output = model.step(item.item);
48
49 sender_output.send(output).unwrap();
50 }
51 Err(_err) => {
52 log::info!("Closing thread on device {device:?}");
53 break;
54 }
55 }
56 }
57 });
58 }
59}
60
61impl<LC: LearnerComponentTypes> MultiDevicesTrainStep<LC> {
62 pub fn new(devices: &[<LC::Backend as Backend>::Device]) -> Self {
72 let (sender_output, receiver_output) = std::sync::mpsc::channel();
73 let workers = devices
74 .iter()
75 .map(|device| {
76 let (sender_input, receiver_input) = std::sync::mpsc::channel();
77 let worker = Worker {
78 sender_input,
79 device: device.clone(),
80 };
81
82 worker.start(sender_output.clone(), receiver_input);
83 worker
84 })
85 .collect();
86
87 Self {
88 workers,
89 receiver: receiver_output,
90 }
91 }
92
93 pub fn step<'a>(
104 &self,
105 dataloaders: &mut [Box<dyn DataLoaderIterator<InputTrain<LC>> + 'a>],
106 model: &LC::Model,
107 ) -> (Vec<TrainOutput<OutputTrain<LC>>>, Progress) {
108 let mut num_send = 0;
109
110 let mut items_total = 0;
111 let mut items_processed = 0;
112
113 for (i, worker) in self.workers.iter().enumerate() {
114 let dataloader = &mut dataloaders[i];
115 if let Some(item) = dataloader.next() {
116 worker.register(item, model);
117 num_send += 1;
118 let progress = dataloader.progress();
119 items_total += progress.items_total;
120 items_processed += progress.items_processed;
121 }
122 }
123
124 let mut outputs = Vec::with_capacity(num_send);
125
126 for _ in 0..num_send {
127 let output = self.receiver.recv().unwrap();
128 outputs.push(output);
129 }
130
131 (outputs, Progress::new(items_processed, items_total))
132 }
133}