burn_compute/channel/
mpsc.rs

1use std::{
2    sync::{mpsc, Arc},
3    thread,
4};
5
6use burn_common::reader::Reader;
7
8use super::ComputeChannel;
9use crate::server::{ComputeServer, Handle};
10
11/// Create a channel using the [multi-producer, single-consumer channel](mpsc) to communicate with
12/// the compute server spawn on its own thread.
13#[derive(Debug)]
14pub struct MpscComputeChannel<Server>
15where
16    Server: ComputeServer,
17{
18    state: Arc<MpscComputeChannelState<Server>>,
19}
20
21#[derive(Debug)]
22struct MpscComputeChannelState<Server>
23where
24    Server: ComputeServer,
25{
26    _handle: thread::JoinHandle<()>,
27    sender: mpsc::Sender<Message<Server>>,
28}
29
30type Callback<Response> = mpsc::Sender<Response>;
31
32enum Message<Server>
33where
34    Server: ComputeServer,
35{
36    Read(Handle<Server>, Callback<Reader<Vec<u8>>>),
37    Create(Vec<u8>, Callback<Handle<Server>>),
38    Empty(usize, Callback<Handle<Server>>),
39    ExecuteKernel(Server::Kernel, Vec<Handle<Server>>),
40    Sync(Callback<()>),
41}
42
43impl<Server> MpscComputeChannel<Server>
44where
45    Server: ComputeServer + 'static,
46{
47    /// Create a new mpsc compute channel.
48    pub fn new(mut server: Server) -> Self {
49        let (sender, receiver) = mpsc::channel();
50
51        let _handle = thread::spawn(move || {
52            while let Ok(message) = receiver.recv() {
53                match message {
54                    Message::Read(handle, callback) => {
55                        let data = server.read(&handle);
56                        core::mem::drop(handle);
57                        callback.send(data).unwrap();
58                    }
59                    Message::Create(data, callback) => {
60                        let handle = server.create(&data);
61                        callback.send(handle).unwrap();
62                    }
63                    Message::Empty(size, callback) => {
64                        let handle = server.empty(size);
65                        callback.send(handle).unwrap();
66                    }
67                    Message::ExecuteKernel(kernel, handles) => {
68                        server.execute(kernel, &handles.iter().collect::<Vec<_>>());
69                    }
70                    Message::Sync(callback) => {
71                        server.sync();
72                        callback.send(()).unwrap();
73                    }
74                };
75            }
76        });
77
78        let state = Arc::new(MpscComputeChannelState { sender, _handle });
79
80        Self { state }
81    }
82}
83
84impl<Server: ComputeServer> Clone for MpscComputeChannel<Server> {
85    fn clone(&self) -> Self {
86        Self {
87            state: self.state.clone(),
88        }
89    }
90}
91
92impl<Server> ComputeChannel<Server> for MpscComputeChannel<Server>
93where
94    Server: ComputeServer + 'static,
95{
96    fn read(&self, handle: &Handle<Server>) -> Reader<Vec<u8>> {
97        let (callback, response) = mpsc::channel();
98
99        self.state
100            .sender
101            .send(Message::Read(handle.clone(), callback))
102            .unwrap();
103
104        self.response(response)
105    }
106
107    fn create(&self, data: &[u8]) -> Handle<Server> {
108        let (callback, response) = mpsc::channel();
109
110        self.state
111            .sender
112            .send(Message::Create(data.to_vec(), callback))
113            .unwrap();
114
115        self.response(response)
116    }
117
118    fn empty(&self, size: usize) -> Handle<Server> {
119        let (callback, response) = mpsc::channel();
120
121        self.state
122            .sender
123            .send(Message::Empty(size, callback))
124            .unwrap();
125
126        self.response(response)
127    }
128
129    fn execute(&self, kernel: Server::Kernel, handles: &[&Handle<Server>]) {
130        self.state
131            .sender
132            .send(Message::ExecuteKernel(
133                kernel,
134                handles
135                    .iter()
136                    .map(|h| (*h).clone())
137                    .collect::<Vec<Handle<Server>>>(),
138            ))
139            .unwrap()
140    }
141
142    fn sync(&self) {
143        let (callback, response) = mpsc::channel();
144
145        self.state.sender.send(Message::Sync(callback)).unwrap();
146
147        self.response(response)
148    }
149}
150
151impl<Server: ComputeServer> MpscComputeChannel<Server> {
152    fn response<Response>(&self, response: mpsc::Receiver<Response>) -> Response {
153        match response.recv() {
154            Ok(val) => val,
155            Err(err) => panic!("Can't connect to the server correctly {err:?}"),
156        }
157    }
158}