use crate::Message;
use crate::base::Protocol;
use crate::base::{Address, CommunicationChannel, ProtocolClient, ProtocolServer};
use burn_tensor::{TensorData, backend::Backend};
use serde::{Deserialize, Serialize};
use std::{collections::HashMap, marker::PhantomData, sync::Arc};
use tokio::sync::Mutex;
use tokio::sync::Notify;
use tokio_util::sync::CancellationToken;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct TensorTransferId(u64);
impl From<u64> for TensorTransferId {
fn from(value: u64) -> Self {
Self(value)
}
}
impl TensorTransferId {
pub fn next(&mut self) {
self.0 += 1;
}
}
#[derive(Debug, Serialize, Deserialize)]
enum DataServiceMessage {
TensorRequest(TensorTransferId),
Tensor(TensorData),
}
type ClientChannelRef<C> = Arc<Mutex<<C as ProtocolClient>::Channel>>;
pub struct TensorDataService<B: Backend, P: Protocol<Client: ProtocolClient>> {
pub exposed_tensors: Mutex<HashMap<TensorTransferId, TensorExposeState>>,
pub channels: Mutex<HashMap<Address, ClientChannelRef<P::Client>>>,
pub new_tensor_notify: Arc<Notify>,
cancel_token: CancellationToken,
_phantom_data: PhantomData<B>,
}
pub struct TensorExposeState {
pub bytes: bytes::Bytes,
pub max_downloads: u32,
pub cur_download_count: u32,
}
pub trait TensorDataServer<B: Backend, P: Protocol> {
fn route_tensor_data_service(self, state: Arc<TensorDataService<B, P>>) -> Self;
}
impl<B: Backend, S: ProtocolServer + Sized, P: Protocol<Server = S> + 'static>
TensorDataServer<B, P> for S
{
fn route_tensor_data_service(self, state: Arc<TensorDataService<B, P>>) -> Self {
self.route("/data", async move |stream: S::Channel| {
state.handle_data_channel(stream).await;
})
}
}
impl<B: Backend, P: Protocol> TensorDataService<B, P> {
pub fn new(cancel_token: CancellationToken) -> Self {
Self {
exposed_tensors: Mutex::new(HashMap::new()),
channels: Mutex::new(HashMap::new()),
new_tensor_notify: Arc::new(Notify::new()),
cancel_token,
_phantom_data: PhantomData::<B>,
}
}
pub async fn expose(
&self,
tensor: B::FloatTensorPrimitive,
max_downloads: u32,
transfer_id: TensorTransferId,
) {
let data = B::float_into_data(tensor).await.unwrap();
self.expose_data(data, max_downloads, transfer_id).await
}
pub async fn expose_data(
&self,
tensor_data: TensorData,
max_downloads: u32,
transfer_id: TensorTransferId,
) {
let bytes: bytes::Bytes = rmp_serde::to_vec(&DataServiceMessage::Tensor(tensor_data))
.unwrap()
.into();
let mut exposed_tensors = self.exposed_tensors.lock().await;
exposed_tensors.insert(
transfer_id,
TensorExposeState {
bytes,
max_downloads,
cur_download_count: 0,
},
);
core::mem::drop(exposed_tensors);
self.new_tensor_notify.notify_waiters();
}
pub async fn close(&self) {
let mut streams = self.channels.lock().await;
for (_, stream) in streams.drain() {
let mut stream = stream.lock().await;
stream
.close()
.await
.expect("Failed to close WebSocket stream");
}
}
pub async fn download_tensor(
&self,
remote: Address,
transfer_id: TensorTransferId,
) -> Option<TensorData> {
log::info!("Downloading tensor from {remote:?}");
let stream = self.get_data_stream(remote).await;
let mut stream = stream.lock().await;
let bytes: bytes::Bytes =
rmp_serde::to_vec(&DataServiceMessage::TensorRequest(transfer_id))
.unwrap()
.into();
stream
.send(Message::new(bytes))
.await
.expect("Failed to send download id");
if let Ok(msg) = stream.recv().await {
let Some(msg) = msg else {
log::warn!("Received None message from the websocket, closing connection.");
return None;
};
let DataServiceMessage::Tensor(data) = rmp_serde::from_slice(&msg.data)
.expect("Can deserialize messages from the websocket.")
else {
panic!("Message should have been TensorData")
};
return Some(data);
}
log::warn!("Closed connection");
None
}
async fn get_data_stream(
&self,
address: Address,
) -> Arc<Mutex<<P::Client as ProtocolClient>::Channel>> {
let mut streams = self.channels.lock().await;
match streams.get(&address) {
Some(stream) => stream.clone(),
None => {
let stream = P::Client::connect(address.clone(), "data").await;
let Some(stream) = stream else {
panic!("Failed to connect to data server at {address:?}");
};
let stream = Arc::new(Mutex::new(stream));
streams.insert(address.clone(), stream.clone());
stream
}
}
}
async fn get_exposed_tensor_bytes(
&self,
transfer_id: TensorTransferId,
) -> Option<bytes::Bytes> {
loop {
{
let mut exposed_tensors = self.exposed_tensors.lock().await;
if let Some(mut exposed_state) = exposed_tensors.remove(&transfer_id) {
exposed_state.cur_download_count += 1;
let bytes = if exposed_state.cur_download_count == exposed_state.max_downloads {
exposed_state.bytes
} else {
let bytes = exposed_state.bytes.clone();
exposed_tensors.insert(transfer_id, exposed_state);
bytes
};
return Some(bytes);
}
}
self.new_tensor_notify.notified().await;
}
}
pub(crate) async fn handle_data_channel(
&self,
mut channel: <P::Server as ProtocolServer>::Channel,
) {
log::info!("[Data Handler] New connection for download.");
while !self.cancel_token.is_cancelled() {
match channel.recv().await {
Ok(message) => {
if let Some(msg) = message {
let bytes = msg.data;
let msg: DataServiceMessage = rmp_serde::from_slice(&bytes)
.expect("Can deserialize messages from the websocket.");
let DataServiceMessage::TensorRequest(transfer_id) = msg else {
panic!("Received a message that wasn't a tensor request! {msg:?}");
};
let bytes = self.get_exposed_tensor_bytes(transfer_id).await.unwrap();
channel.send(Message::new(bytes)).await.unwrap();
} else {
log::info!("Closed connection");
return;
}
}
Err(err) => panic!("Failed to receive message from websocket: {err:?}"),
};
}
log::info!("[Data Service] Closing connection for download.");
}
}