burn_remote/client/
runner.rs1use burn_common::future::DynFut;
2use burn_communication::{Address, ProtocolClient, data_service::TensorTransferId};
3use burn_ir::TensorIr;
4use burn_router::{MultiBackendBridge, RouterTensor, RunnerClient, get_client};
5use burn_tensor::{
6 Shape, TensorData,
7 backend::{DeviceId, DeviceOps},
8};
9use std::{
10 hash::{DefaultHasher, Hash, Hasher},
11 marker::PhantomData,
12 str::FromStr,
13 sync::Mutex,
14};
15
16use crate::shared::{ComputeTask, TaskResponseContent, TensorRemote};
17
18use super::{RemoteChannel, RemoteClient};
19
20impl RunnerClient for RemoteClient {
26 type Device = RemoteDevice;
27
28 fn register(&self, op: burn_ir::OperationIr) {
29 self.sender
30 .send(ComputeTask::RegisterOperation(Box::new(op)));
31 }
32
33 fn read_tensor(&self, tensor: burn_ir::TensorIr) -> DynFut<TensorData> {
34 let fut = self.sender.send_callback(ComputeTask::ReadTensor(tensor));
36
37 Box::pin(async move {
38 match fut.await {
39 TaskResponseContent::ReadTensor(data) => data,
40 _ => panic!("Invalid message type"),
41 }
42 })
43 }
44
45 fn register_tensor_data(&self, data: TensorData) -> RouterTensor<Self> {
46 let id = self.sender.new_tensor_id();
47 let shape = data.shape.clone();
48 let dtype = data.dtype;
49
50 self.sender.send(ComputeTask::RegisterTensor(id, data));
51
52 RouterTensor::new(id, Shape::from(shape), dtype, self.clone())
53 }
54
55 fn register_empty_tensor(&self, shape: Shape, dtype: burn_tensor::DType) -> RouterTensor<Self> {
56 let id = self.sender.new_tensor_id();
57
58 RouterTensor::new(id, shape, dtype, self.clone())
59 }
60
61 fn register_float_tensor(
62 &self,
63 shape: Shape,
64 dtype: burn_tensor::FloatDType,
65 ) -> RouterTensor<Self> {
66 self.register_empty_tensor(shape, dtype.into())
67 }
68
69 fn device(&self) -> Self::Device {
70 self.device.clone()
71 }
72
73 fn sync(&self) {
74 let fut = self.sender.send_callback(ComputeTask::SyncBackend);
76
77 let runtime = self.runtime.clone();
78
79 match runtime.block_on(fut) {
80 TaskResponseContent::SyncBackend => {}
81 _ => panic!("Invalid message type"),
82 };
83 }
84
85 fn seed(&self, seed: u64) {
86 self.sender.send(ComputeTask::Seed(seed));
87 }
88}
89
90#[derive(Clone, PartialEq, Eq, Debug)]
91pub struct RemoteDevice {
93 pub(crate) address: Address,
94 pub(crate) id: u32,
96}
97
98impl RemoteDevice {
99 pub fn new(address: &str) -> Self {
101 let mut hasher = DefaultHasher::new();
102 address.hash(&mut hasher);
103 let id = hasher.finish() as u32;
104
105 Self {
106 address: Address::from_str(address).unwrap(),
107 id,
108 }
109 }
110}
111
112impl Default for RemoteDevice {
113 fn default() -> Self {
114 let address = match std::env::var("BURN_REMOTE_ADDRESS") {
115 Ok(address) => address,
116 Err(_) => String::from("ws://127.0.0.1:3000"),
117 };
118
119 Self::new(&address)
120 }
121}
122
123impl burn_common::device::Device for RemoteDevice {
124 fn from_id(_device_id: DeviceId) -> Self {
125 todo!("Should keep the address as ints, host should be type, port should be index.")
126 }
127
128 fn to_id(&self) -> DeviceId {
129 DeviceId {
130 type_id: 0,
131 index_id: self.id,
132 }
133 }
134
135 fn device_count(_type_id: u16) -> usize {
136 1
137 }
138}
139
140impl DeviceOps for RemoteDevice {}
141
142pub struct RemoteBridge<C: ProtocolClient> {
143 _p: PhantomData<C>,
144}
145
146pub struct RemoteTensorHandle<C: ProtocolClient> {
147 pub(crate) client: RemoteClient,
148 pub(crate) tensor: TensorIr,
149 pub(crate) _p: PhantomData<C>,
150}
151
152static TRANSFER_COUNTER: Mutex<Option<TensorTransferId>> = Mutex::new(None);
153
154fn get_next_transfer_id() -> TensorTransferId {
155 let mut transfer_counter = TRANSFER_COUNTER.lock().unwrap();
156 if transfer_counter.is_none() {
157 *transfer_counter = Some(0.into());
158
159 transfer_counter.unwrap()
160 } else {
161 let mut transfer_counter = transfer_counter.unwrap();
162 transfer_counter.next();
163
164 transfer_counter
165 }
166}
167
168impl<C: ProtocolClient> RemoteTensorHandle<C> {
169 pub(crate) fn change_backend(mut self, target_device: &RemoteDevice) -> Self {
175 let transfer_id = get_next_transfer_id();
176 self.client.sender.send(ComputeTask::ExposeTensorRemote {
177 tensor: self.tensor.clone(),
178 count: 1,
179 transfer_id,
180 });
181
182 let target_client = get_client::<RemoteChannel<C>>(target_device);
183
184 let new_id = target_client.sender.new_tensor_id();
185
186 let remote_tensor = TensorRemote {
187 transfer_id,
188 address: self.client.device.address.clone(),
189 };
190 target_client
191 .sender
192 .send(ComputeTask::RegisterTensorRemote(remote_tensor, new_id));
193
194 self.tensor.id = new_id;
195 self.client = target_client;
196
197 self
198 }
199}
200
201impl<C: ProtocolClient> MultiBackendBridge for RemoteBridge<C> {
202 type TensorHandle = RemoteTensorHandle<C>;
203 type Device = RemoteDevice;
204
205 fn change_backend_float(
206 tensor: Self::TensorHandle,
207 _shape: burn_tensor::Shape,
208 target_device: &Self::Device,
209 ) -> Self::TensorHandle {
210 tensor.change_backend(target_device)
211 }
212
213 fn change_backend_int(
214 tensor: Self::TensorHandle,
215 _shape: burn_tensor::Shape,
216 target_device: &Self::Device,
217 ) -> Self::TensorHandle {
218 tensor.change_backend(target_device)
219 }
220
221 fn change_backend_bool(
222 tensor: Self::TensorHandle,
223 _shape: burn_tensor::Shape,
224 target_device: &Self::Device,
225 ) -> Self::TensorHandle {
226 tensor.change_backend(target_device)
227 }
228}