burn_remote/client/
runner.rs1use super::{RemoteChannel, RemoteClient};
2use crate::shared::{ComputeTask, TaskResponseContent, TensorRemote};
3use burn_backend::{DeviceId, DeviceOps, ExecutionError, TensorData};
4use burn_communication::{Address, ProtocolClient, data_service::TensorTransferId};
5use burn_ir::TensorIr;
6use burn_router::{MultiBackendBridge, RouterTensor, RunnerClient, get_client};
7use burn_std::{backtrace::BackTrace, future::DynFut};
8use std::sync::OnceLock;
9use std::{collections::HashMap, marker::PhantomData, str::FromStr, sync::Mutex};
10
11static ADDRESS_REGISTRY: OnceLock<Mutex<HashMap<String, u32>>> = OnceLock::new();
13
14fn get_address_registry() -> &'static Mutex<HashMap<String, u32>> {
15 ADDRESS_REGISTRY.get_or_init(|| Mutex::new(HashMap::new()))
16}
17
18pub fn address_to_id<S: AsRef<str>>(address: S) -> u32 {
24 let registry = get_address_registry();
25 let mut registry = registry.lock().unwrap();
26 let next_id = registry.len() as u32;
27 *registry
28 .entry(address.as_ref().to_string())
29 .or_insert_with(|| next_id)
30}
31
32pub fn id_to_address(id: u32) -> Option<String> {
36 let registry = get_address_registry();
37 let registry = registry.lock().unwrap();
38 for entry in registry.iter() {
39 if entry.1 == &id {
40 return Some(entry.0.clone());
41 }
42 }
43 None
44}
45
46impl RunnerClient for RemoteClient {
52 type Device = RemoteDevice;
53
54 fn register_op(&self, op: burn_ir::OperationIr) {
55 self.sender
56 .send(ComputeTask::RegisterOperation(Box::new(op)));
57 }
58
59 fn read_tensor_async(
60 &self,
61 tensor: burn_ir::TensorIr,
62 ) -> DynFut<Result<TensorData, ExecutionError>> {
63 let fut = self.sender.send_async(ComputeTask::ReadTensor(tensor));
65
66 Box::pin(async move {
67 match fut.await {
68 Ok(response) => match response {
69 TaskResponseContent::ReadTensor(res) => res,
70 _ => panic!("Invalid message type"),
71 },
72 Err(e) => Err(ExecutionError::Generic {
73 reason: format!("Failed to read tensor: {:?}", e),
74 backtrace: BackTrace::capture(),
75 }),
76 }
77 })
78 }
79
80 fn register_tensor_data(&self, data: TensorData) -> RouterTensor<Self> {
81 let id = self.sender.new_tensor_id();
82 let shape = data.shape.clone();
83 let dtype = data.dtype;
84
85 self.sender.send(ComputeTask::RegisterTensor(id, data));
86
87 RouterTensor::new(id, shape, dtype, self.clone())
88 }
89
90 fn device(&self) -> Self::Device {
91 self.device.clone()
92 }
93
94 fn sync(&self) -> Result<(), ExecutionError> {
95 let fut = self.sender.send_async(ComputeTask::SyncBackend);
97
98 match self.runtime.block_on(fut) {
99 Ok(response) => match response {
100 TaskResponseContent::SyncBackend(res) => res,
101 _ => panic!("Invalid message type"),
102 },
103 Err(e) => Err(ExecutionError::Generic {
104 reason: format!("Failed to sync: {:?}", e),
105 backtrace: BackTrace::capture(),
106 }),
107 }
108 }
109
110 fn seed(&self, seed: u64) {
111 self.sender.send(ComputeTask::Seed(seed));
112 }
113
114 fn create_empty_handle(&self) -> burn_ir::TensorId {
115 self.sender.new_tensor_id()
116 }
117
118 fn dtype_usage(&self, dtype: burn_std::DType) -> burn_backend::DTypeUsageSet {
119 let fut = self.sender.send_async(ComputeTask::DTypeUsage(dtype));
120
121 match self.runtime.block_on(fut) {
122 Ok(response) => match response {
123 TaskResponseContent::DTypeUsage(res) => res,
124 other => panic!("Invalid message type {other:?}"),
125 },
126 Err(e) => panic!("Failed to check dtype support: {:?}", e),
127 }
128 }
129}
130
131#[derive(Clone, PartialEq, Eq, Debug)]
132pub struct RemoteDevice {
134 pub(crate) address: Address,
135 pub(crate) id: u32,
137}
138
139impl RemoteDevice {
140 pub fn new(address: &str) -> Self {
142 let id = address_to_id(address);
143 Self {
144 address: Address::from_str(address).unwrap(),
145 id,
146 }
147 }
148}
149
150impl Default for RemoteDevice {
151 fn default() -> Self {
152 let address = match std::env::var("BURN_REMOTE_ADDRESS") {
153 Ok(address) => address,
154 Err(_) => String::from("ws://127.0.0.1:3000"),
155 };
156
157 Self::new(&address)
158 }
159}
160
161impl burn_std::device::Device for RemoteDevice {
162 fn from_id(device_id: DeviceId) -> Self {
163 if device_id.type_id != 0 {
164 panic!("Invalid device id: {device_id} (expected type 0)");
165 }
166 let address = id_to_address(device_id.index_id as u32)
167 .unwrap_or_else(|| panic!("Invalid device id: {device_id}"));
168 Self::new(&address)
169 }
170
171 fn to_id(&self) -> DeviceId {
172 DeviceId {
173 type_id: 0,
174 index_id: self.id as u16,
175 }
176 }
177}
178
179impl DeviceOps for RemoteDevice {}
180
181pub struct RemoteBridge<C: ProtocolClient> {
182 _p: PhantomData<C>,
183}
184
185pub struct RemoteTensorHandle<C: ProtocolClient> {
186 pub(crate) client: RemoteClient,
187 pub(crate) tensor: TensorIr,
188 pub(crate) _p: PhantomData<C>,
189}
190
191static TRANSFER_COUNTER: Mutex<Option<TensorTransferId>> = Mutex::new(None);
192
193fn get_next_transfer_id() -> TensorTransferId {
194 let mut transfer_counter = TRANSFER_COUNTER.lock().unwrap();
195 if transfer_counter.is_none() {
196 *transfer_counter = Some(0.into());
197
198 transfer_counter.unwrap()
199 } else {
200 let mut transfer_counter = transfer_counter.unwrap();
201 transfer_counter.next();
202
203 transfer_counter
204 }
205}
206
207impl<C: ProtocolClient> RemoteTensorHandle<C> {
208 pub(crate) fn change_backend(mut self, target_device: &RemoteDevice) -> Self {
214 let transfer_id = get_next_transfer_id();
215 self.client.sender.send(ComputeTask::ExposeTensorRemote {
216 tensor: self.tensor.clone(),
217 count: 1,
218 transfer_id,
219 });
220
221 let target_client = get_client::<RemoteChannel<C>>(target_device);
222
223 let new_id = target_client.sender.new_tensor_id();
224
225 let remote_tensor = TensorRemote {
226 transfer_id,
227 address: self.client.device.address.clone(),
228 };
229 target_client
230 .sender
231 .send(ComputeTask::RegisterTensorRemote(remote_tensor, new_id));
232
233 self.tensor.id = new_id;
234 self.client = target_client;
235
236 self
237 }
238}
239
240impl<C: ProtocolClient> MultiBackendBridge for RemoteBridge<C> {
241 type TensorHandle = RemoteTensorHandle<C>;
242 type Device = RemoteDevice;
243
244 fn change_backend_float(
245 tensor: Self::TensorHandle,
246 _shape: burn_backend::Shape,
247 target_device: &Self::Device,
248 ) -> Self::TensorHandle {
249 tensor.change_backend(target_device)
250 }
251
252 fn change_backend_int(
253 tensor: Self::TensorHandle,
254 _shape: burn_backend::Shape,
255 target_device: &Self::Device,
256 ) -> Self::TensorHandle {
257 tensor.change_backend(target_device)
258 }
259
260 fn change_backend_bool(
261 tensor: Self::TensorHandle,
262 _shape: burn_backend::Shape,
263 target_device: &Self::Device,
264 ) -> Self::TensorHandle {
265 tensor.change_backend(target_device)
266 }
267}
268
269#[cfg(test)]
270mod tests {
271 use super::*;
272
273 #[test]
274 fn test_address_to_id() {
275 let address1 = "ws://127.0.0.1:3000";
276 let address2 = "ws://127.0.0.1:3001";
277
278 let id1 = address_to_id(address1);
279 let id2 = address_to_id(address2);
280
281 assert_ne!(id1, id2);
282
283 assert_eq!(address_to_id(address1), id1);
284 assert_eq!(id_to_address(id1), Some(address1.to_string()));
285
286 assert_eq!(address_to_id(address2), id2);
287 assert_eq!(id_to_address(id2), Some(address2.to_string()));
288
289 let unused_id = u32::MAX;
290
291 assert_eq!(id_to_address(unused_id), None);
292 }
293}