burn_compute/channel/
mpsc.rs1use std::{
2 sync::{mpsc, Arc},
3 thread,
4};
5
6use burn_common::reader::Reader;
7
8use super::ComputeChannel;
9use crate::server::{ComputeServer, Handle};
10
11#[derive(Debug)]
14pub struct MpscComputeChannel<Server>
15where
16 Server: ComputeServer,
17{
18 state: Arc<MpscComputeChannelState<Server>>,
19}
20
21#[derive(Debug)]
22struct MpscComputeChannelState<Server>
23where
24 Server: ComputeServer,
25{
26 _handle: thread::JoinHandle<()>,
27 sender: mpsc::Sender<Message<Server>>,
28}
29
30type Callback<Response> = mpsc::Sender<Response>;
31
32enum Message<Server>
33where
34 Server: ComputeServer,
35{
36 Read(Handle<Server>, Callback<Reader<Vec<u8>>>),
37 Create(Vec<u8>, Callback<Handle<Server>>),
38 Empty(usize, Callback<Handle<Server>>),
39 ExecuteKernel(Server::Kernel, Vec<Handle<Server>>),
40 Sync(Callback<()>),
41}
42
43impl<Server> MpscComputeChannel<Server>
44where
45 Server: ComputeServer + 'static,
46{
47 pub fn new(mut server: Server) -> Self {
49 let (sender, receiver) = mpsc::channel();
50
51 let _handle = thread::spawn(move || {
52 while let Ok(message) = receiver.recv() {
53 match message {
54 Message::Read(handle, callback) => {
55 let data = server.read(&handle);
56 core::mem::drop(handle);
57 callback.send(data).unwrap();
58 }
59 Message::Create(data, callback) => {
60 let handle = server.create(&data);
61 callback.send(handle).unwrap();
62 }
63 Message::Empty(size, callback) => {
64 let handle = server.empty(size);
65 callback.send(handle).unwrap();
66 }
67 Message::ExecuteKernel(kernel, handles) => {
68 server.execute(kernel, &handles.iter().collect::<Vec<_>>());
69 }
70 Message::Sync(callback) => {
71 server.sync();
72 callback.send(()).unwrap();
73 }
74 };
75 }
76 });
77
78 let state = Arc::new(MpscComputeChannelState { sender, _handle });
79
80 Self { state }
81 }
82}
83
84impl<Server: ComputeServer> Clone for MpscComputeChannel<Server> {
85 fn clone(&self) -> Self {
86 Self {
87 state: self.state.clone(),
88 }
89 }
90}
91
92impl<Server> ComputeChannel<Server> for MpscComputeChannel<Server>
93where
94 Server: ComputeServer + 'static,
95{
96 fn read(&self, handle: &Handle<Server>) -> Reader<Vec<u8>> {
97 let (callback, response) = mpsc::channel();
98
99 self.state
100 .sender
101 .send(Message::Read(handle.clone(), callback))
102 .unwrap();
103
104 self.response(response)
105 }
106
107 fn create(&self, data: &[u8]) -> Handle<Server> {
108 let (callback, response) = mpsc::channel();
109
110 self.state
111 .sender
112 .send(Message::Create(data.to_vec(), callback))
113 .unwrap();
114
115 self.response(response)
116 }
117
118 fn empty(&self, size: usize) -> Handle<Server> {
119 let (callback, response) = mpsc::channel();
120
121 self.state
122 .sender
123 .send(Message::Empty(size, callback))
124 .unwrap();
125
126 self.response(response)
127 }
128
129 fn execute(&self, kernel: Server::Kernel, handles: &[&Handle<Server>]) {
130 self.state
131 .sender
132 .send(Message::ExecuteKernel(
133 kernel,
134 handles
135 .iter()
136 .map(|h| (*h).clone())
137 .collect::<Vec<Handle<Server>>>(),
138 ))
139 .unwrap()
140 }
141
142 fn sync(&self) {
143 let (callback, response) = mpsc::channel();
144
145 self.state.sender.send(Message::Sync(callback)).unwrap();
146
147 self.response(response)
148 }
149}
150
151impl<Server: ComputeServer> MpscComputeChannel<Server> {
152 fn response<Response>(&self, response: mpsc::Receiver<Response>) -> Response {
153 match response.recv() {
154 Ok(val) => val,
155 Err(err) => panic!("Can't connect to the server correctly {err:?}"),
156 }
157 }
158}