burn_remote/client/
runner.rs

1use super::{RemoteChannel, RemoteClient};
2use crate::shared::{ComputeTask, TaskResponseContent, TensorRemote};
3use burn_communication::{Address, ProtocolClient, data_service::TensorTransferId};
4use burn_ir::TensorIr;
5use burn_router::{MultiBackendBridge, RouterTensor, RunnerClient, get_client};
6use burn_std::{backtrace::BackTrace, future::DynFut};
7use burn_tensor::{
8    Shape, TensorData,
9    backend::{DeviceId, DeviceOps, ExecutionError},
10};
11use std::sync::OnceLock;
12use std::{collections::HashMap, marker::PhantomData, str::FromStr, sync::Mutex};
13
14// TODO: we should work with the parsed structure of Address, not the string.
15static ADDRESS_REGISTRY: OnceLock<Mutex<HashMap<String, u32>>> = OnceLock::new();
16
17fn get_address_registry() -> &'static Mutex<HashMap<String, u32>> {
18    ADDRESS_REGISTRY.get_or_init(|| Mutex::new(HashMap::new()))
19}
20
21/// Map a string network address to a (local runtime) global unique u32.
22///
23/// Globally stable over the lifetime of the process, shared between threads,
24/// If the address has never been seen, a new id will be created.
25/// If the address has been seen, the previous id will be returned.
26pub fn address_to_id<S: AsRef<str>>(address: S) -> u32 {
27    let registry = get_address_registry();
28    let mut registry = registry.lock().unwrap();
29    let next_id = registry.len() as u32;
30    *registry
31        .entry(address.as_ref().to_string())
32        .or_insert_with(|| next_id)
33}
34
35/// Look up an address by id.
36///
37/// Returns the same address given ids by [`address_to_id`].
38pub fn id_to_address(id: u32) -> Option<String> {
39    let registry = get_address_registry();
40    let registry = registry.lock().unwrap();
41    for entry in registry.iter() {
42        if entry.1 == &id {
43            return Some(entry.0.clone());
44        }
45    }
46    None
47}
48
49// It is very important to block on any request made with the sender, since ordering is crucial
50// when registering operation or creating tensors.
51//
52// The overhead is minimal, since we only wait for the task to be sent to the async
53// channel, but not sent to the server and even less processed by the server.
54impl RunnerClient for RemoteClient {
55    type Device = RemoteDevice;
56
57    fn register_op(&self, op: burn_ir::OperationIr) {
58        self.sender
59            .send(ComputeTask::RegisterOperation(Box::new(op)));
60    }
61
62    fn read_tensor_async(
63        &self,
64        tensor: burn_ir::TensorIr,
65    ) -> DynFut<Result<TensorData, ExecutionError>> {
66        // Important for ordering to call the creation of the future sync.
67        let fut = self.sender.send_async(ComputeTask::ReadTensor(tensor));
68
69        Box::pin(async move {
70            match fut.await {
71                Ok(response) => match response {
72                    TaskResponseContent::ReadTensor(res) => res,
73                    _ => panic!("Invalid message type"),
74                },
75                Err(e) => Err(ExecutionError::Generic {
76                    reason: format!("Failed to read tensor: {:?}", e),
77                    backtrace: BackTrace::capture(),
78                }),
79            }
80        })
81    }
82
83    fn register_tensor_data(&self, data: TensorData) -> RouterTensor<Self> {
84        let id = self.sender.new_tensor_id();
85        let shape = data.shape.clone();
86        let dtype = data.dtype;
87
88        self.sender.send(ComputeTask::RegisterTensor(id, data));
89
90        RouterTensor::new(id, Shape::from(shape), dtype, self.clone())
91    }
92
93    fn device(&self) -> Self::Device {
94        self.device.clone()
95    }
96
97    fn sync(&self) -> Result<(), ExecutionError> {
98        // Important for ordering to call the creation of the future sync.
99        let fut = self.sender.send_async(ComputeTask::SyncBackend);
100
101        match self.runtime.block_on(fut) {
102            Ok(response) => match response {
103                TaskResponseContent::SyncBackend(res) => res,
104                _ => panic!("Invalid message type"),
105            },
106            Err(e) => Err(ExecutionError::Generic {
107                reason: format!("Failed to sync: {:?}", e),
108                backtrace: BackTrace::capture(),
109            }),
110        }
111    }
112
113    fn seed(&self, seed: u64) {
114        self.sender.send(ComputeTask::Seed(seed));
115    }
116
117    fn create_empty_handle(&self) -> burn_ir::TensorId {
118        self.sender.new_tensor_id()
119    }
120
121    fn supports_dtype(&self, dtype: burn_std::DType) -> bool {
122        let fut = self.sender.send_async(ComputeTask::SupportsDType(dtype));
123
124        match self.runtime.block_on(fut) {
125            Ok(response) => match response {
126                TaskResponseContent::SupportsDType(res) => res,
127                _ => panic!("Invalid message type"),
128            },
129            Err(e) => panic!("Failed to check dtype support: {:?}", e),
130        }
131    }
132}
133
134#[derive(Clone, PartialEq, Eq, Debug)]
135/// The device contains the connection information of the server.
136pub struct RemoteDevice {
137    pub(crate) address: Address,
138    /// The id of the device in the local registry, see [`address_to_id`].
139    pub(crate) id: u32,
140}
141
142impl RemoteDevice {
143    /// Create a device from an url.
144    pub fn new(address: &str) -> Self {
145        let id = address_to_id(address);
146        Self {
147            address: Address::from_str(address).unwrap(),
148            id,
149        }
150    }
151}
152
153impl Default for RemoteDevice {
154    fn default() -> Self {
155        let address = match std::env::var("BURN_REMOTE_ADDRESS") {
156            Ok(address) => address,
157            Err(_) => String::from("ws://127.0.0.1:3000"),
158        };
159
160        Self::new(&address)
161    }
162}
163
164impl burn_std::device::Device for RemoteDevice {
165    fn from_id(device_id: DeviceId) -> Self {
166        if device_id.type_id != 0 {
167            panic!("Invalid device id: {device_id} (expected type 0)");
168        }
169        let address = id_to_address(device_id.index_id)
170            .unwrap_or_else(|| panic!("Invalid device id: {device_id}"));
171        Self::new(&address)
172    }
173
174    fn to_id(&self) -> DeviceId {
175        DeviceId {
176            type_id: 0,
177            index_id: self.id,
178        }
179    }
180
181    fn device_count(_type_id: u16) -> usize {
182        1
183    }
184}
185
186impl DeviceOps for RemoteDevice {}
187
188pub struct RemoteBridge<C: ProtocolClient> {
189    _p: PhantomData<C>,
190}
191
192pub struct RemoteTensorHandle<C: ProtocolClient> {
193    pub(crate) client: RemoteClient,
194    pub(crate) tensor: TensorIr,
195    pub(crate) _p: PhantomData<C>,
196}
197
198static TRANSFER_COUNTER: Mutex<Option<TensorTransferId>> = Mutex::new(None);
199
200fn get_next_transfer_id() -> TensorTransferId {
201    let mut transfer_counter = TRANSFER_COUNTER.lock().unwrap();
202    if transfer_counter.is_none() {
203        *transfer_counter = Some(0.into());
204
205        transfer_counter.unwrap()
206    } else {
207        let mut transfer_counter = transfer_counter.unwrap();
208        transfer_counter.next();
209
210        transfer_counter
211    }
212}
213
214impl<C: ProtocolClient> RemoteTensorHandle<C> {
215    /// Changes the backend of the tensor via a dWebSocket.
216    /// We ask the original server to expose the tensor, then ask the target server to fetch
217    /// the tensor. The target server will open a new network connection to the original server
218    /// to download the data.
219    /// This way the client never sees the tensor's data, and we avoid a bottleneck.
220    pub(crate) fn change_backend(mut self, target_device: &RemoteDevice) -> Self {
221        let transfer_id = get_next_transfer_id();
222        self.client.sender.send(ComputeTask::ExposeTensorRemote {
223            tensor: self.tensor.clone(),
224            count: 1,
225            transfer_id,
226        });
227
228        let target_client = get_client::<RemoteChannel<C>>(target_device);
229
230        let new_id = target_client.sender.new_tensor_id();
231
232        let remote_tensor = TensorRemote {
233            transfer_id,
234            address: self.client.device.address.clone(),
235        };
236        target_client
237            .sender
238            .send(ComputeTask::RegisterTensorRemote(remote_tensor, new_id));
239
240        self.tensor.id = new_id;
241        self.client = target_client;
242
243        self
244    }
245}
246
247impl<C: ProtocolClient> MultiBackendBridge for RemoteBridge<C> {
248    type TensorHandle = RemoteTensorHandle<C>;
249    type Device = RemoteDevice;
250
251    fn change_backend_float(
252        tensor: Self::TensorHandle,
253        _shape: burn_tensor::Shape,
254        target_device: &Self::Device,
255    ) -> Self::TensorHandle {
256        tensor.change_backend(target_device)
257    }
258
259    fn change_backend_int(
260        tensor: Self::TensorHandle,
261        _shape: burn_tensor::Shape,
262        target_device: &Self::Device,
263    ) -> Self::TensorHandle {
264        tensor.change_backend(target_device)
265    }
266
267    fn change_backend_bool(
268        tensor: Self::TensorHandle,
269        _shape: burn_tensor::Shape,
270        target_device: &Self::Device,
271    ) -> Self::TensorHandle {
272        tensor.change_backend(target_device)
273    }
274}
275
276#[cfg(test)]
277mod tests {
278    use super::*;
279
280    #[test]
281    fn test_address_to_id() {
282        let address1 = "ws://127.0.0.1:3000";
283        let address2 = "ws://127.0.0.1:3001";
284
285        let id1 = address_to_id(address1);
286        let id2 = address_to_id(address2);
287
288        assert_ne!(id1, id2);
289
290        assert_eq!(address_to_id(address1), id1);
291        assert_eq!(id_to_address(id1), Some(address1.to_string()));
292
293        assert_eq!(address_to_id(address2), id2);
294        assert_eq!(id_to_address(id2), Some(address2.to_string()));
295
296        let unused_id = u32::MAX;
297
298        assert_eq!(id_to_address(unused_id), None);
299    }
300}