burn_remote/client/
runner.rs1use burn_common::future::DynFut;
2use burn_ir::TensorIr;
3use burn_router::{MultiBackendBridge, RouterTensor, RunnerClient, get_client};
4use burn_tensor::{
5 DType, TensorData,
6 backend::{DeviceId, DeviceOps},
7};
8use std::{
9 hash::{DefaultHasher, Hash, Hasher},
10 sync::Arc,
11};
12
13use crate::shared::{ComputeTask, TaskResponseContent, TensorRemote};
14
15use super::{WsChannel, WsClient};
16
17impl RunnerClient for WsClient {
23 type Device = WsDevice;
24
25 fn register(&self, op: burn_ir::OperationIr) {
26 self.sender
27 .send(ComputeTask::RegisterOperation(Box::new(op)));
28 }
29
30 fn read_tensor(&self, tensor: burn_ir::TensorIr) -> DynFut<TensorData> {
31 let fut = self.sender.send_callback(ComputeTask::ReadTensor(tensor));
33
34 Box::pin(async move {
35 match fut.await {
36 TaskResponseContent::ReadTensor(data) => data,
37 _ => panic!("Invalid message type"),
38 }
39 })
40 }
41
42 fn register_tensor_data(&self, data: TensorData) -> RouterTensor<Self> {
43 let id = self.sender.new_tensor_id();
44 let shape = data.shape.clone();
45 let dtype = data.dtype;
46
47 self.sender.send(ComputeTask::RegisterTensor(id, data));
48
49 RouterTensor::new(id, shape, dtype, self.clone())
50 }
51
52 fn register_empty_tensor(
53 &self,
54 shape: Vec<usize>,
55 dtype: burn_tensor::DType,
56 ) -> RouterTensor<Self> {
57 let id = self.sender.new_tensor_id();
58
59 RouterTensor::new(id, shape, dtype, self.clone())
60 }
61
62 fn register_float_tensor(
63 &self,
64 shape: Vec<usize>,
65 _dtype: burn_tensor::FloatDType,
66 ) -> RouterTensor<Self> {
67 self.register_empty_tensor(shape, DType::F32)
68 }
69
70 fn device(&self) -> Self::Device {
71 self.device.clone()
72 }
73
74 fn sync(&self) {
75 let fut = self.sender.send_callback(ComputeTask::SyncBackend);
77
78 let runtime = self.runtime.clone();
79
80 match runtime.block_on(fut) {
81 TaskResponseContent::SyncBackend => {}
82 _ => panic!("Invalid message type"),
83 };
84 }
85
86 fn seed(&self, _seed: u64) {
87 }
89}
90
91#[derive(Clone, PartialEq, Eq, Debug)]
92pub struct WsDevice {
94 pub(crate) address: Arc<String>,
95 pub(crate) id: u32,
97}
98
99impl WsDevice {
100 pub fn new(url: &str) -> Self {
102 let mut address = String::new();
103
104 if !url.starts_with("ws://") {
105 address += "ws://";
106 address += url;
107 } else {
108 address += url;
109 };
110
111 let mut hasher = DefaultHasher::new();
112 address.hash(&mut hasher);
113 let id = hasher.finish() as u32;
114
115 Self {
116 address: Arc::new(address),
117 id,
118 }
119 }
120}
121
122impl Default for WsDevice {
123 fn default() -> Self {
124 let address = match std::env::var("BURN_REMOTE_ADDRESS") {
125 Ok(address) => address,
126 Err(_) => String::from("ws://127.0.0.1:3000"),
127 };
128
129 Self::new(&address)
130 }
131}
132
133impl DeviceOps for WsDevice {
134 fn id(&self) -> DeviceId {
135 DeviceId {
136 type_id: 0,
137 index_id: self.id,
138 }
139 }
140}
141
142pub struct WsBridge;
143
144pub struct RemoteTensorHandle {
145 pub(crate) client: WsClient,
146 pub(crate) tensor: TensorIr,
147}
148
149impl RemoteTensorHandle {
150 pub(crate) fn change_backend(mut self, target_device: &WsDevice) -> Self {
156 self.client.sender.send(ComputeTask::ExposeTensorRemote {
157 tensor: self.tensor.clone(),
158 count: 1,
159 });
160
161 let target_client: WsClient = get_client::<WsChannel>(target_device);
162
163 let new_id = target_client.sender.new_tensor_id();
164
165 let remote_tensor = TensorRemote {
166 id: self.tensor.id,
167 address: self.client.device.address.to_string(),
168 };
169 target_client
170 .sender
171 .send(ComputeTask::RegisterTensorRemote(remote_tensor, new_id));
172
173 self.tensor.id = new_id;
174 self.client = target_client;
175
176 self
177 }
178}
179
180impl MultiBackendBridge for WsBridge {
181 type TensorHandle = RemoteTensorHandle;
182 type Device = WsDevice;
183
184 fn change_backend_float(
185 tensor: Self::TensorHandle,
186 _shape: burn_tensor::Shape,
187 target_device: &Self::Device,
188 ) -> Self::TensorHandle {
189 tensor.change_backend(target_device)
190 }
191
192 fn change_backend_int(
193 tensor: Self::TensorHandle,
194 _shape: burn_tensor::Shape,
195 target_device: &Self::Device,
196 ) -> Self::TensorHandle {
197 tensor.change_backend(target_device)
198 }
199
200 fn change_backend_bool(
201 tensor: Self::TensorHandle,
202 _shape: burn_tensor::Shape,
203 target_device: &Self::Device,
204 ) -> Self::TensorHandle {
205 tensor.change_backend(target_device)
206 }
207}