Skip to main content

burn_remote/client/
runner.rs

1use super::{RemoteChannel, RemoteClient};
2use crate::shared::{ComputeTask, TaskResponseContent, TensorRemote};
3use burn_backend::{DeviceId, DeviceOps, ExecutionError, TensorData};
4use burn_communication::{Address, ProtocolClient, data_service::TensorTransferId};
5use burn_ir::TensorIr;
6use burn_router::{MultiBackendBridge, RouterTensor, RunnerClient, get_client};
7use burn_std::{backtrace::BackTrace, future::DynFut};
8use std::sync::OnceLock;
9use std::{collections::HashMap, marker::PhantomData, str::FromStr, sync::Mutex};
10
11// TODO: we should work with the parsed structure of Address, not the string.
12static ADDRESS_REGISTRY: OnceLock<Mutex<HashMap<String, u32>>> = OnceLock::new();
13
14fn get_address_registry() -> &'static Mutex<HashMap<String, u32>> {
15    ADDRESS_REGISTRY.get_or_init(|| Mutex::new(HashMap::new()))
16}
17
18/// Map a string network address to a (local runtime) global unique u32.
19///
20/// Globally stable over the lifetime of the process, shared between threads,
21/// If the address has never been seen, a new id will be created.
22/// If the address has been seen, the previous id will be returned.
23pub fn address_to_id<S: AsRef<str>>(address: S) -> u32 {
24    let registry = get_address_registry();
25    let mut registry = registry.lock().unwrap();
26    let next_id = registry.len() as u32;
27    *registry
28        .entry(address.as_ref().to_string())
29        .or_insert_with(|| next_id)
30}
31
32/// Look up an address by id.
33///
34/// Returns the same address given ids by [`address_to_id`].
35pub fn id_to_address(id: u32) -> Option<String> {
36    let registry = get_address_registry();
37    let registry = registry.lock().unwrap();
38    for entry in registry.iter() {
39        if entry.1 == &id {
40            return Some(entry.0.clone());
41        }
42    }
43    None
44}
45
46// It is very important to block on any request made with the sender, since ordering is crucial
47// when registering operation or creating tensors.
48//
49// The overhead is minimal, since we only wait for the task to be sent to the async
50// channel, but not sent to the server and even less processed by the server.
51impl RunnerClient for RemoteClient {
52    type Device = RemoteDevice;
53
54    fn register_op(&self, op: burn_ir::OperationIr) {
55        self.sender
56            .send(ComputeTask::RegisterOperation(Box::new(op)));
57    }
58
59    fn read_tensor_async(
60        &self,
61        tensor: burn_ir::TensorIr,
62    ) -> DynFut<Result<TensorData, ExecutionError>> {
63        // Important for ordering to call the creation of the future sync.
64        let fut = self.sender.send_async(ComputeTask::ReadTensor(tensor));
65
66        Box::pin(async move {
67            match fut.await {
68                Ok(response) => match response {
69                    TaskResponseContent::ReadTensor(res) => res,
70                    _ => panic!("Invalid message type"),
71                },
72                Err(e) => Err(ExecutionError::Generic {
73                    reason: format!("Failed to read tensor: {:?}", e),
74                    backtrace: BackTrace::capture(),
75                }),
76            }
77        })
78    }
79
80    fn register_tensor_data(&self, data: TensorData) -> RouterTensor<Self> {
81        let id = self.sender.new_tensor_id();
82        let shape = data.shape.clone();
83        let dtype = data.dtype;
84
85        self.sender.send(ComputeTask::RegisterTensor(id, data));
86
87        RouterTensor::new(id, shape, dtype, self.clone())
88    }
89
90    fn device(&self) -> Self::Device {
91        self.device.clone()
92    }
93
94    fn sync(&self) -> Result<(), ExecutionError> {
95        // Important for ordering to call the creation of the future sync.
96        let fut = self.sender.send_async(ComputeTask::SyncBackend);
97
98        match self.runtime.block_on(fut) {
99            Ok(response) => match response {
100                TaskResponseContent::SyncBackend(res) => res,
101                _ => panic!("Invalid message type"),
102            },
103            Err(e) => Err(ExecutionError::Generic {
104                reason: format!("Failed to sync: {:?}", e),
105                backtrace: BackTrace::capture(),
106            }),
107        }
108    }
109
110    fn seed(&self, seed: u64) {
111        self.sender.send(ComputeTask::Seed(seed));
112    }
113
114    fn create_empty_handle(&self) -> burn_ir::TensorId {
115        self.sender.new_tensor_id()
116    }
117
118    fn dtype_usage(&self, dtype: burn_std::DType) -> burn_backend::DTypeUsageSet {
119        let fut = self.sender.send_async(ComputeTask::DTypeUsage(dtype));
120
121        match self.runtime.block_on(fut) {
122            Ok(response) => match response {
123                TaskResponseContent::DTypeUsage(res) => res,
124                other => panic!("Invalid message type {other:?}"),
125            },
126            Err(e) => panic!("Failed to check dtype support: {:?}", e),
127        }
128    }
129}
130
131#[derive(Clone, PartialEq, Eq, Debug)]
132/// The device contains the connection information of the server.
133pub struct RemoteDevice {
134    pub(crate) address: Address,
135    /// The id of the device in the local registry, see [`address_to_id`].
136    pub(crate) id: u32,
137}
138
139impl RemoteDevice {
140    /// Create a device from an url.
141    pub fn new(address: &str) -> Self {
142        let id = address_to_id(address);
143        Self {
144            address: Address::from_str(address).unwrap(),
145            id,
146        }
147    }
148}
149
150impl Default for RemoteDevice {
151    fn default() -> Self {
152        let address = match std::env::var("BURN_REMOTE_ADDRESS") {
153            Ok(address) => address,
154            Err(_) => String::from("ws://127.0.0.1:3000"),
155        };
156
157        Self::new(&address)
158    }
159}
160
161impl burn_std::device::Device for RemoteDevice {
162    fn from_id(device_id: DeviceId) -> Self {
163        if device_id.type_id != 0 {
164            panic!("Invalid device id: {device_id} (expected type 0)");
165        }
166        let address = id_to_address(device_id.index_id as u32)
167            .unwrap_or_else(|| panic!("Invalid device id: {device_id}"));
168        Self::new(&address)
169    }
170
171    fn to_id(&self) -> DeviceId {
172        DeviceId {
173            type_id: 0,
174            index_id: self.id as u16,
175        }
176    }
177}
178
179impl DeviceOps for RemoteDevice {}
180
181pub struct RemoteBridge<C: ProtocolClient> {
182    _p: PhantomData<C>,
183}
184
185pub struct RemoteTensorHandle<C: ProtocolClient> {
186    pub(crate) client: RemoteClient,
187    pub(crate) tensor: TensorIr,
188    pub(crate) _p: PhantomData<C>,
189}
190
191static TRANSFER_COUNTER: Mutex<Option<TensorTransferId>> = Mutex::new(None);
192
193fn get_next_transfer_id() -> TensorTransferId {
194    let mut transfer_counter = TRANSFER_COUNTER.lock().unwrap();
195    if transfer_counter.is_none() {
196        *transfer_counter = Some(0.into());
197
198        transfer_counter.unwrap()
199    } else {
200        let mut transfer_counter = transfer_counter.unwrap();
201        transfer_counter.next();
202
203        transfer_counter
204    }
205}
206
207impl<C: ProtocolClient> RemoteTensorHandle<C> {
208    /// Changes the backend of the tensor via a dWebSocket.
209    /// We ask the original server to expose the tensor, then ask the target server to fetch
210    /// the tensor. The target server will open a new network connection to the original server
211    /// to download the data.
212    /// This way the client never sees the tensor's data, and we avoid a bottleneck.
213    pub(crate) fn change_backend(mut self, target_device: &RemoteDevice) -> Self {
214        let transfer_id = get_next_transfer_id();
215        self.client.sender.send(ComputeTask::ExposeTensorRemote {
216            tensor: self.tensor.clone(),
217            count: 1,
218            transfer_id,
219        });
220
221        let target_client = get_client::<RemoteChannel<C>>(target_device);
222
223        let new_id = target_client.sender.new_tensor_id();
224
225        let remote_tensor = TensorRemote {
226            transfer_id,
227            address: self.client.device.address.clone(),
228        };
229        target_client
230            .sender
231            .send(ComputeTask::RegisterTensorRemote(remote_tensor, new_id));
232
233        self.tensor.id = new_id;
234        self.client = target_client;
235
236        self
237    }
238}
239
240impl<C: ProtocolClient> MultiBackendBridge for RemoteBridge<C> {
241    type TensorHandle = RemoteTensorHandle<C>;
242    type Device = RemoteDevice;
243
244    fn change_backend_float(
245        tensor: Self::TensorHandle,
246        _shape: burn_backend::Shape,
247        target_device: &Self::Device,
248    ) -> Self::TensorHandle {
249        tensor.change_backend(target_device)
250    }
251
252    fn change_backend_int(
253        tensor: Self::TensorHandle,
254        _shape: burn_backend::Shape,
255        target_device: &Self::Device,
256    ) -> Self::TensorHandle {
257        tensor.change_backend(target_device)
258    }
259
260    fn change_backend_bool(
261        tensor: Self::TensorHandle,
262        _shape: burn_backend::Shape,
263        target_device: &Self::Device,
264    ) -> Self::TensorHandle {
265        tensor.change_backend(target_device)
266    }
267}
268
269#[cfg(test)]
270mod tests {
271    use super::*;
272
273    #[test]
274    fn test_address_to_id() {
275        let address1 = "ws://127.0.0.1:3000";
276        let address2 = "ws://127.0.0.1:3001";
277
278        let id1 = address_to_id(address1);
279        let id2 = address_to_id(address2);
280
281        assert_ne!(id1, id2);
282
283        assert_eq!(address_to_id(address1), id1);
284        assert_eq!(id_to_address(id1), Some(address1.to_string()));
285
286        assert_eq!(address_to_id(address2), id2);
287        assert_eq!(id_to_address(id2), Some(address2.to_string()));
288
289        let unused_id = u32::MAX;
290
291        assert_eq!(id_to_address(unused_id), None);
292    }
293}