use super::{RemoteChannel, RemoteClient};
use crate::shared::{ComputeTask, TaskResponseContent, TensorRemote};
use burn_backend::{DeviceId, DeviceOps, ExecutionError, TensorData};
use burn_communication::{Address, ProtocolClient, data_service::TensorTransferId};
use burn_ir::TensorIr;
use burn_router::{MultiBackendBridge, RouterTensor, RunnerClient, get_client};
use burn_std::{backtrace::BackTrace, future::DynFut};
use std::sync::OnceLock;
use std::{collections::HashMap, marker::PhantomData, str::FromStr, sync::Mutex};
static ADDRESS_REGISTRY: OnceLock<Mutex<HashMap<String, u32>>> = OnceLock::new();
fn get_address_registry() -> &'static Mutex<HashMap<String, u32>> {
ADDRESS_REGISTRY.get_or_init(|| Mutex::new(HashMap::new()))
}
pub fn address_to_id<S: AsRef<str>>(address: S) -> u32 {
let registry = get_address_registry();
let mut registry = registry.lock().unwrap();
let next_id = registry.len() as u32;
*registry
.entry(address.as_ref().to_string())
.or_insert_with(|| next_id)
}
pub fn id_to_address(id: u32) -> Option<String> {
let registry = get_address_registry();
let registry = registry.lock().unwrap();
for entry in registry.iter() {
if entry.1 == &id {
return Some(entry.0.clone());
}
}
None
}
impl RunnerClient for RemoteClient {
type Device = RemoteDevice;
fn register_op(&self, op: burn_ir::OperationIr) {
self.sender
.send(ComputeTask::RegisterOperation(Box::new(op)));
}
fn read_tensor_async(
&self,
tensor: burn_ir::TensorIr,
) -> DynFut<Result<TensorData, ExecutionError>> {
let fut = self.sender.send_async(ComputeTask::ReadTensor(tensor));
Box::pin(async move {
match fut.await {
Ok(response) => match response {
TaskResponseContent::ReadTensor(res) => res,
_ => panic!("Invalid message type"),
},
Err(e) => Err(ExecutionError::Generic {
reason: format!("Failed to read tensor: {:?}", e),
backtrace: BackTrace::capture(),
}),
}
})
}
fn register_tensor_data(&self, data: TensorData) -> RouterTensor<Self> {
let id = self.sender.new_tensor_id();
let shape = data.shape.clone();
let dtype = data.dtype;
self.sender.send(ComputeTask::RegisterTensor(id, data));
RouterTensor::new(id, shape, dtype, self.clone())
}
fn device(&self) -> Self::Device {
self.device.clone()
}
fn sync(&self) -> Result<(), ExecutionError> {
let fut = self.sender.send_async(ComputeTask::SyncBackend);
match self.runtime.block_on(fut) {
Ok(response) => match response {
TaskResponseContent::SyncBackend(res) => res,
_ => panic!("Invalid message type"),
},
Err(e) => Err(ExecutionError::Generic {
reason: format!("Failed to sync: {:?}", e),
backtrace: BackTrace::capture(),
}),
}
}
fn seed(&self, seed: u64) {
self.sender.send(ComputeTask::Seed(seed));
}
fn create_empty_handle(&self) -> burn_ir::TensorId {
self.sender.new_tensor_id()
}
fn dtype_usage(&self, dtype: burn_std::DType) -> burn_backend::DTypeUsageSet {
let fut = self.sender.send_async(ComputeTask::DTypeUsage(dtype));
match self.runtime.block_on(fut) {
Ok(response) => match response {
TaskResponseContent::DTypeUsage(res) => res,
other => panic!("Invalid message type {other:?}"),
},
Err(e) => panic!("Failed to check dtype support: {:?}", e),
}
}
}
#[derive(Clone, PartialEq, Eq, Debug)]
pub struct RemoteDevice {
pub(crate) address: Address,
pub(crate) id: u32,
}
impl RemoteDevice {
pub fn new(address: &str) -> Self {
let id = address_to_id(address);
Self {
address: Address::from_str(address).unwrap(),
id,
}
}
}
impl Default for RemoteDevice {
fn default() -> Self {
let address = match std::env::var("BURN_REMOTE_ADDRESS") {
Ok(address) => address,
Err(_) => String::from("ws://127.0.0.1:3000"),
};
Self::new(&address)
}
}
impl burn_std::device::Device for RemoteDevice {
fn from_id(device_id: DeviceId) -> Self {
if device_id.type_id != 0 {
panic!("Invalid device id: {device_id} (expected type 0)");
}
let address = id_to_address(device_id.index_id as u32)
.unwrap_or_else(|| panic!("Invalid device id: {device_id}"));
Self::new(&address)
}
fn to_id(&self) -> DeviceId {
DeviceId {
type_id: 0,
index_id: self.id as u16,
}
}
}
impl DeviceOps for RemoteDevice {}
pub struct RemoteBridge<C: ProtocolClient> {
_p: PhantomData<C>,
}
pub struct RemoteTensorHandle<C: ProtocolClient> {
pub(crate) client: RemoteClient,
pub(crate) tensor: TensorIr,
pub(crate) _p: PhantomData<C>,
}
static TRANSFER_COUNTER: Mutex<Option<TensorTransferId>> = Mutex::new(None);
fn get_next_transfer_id() -> TensorTransferId {
let mut transfer_counter = TRANSFER_COUNTER.lock().unwrap();
if transfer_counter.is_none() {
*transfer_counter = Some(0.into());
transfer_counter.unwrap()
} else {
let mut transfer_counter = transfer_counter.unwrap();
transfer_counter.next();
transfer_counter
}
}
impl<C: ProtocolClient> RemoteTensorHandle<C> {
pub(crate) fn change_backend(mut self, target_device: &RemoteDevice) -> Self {
let transfer_id = get_next_transfer_id();
self.client.sender.send(ComputeTask::ExposeTensorRemote {
tensor: self.tensor.clone(),
count: 1,
transfer_id,
});
let target_client = get_client::<RemoteChannel<C>>(target_device);
let new_id = target_client.sender.new_tensor_id();
let remote_tensor = TensorRemote {
transfer_id,
address: self.client.device.address.clone(),
};
target_client
.sender
.send(ComputeTask::RegisterTensorRemote(remote_tensor, new_id));
self.tensor.id = new_id;
self.client = target_client;
self
}
}
impl<C: ProtocolClient> MultiBackendBridge for RemoteBridge<C> {
type TensorHandle = RemoteTensorHandle<C>;
type Device = RemoteDevice;
fn change_backend_float(
tensor: Self::TensorHandle,
_shape: burn_backend::Shape,
target_device: &Self::Device,
) -> Self::TensorHandle {
tensor.change_backend(target_device)
}
fn change_backend_int(
tensor: Self::TensorHandle,
_shape: burn_backend::Shape,
target_device: &Self::Device,
) -> Self::TensorHandle {
tensor.change_backend(target_device)
}
fn change_backend_bool(
tensor: Self::TensorHandle,
_shape: burn_backend::Shape,
target_device: &Self::Device,
) -> Self::TensorHandle {
tensor.change_backend(target_device)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_address_to_id() {
let address1 = "ws://127.0.0.1:3000";
let address2 = "ws://127.0.0.1:3001";
let id1 = address_to_id(address1);
let id2 = address_to_id(address2);
assert_ne!(id1, id2);
assert_eq!(address_to_id(address1), id1);
assert_eq!(id_to_address(id1), Some(address1.to_string()));
assert_eq!(address_to_id(address2), id2);
assert_eq!(id_to_address(id2), Some(address2.to_string()));
let unused_id = u32::MAX;
assert_eq!(id_to_address(unused_id), None);
}
}