burn_remote/client/
runner.rs

1use burn_common::future::DynFut;
2use burn_ir::TensorIr;
3use burn_router::{MultiBackendBridge, RouterTensor, RunnerClient, get_client};
4use burn_tensor::{
5    DType, TensorData,
6    backend::{DeviceId, DeviceOps},
7};
8use std::{
9    hash::{DefaultHasher, Hash, Hasher},
10    sync::Arc,
11};
12
13use crate::shared::{ComputeTask, TaskResponseContent, TensorRemote};
14
15use super::{WsChannel, WsClient};
16
17// It is very important to block on any request made with the sender, since ordering is crucial
18// when registering operation or creating tensors.
19//
20// The overhead is minimal, since we only wait for the task to be sent to the async
21// channel, but not sent to the websocket server and even less processed by the server.
22impl RunnerClient for WsClient {
23    type Device = WsDevice;
24
25    fn register(&self, op: burn_ir::OperationIr) {
26        self.sender
27            .send(ComputeTask::RegisterOperation(Box::new(op)));
28    }
29
30    fn read_tensor(&self, tensor: burn_ir::TensorIr) -> DynFut<TensorData> {
31        // Important for ordering to call the creation of the future sync.
32        let fut = self.sender.send_callback(ComputeTask::ReadTensor(tensor));
33
34        Box::pin(async move {
35            match fut.await {
36                TaskResponseContent::ReadTensor(data) => data,
37                _ => panic!("Invalid message type"),
38            }
39        })
40    }
41
42    fn register_tensor_data(&self, data: TensorData) -> RouterTensor<Self> {
43        let id = self.sender.new_tensor_id();
44        let shape = data.shape.clone();
45        let dtype = data.dtype;
46
47        self.sender.send(ComputeTask::RegisterTensor(id, data));
48
49        RouterTensor::new(id, shape, dtype, self.clone())
50    }
51
52    fn register_empty_tensor(
53        &self,
54        shape: Vec<usize>,
55        dtype: burn_tensor::DType,
56    ) -> RouterTensor<Self> {
57        let id = self.sender.new_tensor_id();
58
59        RouterTensor::new(id, shape, dtype, self.clone())
60    }
61
62    fn register_float_tensor(
63        &self,
64        shape: Vec<usize>,
65        _dtype: burn_tensor::FloatDType,
66    ) -> RouterTensor<Self> {
67        self.register_empty_tensor(shape, DType::F32)
68    }
69
70    fn device(&self) -> Self::Device {
71        self.device.clone()
72    }
73
74    fn sync(&self) {
75        // Important for ordering to call the creation of the future sync.
76        let fut = self.sender.send_callback(ComputeTask::SyncBackend);
77
78        let runtime = self.runtime.clone();
79
80        match runtime.block_on(fut) {
81            TaskResponseContent::SyncBackend => {}
82            _ => panic!("Invalid message type"),
83        };
84    }
85
86    fn seed(&self, _seed: u64) {
87        // TODO
88    }
89}
90
91#[derive(Clone, PartialEq, Eq, Debug)]
92/// The device contains the connection information of the server.
93pub struct WsDevice {
94    pub(crate) address: Arc<String>,
95    // Unique ID generated from hash of the address
96    pub(crate) id: u32,
97}
98
99impl WsDevice {
100    /// Create a device from an url.
101    pub fn new(url: &str) -> Self {
102        let mut address = String::new();
103
104        if !url.starts_with("ws://") {
105            address += "ws://";
106            address += url;
107        } else {
108            address += url;
109        };
110
111        let mut hasher = DefaultHasher::new();
112        address.hash(&mut hasher);
113        let id = hasher.finish() as u32;
114
115        Self {
116            address: Arc::new(address),
117            id,
118        }
119    }
120}
121
122impl Default for WsDevice {
123    fn default() -> Self {
124        let address = match std::env::var("BURN_REMOTE_ADDRESS") {
125            Ok(address) => address,
126            Err(_) => String::from("ws://127.0.0.1:3000"),
127        };
128
129        Self::new(&address)
130    }
131}
132
133impl DeviceOps for WsDevice {
134    fn id(&self) -> DeviceId {
135        DeviceId {
136            type_id: 0,
137            index_id: self.id,
138        }
139    }
140}
141
142pub struct WsBridge;
143
144pub struct RemoteTensorHandle {
145    pub(crate) client: WsClient,
146    pub(crate) tensor: TensorIr,
147}
148
149impl RemoteTensorHandle {
150    /// Changes the backend of the tensor via a WebSocket.
151    /// We ask the original server to expose the tensor, then ask the target server to fetch
152    /// the tensor. The target server will open a new websocket connection to the original server
153    /// to download the data.
154    /// This way the client never sees the tensor's data, and we avoid a bottleneck.
155    pub(crate) fn change_backend(mut self, target_device: &WsDevice) -> Self {
156        self.client.sender.send(ComputeTask::ExposeTensorRemote {
157            tensor: self.tensor.clone(),
158            count: 1,
159        });
160
161        let target_client: WsClient = get_client::<WsChannel>(target_device);
162
163        let new_id = target_client.sender.new_tensor_id();
164
165        let remote_tensor = TensorRemote {
166            id: self.tensor.id,
167            address: self.client.device.address.to_string(),
168        };
169        target_client
170            .sender
171            .send(ComputeTask::RegisterTensorRemote(remote_tensor, new_id));
172
173        self.tensor.id = new_id;
174        self.client = target_client;
175
176        self
177    }
178}
179
180impl MultiBackendBridge for WsBridge {
181    type TensorHandle = RemoteTensorHandle;
182    type Device = WsDevice;
183
184    fn change_backend_float(
185        tensor: Self::TensorHandle,
186        _shape: burn_tensor::Shape,
187        target_device: &Self::Device,
188    ) -> Self::TensorHandle {
189        tensor.change_backend(target_device)
190    }
191
192    fn change_backend_int(
193        tensor: Self::TensorHandle,
194        _shape: burn_tensor::Shape,
195        target_device: &Self::Device,
196    ) -> Self::TensorHandle {
197        tensor.change_backend(target_device)
198    }
199
200    fn change_backend_bool(
201        tensor: Self::TensorHandle,
202        _shape: burn_tensor::Shape,
203        target_device: &Self::Device,
204    ) -> Self::TensorHandle {
205        tensor.change_backend(target_device)
206    }
207}