Skip to main content

burn_remote/server/
base.rs

1use burn_communication::{
2    CommunicationChannel, Message, Protocol, ProtocolServer,
3    data_service::{TensorDataServer, TensorDataService},
4    util::os_shutdown_signal,
5    websocket::{WebSocket, WsServer},
6};
7use std::{marker::PhantomData, sync::Arc};
8use tokio_util::sync::CancellationToken;
9
10use burn_ir::BackendIr;
11use burn_tensor::Device;
12
13use crate::shared::{ComputeTask, Task};
14
15use super::session::SessionManager;
16
17pub struct RemoteServer<B, P>
18where
19    B: BackendIr,
20    P: Protocol,
21{
22    _b: PhantomData<B>,
23    _n: PhantomData<P>,
24}
25
26impl<B, P> RemoteServer<B, P>
27where
28    B: BackendIr,
29    P: Protocol,
30{
31    /// Start the server on the given address.
32    pub async fn start(device: Device<B>, server: P::Server) {
33        let cancel_token = CancellationToken::new();
34        let data_service = Arc::new(TensorDataService::<B, P>::new(cancel_token));
35        let session_manager = Arc::new(SessionManager::<B, P>::new(device, data_service.clone()));
36
37        let _server = server
38            .route("/response", {
39                let session_manager = session_manager.clone();
40                move |stream| Self::handle_socket_response(session_manager, stream)
41            })
42            .route("/request", {
43                let session_manager = session_manager.clone();
44                move |stream| Self::handle_socket_request(session_manager, stream)
45            })
46            .route_tensor_data_service(data_service)
47            .serve(os_shutdown_signal())
48            .await;
49    }
50
51    async fn handle_socket_response(
52        session_manager: Arc<SessionManager<B, P>>,
53        mut socket: <P::Server as ProtocolServer>::Channel,
54    ) {
55        log::info!("[Response Handler] On new connection.");
56
57        let packet = socket.recv().await;
58        let msg = match packet {
59            Ok(Some(msg)) => msg,
60            Ok(None) => {
61                log::info!("Response stream closed");
62                return;
63            }
64            Err(e) => {
65                log::info!("Response stream error on init: {e:?}");
66                return;
67            }
68        };
69
70        let id = match rmp_serde::from_slice::<Task>(&msg.data) {
71            Ok(Task::Init(session_id)) => session_id,
72            msg => {
73                log::error!("Message is not a valid initialization task {msg:?}");
74                return;
75            }
76        };
77
78        let mut receiver = session_manager.register_responder(id).await;
79
80        log::info!("Response handler connection active");
81
82        while let Some(mut callback) = receiver.recv().await {
83            let response = callback.recv().await.unwrap();
84            let bytes = rmp_serde::to_vec(&response).unwrap();
85
86            socket.send(Message::new(bytes.into())).await.unwrap();
87        }
88    }
89
90    async fn handle_socket_request(
91        session_manager: Arc<SessionManager<B, P>>,
92        mut socket: <P::Server as ProtocolServer>::Channel,
93    ) {
94        log::info!("[Request Handler] On new connection.");
95        let mut session_id = None;
96
97        loop {
98            let packet = socket.recv().await;
99            let msg = match packet {
100                Ok(Some(msg)) => msg,
101                Ok(None) => {
102                    log::info!("Request stream closed");
103                    break;
104                }
105                Err(e) => {
106                    log::info!("Request stream error: {e:?}, Closing.");
107                    break;
108                }
109            };
110
111            let task = match rmp_serde::from_slice::<Task>(&msg.data) {
112                Ok(val) => val,
113                Err(err) => {
114                    log::info!("Only bytes message in the json format are supported {err:?}");
115                    break;
116                }
117            };
118
119            if let Task::Close(id) = task {
120                session_id = Some(id);
121                break;
122            }
123
124            let (stream, connection_id, task) =
125                match session_manager.stream(&mut session_id, task).await {
126                    Some(val) => val,
127                    None => {
128                        log::info!("Ops session activated {session_id:?}");
129                        continue;
130                    }
131                };
132
133            match task {
134                ComputeTask::RegisterOperation(op) => {
135                    stream.register_operation(op).await;
136                }
137                ComputeTask::RegisterTensor(id, data) => {
138                    stream.register_tensor(id, data).await;
139                }
140                ComputeTask::ReadTensor(tensor) => {
141                    stream.read_tensor(connection_id, tensor).await;
142                }
143                ComputeTask::SyncBackend => {
144                    stream.sync(connection_id).await;
145                }
146                ComputeTask::RegisterTensorRemote(tensor, new_id) => {
147                    stream.register_tensor_remote(tensor, new_id).await;
148                }
149                ComputeTask::ExposeTensorRemote {
150                    tensor,
151                    count,
152                    transfer_id,
153                } => {
154                    stream
155                        .expose_tensor_remote(tensor, count, transfer_id)
156                        .await;
157                }
158                ComputeTask::Seed(seed) => {
159                    stream.seed(seed).await;
160                }
161                ComputeTask::SupportsDType(dtype) => {
162                    stream.supports_dtype(connection_id, dtype).await
163                }
164            }
165        }
166
167        log::info!("Closing session {session_id:?}");
168        session_manager.close(session_id).await;
169    }
170}
171
172/// Start the server on the given port and [device](Device).
173pub async fn start_websocket_async<B: BackendIr>(device: Device<B>, port: u16) {
174    let server = WsServer::new(port);
175    RemoteServer::<B, WebSocket>::start(device, server).await;
176}
177
178#[tokio::main]
179/// Start the server on the given port and [device](Device).
180pub async fn start_websocket<B: BackendIr>(device: Device<B>, port: u16) {
181    start_websocket_async::<B>(device, port).await;
182}