burn_train/learner/step/
train.rs1use crate::{TrainOutput, TrainStep};
2use burn_core::{
3 data::dataloader::DataLoaderIterator, module::AutodiffModule, tensor::backend::AutodiffBackend,
4};
5use std::sync::mpsc::{Receiver, Sender};
6use std::thread::spawn;
7
8pub struct MultiDevicesTrainStep<B: AutodiffBackend, M, TI, TO> {
10 workers: Vec<Worker<B, M, TI>>,
11 receiver: Receiver<TrainOutput<TO>>,
12}
13
14struct Message<M, TI> {
15 item: TI,
16 model: M,
17}
18
19struct Worker<B: AutodiffBackend, M, TI> {
20 sender_input: Sender<Message<M, TI>>,
21 device: B::Device,
22}
23
24impl<B, M, TI> Worker<B, M, TI>
25where
26 B: AutodiffBackend,
27 M: AutodiffModule<B>,
28{
29 fn register(&self, item: TI, model: &M) {
30 let message = Message {
31 item,
32 model: model.clone(),
33 };
34 self.sender_input.send(message).unwrap();
35 }
36
37 fn start<TO>(
38 &self,
39 sender_output: Sender<TrainOutput<TO>>,
40 receiver_input: Receiver<Message<M, TI>>,
41 ) where
42 TI: Send + 'static,
43 TO: Send + 'static,
44 M: TrainStep<TI, TO> + Send + 'static,
45 {
46 let device = self.device.clone();
47
48 spawn(move || loop {
49 match receiver_input.recv() {
50 Ok(item) => {
51 let step = item.model.fork(&device);
52 let output = step.step(item.item);
53
54 sender_output.send(output).unwrap();
55 }
56 Err(_err) => {
57 log::info!("Closing thread on device {:?}", device);
58 break;
59 }
60 }
61 });
62 }
63}
64
65impl<B, M, TI, TO> MultiDevicesTrainStep<B, M, TI, TO>
66where
67 B: AutodiffBackend,
68 M: AutodiffModule<B> + TrainStep<TI, TO> + Send + Clone + 'static,
69 TI: Send + 'static,
70 TO: Send + 'static,
71{
72 pub fn new(devices: &[B::Device]) -> Self
82 where
83 TI: Send + 'static,
84 {
85 let (sender_output, receiver_output) = std::sync::mpsc::channel();
86 let workers = devices
87 .iter()
88 .map(|device| {
89 let (sender_input, receiver_input) = std::sync::mpsc::channel();
90 let worker = Worker {
91 sender_input,
92 device: device.clone(),
93 };
94
95 worker.start(sender_output.clone(), receiver_input);
96 worker
97 })
98 .collect();
99
100 Self {
101 workers,
102 receiver: receiver_output,
103 }
104 }
105
106 pub fn step<'a>(
117 &self,
118 dataloader: &mut Box<dyn DataLoaderIterator<TI> + 'a>,
119 model: &M,
120 ) -> Vec<TrainOutput<TO>> {
121 let mut num_send = 0;
122
123 for worker in self.workers.iter() {
124 if let Some(item) = dataloader.next() {
125 worker.register(item, model);
126 num_send += 1;
127 }
128 }
129
130 let mut outputs = Vec::with_capacity(num_send);
131
132 for _ in 0..num_send {
133 let output = self.receiver.recv().unwrap();
134 outputs.push(output);
135 }
136
137 outputs
138 }
139}