1#[macro_use]
2extern crate derive_new;
3
4#[cfg(feature = "client")]
5pub(crate) mod client;
6
7#[cfg(feature = "server")]
8pub mod server;
9
10pub(crate) mod shared;
11
12#[cfg(feature = "client")]
13mod __client {
14 use super::*;
15
16 use burn_router::BackendRouter;
17 use client::WsChannel;
18
19 pub type RemoteBackend = BackendRouter<WsChannel>;
33
34 pub use client::WsDevice as RemoteDevice;
35}
36#[cfg(feature = "client")]
37pub use __client::*;
38
39#[cfg(all(test, feature = "client", feature = "server"))]
40mod tests {
41 use crate::RemoteBackend;
42 use burn_ndarray::NdArray;
43 use burn_tensor::{Distribution, Tensor};
44
45 #[test]
46 pub fn test_to_device_over_websocket() {
47 let rt = tokio::runtime::Builder::new_multi_thread()
48 .enable_io()
49 .build()
50 .unwrap();
51
52 rt.spawn(crate::server::start_async::<NdArray>(
53 Default::default(),
54 3000,
55 ));
56 rt.spawn(crate::server::start_async::<NdArray>(
57 Default::default(),
58 3010,
59 ));
60
61 let remote_device_1 = super::RemoteDevice::new("ws://localhost:3000");
62 let remote_device_2 = super::RemoteDevice::new("ws://localhost:3010");
63
64 let input_shape = [1, 28, 28];
66 let input = Tensor::<RemoteBackend, 3>::random(
67 input_shape,
68 Distribution::Default,
69 &remote_device_1,
70 );
71 let numbers_expected: Vec<f32> = input.to_data().to_vec().unwrap();
72
73 let input = input.to_device(&remote_device_2);
75 let numbers: Vec<f32> = input.to_data().to_vec().unwrap();
76 assert_eq!(numbers, numbers_expected);
77
78 let input = input.to_device(&remote_device_1);
80 let numbers: Vec<f32> = input.to_data().to_vec().unwrap();
81 assert_eq!(numbers, numbers_expected);
82
83 rt.shutdown_background();
84 }
85}