burn_remote/client/
runner.rs

1use burn_common::future::DynFut;
2use burn_communication::{Address, ProtocolClient, data_service::TensorTransferId};
3use burn_ir::TensorIr;
4use burn_router::{MultiBackendBridge, RouterTensor, RunnerClient, get_client};
5use burn_tensor::{
6    Shape, TensorData,
7    backend::{DeviceId, DeviceOps},
8};
9use std::{
10    hash::{DefaultHasher, Hash, Hasher},
11    marker::PhantomData,
12    str::FromStr,
13    sync::Mutex,
14};
15
16use crate::shared::{ComputeTask, TaskResponseContent, TensorRemote};
17
18use super::{RemoteChannel, RemoteClient};
19
20// It is very important to block on any request made with the sender, since ordering is crucial
21// when registering operation or creating tensors.
22//
23// The overhead is minimal, since we only wait for the task to be sent to the async
24// channel, but not sent to the server and even less processed by the server.
25impl RunnerClient for RemoteClient {
26    type Device = RemoteDevice;
27
28    fn register(&self, op: burn_ir::OperationIr) {
29        self.sender
30            .send(ComputeTask::RegisterOperation(Box::new(op)));
31    }
32
33    fn read_tensor(&self, tensor: burn_ir::TensorIr) -> DynFut<TensorData> {
34        // Important for ordering to call the creation of the future sync.
35        let fut = self.sender.send_callback(ComputeTask::ReadTensor(tensor));
36
37        Box::pin(async move {
38            match fut.await {
39                TaskResponseContent::ReadTensor(data) => data,
40                _ => panic!("Invalid message type"),
41            }
42        })
43    }
44
45    fn register_tensor_data(&self, data: TensorData) -> RouterTensor<Self> {
46        let id = self.sender.new_tensor_id();
47        let shape = data.shape.clone();
48        let dtype = data.dtype;
49
50        self.sender.send(ComputeTask::RegisterTensor(id, data));
51
52        RouterTensor::new(id, Shape::from(shape), dtype, self.clone())
53    }
54
55    fn register_empty_tensor(&self, shape: Shape, dtype: burn_tensor::DType) -> RouterTensor<Self> {
56        let id = self.sender.new_tensor_id();
57
58        RouterTensor::new(id, shape, dtype, self.clone())
59    }
60
61    fn register_float_tensor(
62        &self,
63        shape: Shape,
64        dtype: burn_tensor::FloatDType,
65    ) -> RouterTensor<Self> {
66        self.register_empty_tensor(shape, dtype.into())
67    }
68
69    fn device(&self) -> Self::Device {
70        self.device.clone()
71    }
72
73    fn sync(&self) {
74        // Important for ordering to call the creation of the future sync.
75        let fut = self.sender.send_callback(ComputeTask::SyncBackend);
76
77        let runtime = self.runtime.clone();
78
79        match runtime.block_on(fut) {
80            TaskResponseContent::SyncBackend => {}
81            _ => panic!("Invalid message type"),
82        };
83    }
84
85    fn seed(&self, seed: u64) {
86        self.sender.send(ComputeTask::Seed(seed));
87    }
88}
89
90#[derive(Clone, PartialEq, Eq, Debug)]
91/// The device contains the connection information of the server.
92pub struct RemoteDevice {
93    pub(crate) address: Address,
94    // Unique ID generated from hash of the address
95    pub(crate) id: u32,
96}
97
98impl RemoteDevice {
99    /// Create a device from an url.
100    pub fn new(address: &str) -> Self {
101        let mut hasher = DefaultHasher::new();
102        address.hash(&mut hasher);
103        let id = hasher.finish() as u32;
104
105        Self {
106            address: Address::from_str(address).unwrap(),
107            id,
108        }
109    }
110}
111
112impl Default for RemoteDevice {
113    fn default() -> Self {
114        let address = match std::env::var("BURN_REMOTE_ADDRESS") {
115            Ok(address) => address,
116            Err(_) => String::from("ws://127.0.0.1:3000"),
117        };
118
119        Self::new(&address)
120    }
121}
122
123impl burn_common::device::Device for RemoteDevice {
124    fn from_id(_device_id: DeviceId) -> Self {
125        todo!("Should keep the address as ints, host should be type, port should be index.")
126    }
127
128    fn to_id(&self) -> DeviceId {
129        DeviceId {
130            type_id: 0,
131            index_id: self.id,
132        }
133    }
134
135    fn device_count(_type_id: u16) -> usize {
136        1
137    }
138}
139
140impl DeviceOps for RemoteDevice {}
141
142pub struct RemoteBridge<C: ProtocolClient> {
143    _p: PhantomData<C>,
144}
145
146pub struct RemoteTensorHandle<C: ProtocolClient> {
147    pub(crate) client: RemoteClient,
148    pub(crate) tensor: TensorIr,
149    pub(crate) _p: PhantomData<C>,
150}
151
152static TRANSFER_COUNTER: Mutex<Option<TensorTransferId>> = Mutex::new(None);
153
154fn get_next_transfer_id() -> TensorTransferId {
155    let mut transfer_counter = TRANSFER_COUNTER.lock().unwrap();
156    if transfer_counter.is_none() {
157        *transfer_counter = Some(0.into());
158
159        transfer_counter.unwrap()
160    } else {
161        let mut transfer_counter = transfer_counter.unwrap();
162        transfer_counter.next();
163
164        transfer_counter
165    }
166}
167
168impl<C: ProtocolClient> RemoteTensorHandle<C> {
169    /// Changes the backend of the tensor via a dWebSocket.
170    /// We ask the original server to expose the tensor, then ask the target server to fetch
171    /// the tensor. The target server will open a new network connection to the original server
172    /// to download the data.
173    /// This way the client never sees the tensor's data, and we avoid a bottleneck.
174    pub(crate) fn change_backend(mut self, target_device: &RemoteDevice) -> Self {
175        let transfer_id = get_next_transfer_id();
176        self.client.sender.send(ComputeTask::ExposeTensorRemote {
177            tensor: self.tensor.clone(),
178            count: 1,
179            transfer_id,
180        });
181
182        let target_client = get_client::<RemoteChannel<C>>(target_device);
183
184        let new_id = target_client.sender.new_tensor_id();
185
186        let remote_tensor = TensorRemote {
187            transfer_id,
188            address: self.client.device.address.clone(),
189        };
190        target_client
191            .sender
192            .send(ComputeTask::RegisterTensorRemote(remote_tensor, new_id));
193
194        self.tensor.id = new_id;
195        self.client = target_client;
196
197        self
198    }
199}
200
201impl<C: ProtocolClient> MultiBackendBridge for RemoteBridge<C> {
202    type TensorHandle = RemoteTensorHandle<C>;
203    type Device = RemoteDevice;
204
205    fn change_backend_float(
206        tensor: Self::TensorHandle,
207        _shape: burn_tensor::Shape,
208        target_device: &Self::Device,
209    ) -> Self::TensorHandle {
210        tensor.change_backend(target_device)
211    }
212
213    fn change_backend_int(
214        tensor: Self::TensorHandle,
215        _shape: burn_tensor::Shape,
216        target_device: &Self::Device,
217    ) -> Self::TensorHandle {
218        tensor.change_backend(target_device)
219    }
220
221    fn change_backend_bool(
222        tensor: Self::TensorHandle,
223        _shape: burn_tensor::Shape,
224        target_device: &Self::Device,
225    ) -> Self::TensorHandle {
226        tensor.change_backend(target_device)
227    }
228}