1#[macro_use]
2extern crate derive_new;
3
4#[cfg(feature = "client")]
5pub(crate) mod client;
6
7#[cfg(feature = "server")]
8pub mod server;
9
10pub(crate) mod shared;
11
12#[cfg(feature = "client")]
13mod __client {
14 use super::*;
15
16 use crate::{client::RemoteChannel, shared::RemoteProtocol};
17 use burn_communication::Protocol;
18 use burn_router::BackendRouter;
19
20 pub type RemoteBackend = BackendRouter<RemoteChannel<<RemoteProtocol as Protocol>::Client>>;
34
35 pub use client::RemoteDevice;
36}
37#[cfg(feature = "client")]
38pub use __client::*;
39
40#[cfg(all(test, feature = "client", feature = "server"))]
41mod tests {
42 use crate::RemoteBackend;
43 use burn_ndarray::NdArray;
44 use burn_tensor::{Distribution, Tensor};
45
46 #[test]
47 pub fn test_to_device_over_websocket() {
48 let rt = tokio::runtime::Builder::new_multi_thread()
49 .enable_io()
50 .build()
51 .unwrap();
52
53 rt.spawn(crate::server::start_websocket_async::<NdArray>(
54 Default::default(),
55 3000,
56 ));
57 rt.spawn(crate::server::start_websocket_async::<NdArray>(
58 Default::default(),
59 3010,
60 ));
61
62 let remote_device_1 = super::RemoteDevice::new("ws://localhost:3000");
63 let remote_device_2 = super::RemoteDevice::new("ws://localhost:3010");
64
65 let input_shape = [1, 28, 28];
67 let input = Tensor::<RemoteBackend, 3>::random(
68 input_shape,
69 Distribution::Default,
70 &remote_device_1,
71 );
72 let numbers_expected: Vec<f32> = input.to_data().to_vec().unwrap();
73
74 let input = input.to_device(&remote_device_2);
76 let numbers: Vec<f32> = input.to_data().to_vec().unwrap();
77 assert_eq!(numbers, numbers_expected);
78
79 let input = input.to_device(&remote_device_1);
81 let numbers: Vec<f32> = input.to_data().to_vec().unwrap();
82 assert_eq!(numbers, numbers_expected);
83
84 rt.shutdown_background();
85 }
86}