1use super::{RemoteChannel, RemoteClient};
2use crate::shared::{ComputeTask, TaskResponseContent, TensorRemote};
3use burn_communication::{Address, ProtocolClient, data_service::TensorTransferId};
4use burn_ir::TensorIr;
5use burn_router::{MultiBackendBridge, RouterTensor, RunnerClient, get_client};
6use burn_std::{backtrace::BackTrace, future::DynFut};
7use burn_tensor::{
8 Shape, TensorData,
9 backend::{DeviceId, DeviceOps, ExecutionError},
10};
11use std::sync::OnceLock;
12use std::{collections::HashMap, marker::PhantomData, str::FromStr, sync::Mutex};
13
14static ADDRESS_REGISTRY: OnceLock<Mutex<HashMap<String, u32>>> = OnceLock::new();
16
17fn get_address_registry() -> &'static Mutex<HashMap<String, u32>> {
18 ADDRESS_REGISTRY.get_or_init(|| Mutex::new(HashMap::new()))
19}
20
21pub fn address_to_id<S: AsRef<str>>(address: S) -> u32 {
27 let registry = get_address_registry();
28 let mut registry = registry.lock().unwrap();
29 let next_id = registry.len() as u32;
30 *registry
31 .entry(address.as_ref().to_string())
32 .or_insert_with(|| next_id)
33}
34
35pub fn id_to_address(id: u32) -> Option<String> {
39 let registry = get_address_registry();
40 let registry = registry.lock().unwrap();
41 for entry in registry.iter() {
42 if entry.1 == &id {
43 return Some(entry.0.clone());
44 }
45 }
46 None
47}
48
49impl RunnerClient for RemoteClient {
55 type Device = RemoteDevice;
56
57 fn register_op(&self, op: burn_ir::OperationIr) {
58 self.sender
59 .send(ComputeTask::RegisterOperation(Box::new(op)));
60 }
61
62 fn read_tensor_async(
63 &self,
64 tensor: burn_ir::TensorIr,
65 ) -> DynFut<Result<TensorData, ExecutionError>> {
66 let fut = self.sender.send_async(ComputeTask::ReadTensor(tensor));
68
69 Box::pin(async move {
70 match fut.await {
71 Ok(response) => match response {
72 TaskResponseContent::ReadTensor(res) => res,
73 _ => panic!("Invalid message type"),
74 },
75 Err(e) => Err(ExecutionError::Generic {
76 reason: format!("Failed to read tensor: {:?}", e),
77 backtrace: BackTrace::capture(),
78 }),
79 }
80 })
81 }
82
83 fn register_tensor_data(&self, data: TensorData) -> RouterTensor<Self> {
84 let id = self.sender.new_tensor_id();
85 let shape = data.shape.clone();
86 let dtype = data.dtype;
87
88 self.sender.send(ComputeTask::RegisterTensor(id, data));
89
90 RouterTensor::new(id, Shape::from(shape), dtype, self.clone())
91 }
92
93 fn device(&self) -> Self::Device {
94 self.device.clone()
95 }
96
97 fn sync(&self) -> Result<(), ExecutionError> {
98 let fut = self.sender.send_async(ComputeTask::SyncBackend);
100
101 match self.runtime.block_on(fut) {
102 Ok(response) => match response {
103 TaskResponseContent::SyncBackend(res) => res,
104 _ => panic!("Invalid message type"),
105 },
106 Err(e) => Err(ExecutionError::Generic {
107 reason: format!("Failed to sync: {:?}", e),
108 backtrace: BackTrace::capture(),
109 }),
110 }
111 }
112
113 fn seed(&self, seed: u64) {
114 self.sender.send(ComputeTask::Seed(seed));
115 }
116
117 fn create_empty_handle(&self) -> burn_ir::TensorId {
118 self.sender.new_tensor_id()
119 }
120
121 fn supports_dtype(&self, dtype: burn_std::DType) -> bool {
122 let fut = self.sender.send_async(ComputeTask::SupportsDType(dtype));
123
124 match self.runtime.block_on(fut) {
125 Ok(response) => match response {
126 TaskResponseContent::SupportsDType(res) => res,
127 _ => panic!("Invalid message type"),
128 },
129 Err(e) => panic!("Failed to check dtype support: {:?}", e),
130 }
131 }
132}
133
134#[derive(Clone, PartialEq, Eq, Debug)]
135pub struct RemoteDevice {
137 pub(crate) address: Address,
138 pub(crate) id: u32,
140}
141
142impl RemoteDevice {
143 pub fn new(address: &str) -> Self {
145 let id = address_to_id(address);
146 Self {
147 address: Address::from_str(address).unwrap(),
148 id,
149 }
150 }
151}
152
153impl Default for RemoteDevice {
154 fn default() -> Self {
155 let address = match std::env::var("BURN_REMOTE_ADDRESS") {
156 Ok(address) => address,
157 Err(_) => String::from("ws://127.0.0.1:3000"),
158 };
159
160 Self::new(&address)
161 }
162}
163
164impl burn_std::device::Device for RemoteDevice {
165 fn from_id(device_id: DeviceId) -> Self {
166 if device_id.type_id != 0 {
167 panic!("Invalid device id: {device_id} (expected type 0)");
168 }
169 let address = id_to_address(device_id.index_id)
170 .unwrap_or_else(|| panic!("Invalid device id: {device_id}"));
171 Self::new(&address)
172 }
173
174 fn to_id(&self) -> DeviceId {
175 DeviceId {
176 type_id: 0,
177 index_id: self.id,
178 }
179 }
180
181 fn device_count(_type_id: u16) -> usize {
182 1
183 }
184}
185
186impl DeviceOps for RemoteDevice {}
187
188pub struct RemoteBridge<C: ProtocolClient> {
189 _p: PhantomData<C>,
190}
191
192pub struct RemoteTensorHandle<C: ProtocolClient> {
193 pub(crate) client: RemoteClient,
194 pub(crate) tensor: TensorIr,
195 pub(crate) _p: PhantomData<C>,
196}
197
198static TRANSFER_COUNTER: Mutex<Option<TensorTransferId>> = Mutex::new(None);
199
200fn get_next_transfer_id() -> TensorTransferId {
201 let mut transfer_counter = TRANSFER_COUNTER.lock().unwrap();
202 if transfer_counter.is_none() {
203 *transfer_counter = Some(0.into());
204
205 transfer_counter.unwrap()
206 } else {
207 let mut transfer_counter = transfer_counter.unwrap();
208 transfer_counter.next();
209
210 transfer_counter
211 }
212}
213
214impl<C: ProtocolClient> RemoteTensorHandle<C> {
215 pub(crate) fn change_backend(mut self, target_device: &RemoteDevice) -> Self {
221 let transfer_id = get_next_transfer_id();
222 self.client.sender.send(ComputeTask::ExposeTensorRemote {
223 tensor: self.tensor.clone(),
224 count: 1,
225 transfer_id,
226 });
227
228 let target_client = get_client::<RemoteChannel<C>>(target_device);
229
230 let new_id = target_client.sender.new_tensor_id();
231
232 let remote_tensor = TensorRemote {
233 transfer_id,
234 address: self.client.device.address.clone(),
235 };
236 target_client
237 .sender
238 .send(ComputeTask::RegisterTensorRemote(remote_tensor, new_id));
239
240 self.tensor.id = new_id;
241 self.client = target_client;
242
243 self
244 }
245}
246
247impl<C: ProtocolClient> MultiBackendBridge for RemoteBridge<C> {
248 type TensorHandle = RemoteTensorHandle<C>;
249 type Device = RemoteDevice;
250
251 fn change_backend_float(
252 tensor: Self::TensorHandle,
253 _shape: burn_tensor::Shape,
254 target_device: &Self::Device,
255 ) -> Self::TensorHandle {
256 tensor.change_backend(target_device)
257 }
258
259 fn change_backend_int(
260 tensor: Self::TensorHandle,
261 _shape: burn_tensor::Shape,
262 target_device: &Self::Device,
263 ) -> Self::TensorHandle {
264 tensor.change_backend(target_device)
265 }
266
267 fn change_backend_bool(
268 tensor: Self::TensorHandle,
269 _shape: burn_tensor::Shape,
270 target_device: &Self::Device,
271 ) -> Self::TensorHandle {
272 tensor.change_backend(target_device)
273 }
274}
275
276#[cfg(test)]
277mod tests {
278 use super::*;
279
280 #[test]
281 fn test_address_to_id() {
282 let address1 = "ws://127.0.0.1:3000";
283 let address2 = "ws://127.0.0.1:3001";
284
285 let id1 = address_to_id(address1);
286 let id2 = address_to_id(address2);
287
288 assert_ne!(id1, id2);
289
290 assert_eq!(address_to_id(address1), id1);
291 assert_eq!(id_to_address(id1), Some(address1.to_string()));
292
293 assert_eq!(address_to_id(address2), id2);
294 assert_eq!(id_to_address(id2), Some(address2.to_string()));
295
296 let unused_id = u32::MAX;
297
298 assert_eq!(id_to_address(unused_id), None);
299 }
300}