burn_remote/server/
base.rs1use burn_communication::{
2 CommunicationChannel, Message, Protocol, ProtocolServer,
3 data_service::{TensorDataServer, TensorDataService},
4 util::os_shutdown_signal,
5 websocket::{WebSocket, WsServer},
6};
7use std::{marker::PhantomData, sync::Arc};
8use tokio_util::sync::CancellationToken;
9
10use burn_ir::BackendIr;
11use burn_tensor::Device;
12
13use crate::shared::{ComputeTask, Task};
14
15use super::session::SessionManager;
16
17pub struct RemoteServer<B, P>
18where
19 B: BackendIr,
20 P: Protocol,
21{
22 _b: PhantomData<B>,
23 _n: PhantomData<P>,
24}
25
26impl<B, P> RemoteServer<B, P>
27where
28 B: BackendIr,
29 P: Protocol,
30{
31 pub async fn start(device: Device<B>, server: P::Server) {
33 let cancel_token = CancellationToken::new();
34 let data_service = Arc::new(TensorDataService::<B, P>::new(cancel_token));
35 let session_manager = Arc::new(SessionManager::<B, P>::new(device, data_service.clone()));
36
37 let _server = server
38 .route("/response", {
39 let session_manager = session_manager.clone();
40 move |stream| Self::handle_socket_response(session_manager, stream)
41 })
42 .route("/request", {
43 let session_manager = session_manager.clone();
44 move |stream| Self::handle_socket_request(session_manager, stream)
45 })
46 .route_tensor_data_service(data_service)
47 .serve(os_shutdown_signal())
48 .await;
49 }
50
51 async fn handle_socket_response(
52 session_manager: Arc<SessionManager<B, P>>,
53 mut socket: <P::Server as ProtocolServer>::Channel,
54 ) {
55 log::info!("[Response Handler] On new connection.");
56
57 let packet = socket.recv().await;
58 let msg = match packet {
59 Ok(Some(msg)) => msg,
60 Ok(None) => {
61 log::info!("Response stream closed");
62 return;
63 }
64 Err(e) => {
65 log::info!("Response stream error on init: {e:?}");
66 return;
67 }
68 };
69
70 let id = match rmp_serde::from_slice::<Task>(&msg.data) {
71 Ok(Task::Init(session_id)) => session_id,
72 msg => {
73 log::error!("Message is not a valid initialization task {msg:?}");
74 return;
75 }
76 };
77
78 let mut receiver = session_manager.register_responder(id).await;
79
80 log::info!("Response handler connection active");
81
82 while let Some(mut callback) = receiver.recv().await {
83 let response = callback.recv().await.unwrap();
84 let bytes = rmp_serde::to_vec(&response).unwrap();
85
86 socket.send(Message::new(bytes.into())).await.unwrap();
87 }
88 }
89
90 async fn handle_socket_request(
91 session_manager: Arc<SessionManager<B, P>>,
92 mut socket: <P::Server as ProtocolServer>::Channel,
93 ) {
94 log::info!("[Request Handler] On new connection.");
95 let mut session_id = None;
96
97 loop {
98 let packet = socket.recv().await;
99 let msg = match packet {
100 Ok(Some(msg)) => msg,
101 Ok(None) => {
102 log::info!("Request stream closed");
103 break;
104 }
105 Err(e) => {
106 log::info!("Request stream error: {e:?}, Closing.");
107 break;
108 }
109 };
110
111 let task = match rmp_serde::from_slice::<Task>(&msg.data) {
112 Ok(val) => val,
113 Err(err) => {
114 log::info!("Only bytes message in the json format are supported {err:?}");
115 break;
116 }
117 };
118
119 if let Task::Close(id) = task {
120 session_id = Some(id);
121 break;
122 }
123
124 let (stream, connection_id, task) =
125 match session_manager.stream(&mut session_id, task).await {
126 Some(val) => val,
127 None => {
128 log::info!("Ops session activated {session_id:?}");
129 continue;
130 }
131 };
132
133 match task {
134 ComputeTask::RegisterOperation(op) => {
135 stream.register_operation(op).await;
136 }
137 ComputeTask::RegisterTensor(id, data) => {
138 stream.register_tensor(id, data).await;
139 }
140 ComputeTask::ReadTensor(tensor) => {
141 stream.read_tensor(connection_id, tensor).await;
142 }
143 ComputeTask::SyncBackend => {
144 stream.sync(connection_id).await;
145 }
146 ComputeTask::RegisterTensorRemote(tensor, new_id) => {
147 stream.register_tensor_remote(tensor, new_id).await;
148 }
149 ComputeTask::ExposeTensorRemote {
150 tensor,
151 count,
152 transfer_id,
153 } => {
154 stream
155 .expose_tensor_remote(tensor, count, transfer_id)
156 .await;
157 }
158 ComputeTask::Seed(seed) => {
159 stream.seed(seed).await;
160 }
161 ComputeTask::SupportsDType(dtype) => {
162 stream.supports_dtype(connection_id, dtype).await
163 }
164 }
165 }
166
167 log::info!("Closing session {session_id:?}");
168 session_manager.close(session_id).await;
169 }
170}
171
172pub async fn start_websocket_async<B: BackendIr>(device: Device<B>, port: u16) {
174 let server = WsServer::new(port);
175 RemoteServer::<B, WebSocket>::start(device, server).await;
176}
177
178#[tokio::main]
179pub async fn start_websocket<B: BackendIr>(device: Device<B>, port: u16) {
181 start_websocket_async::<B>(device, port).await;
182}