burn_remote/
lib.rs

1#[macro_use]
2extern crate derive_new;
3
4#[cfg(feature = "client")]
5pub(crate) mod client;
6
7#[cfg(feature = "server")]
8pub mod server;
9
10pub(crate) mod shared;
11
12#[cfg(feature = "client")]
13mod __client {
14    use super::*;
15
16    use burn_router::BackendRouter;
17    use client::WsChannel;
18
19    /// The remote backend allows you to run computation on a remote device.
20    ///
21    /// Make sure there is a running server before trying to connect to it.
22    ///
23    /// ```rust, ignore
24    /// fn main() {
25    ///     let device = Default::default();
26    ///     let port = 3000;
27    ///
28    ///     // You need to activate the `server` feature flag to have access to this function.
29    ///     burn::server::start::<burn::backend::Wgpu>(device, port);
30    /// }
31    ///```
32    pub type RemoteBackend = BackendRouter<WsChannel>;
33
34    pub use client::WsDevice as RemoteDevice;
35}
36#[cfg(feature = "client")]
37pub use __client::*;
38
39#[cfg(all(test, feature = "client", feature = "server"))]
40mod tests {
41    use crate::RemoteBackend;
42    use burn_ndarray::NdArray;
43    use burn_tensor::{Distribution, Tensor};
44
45    #[test]
46    pub fn test_to_device_over_websocket() {
47        let rt = tokio::runtime::Builder::new_multi_thread()
48            .enable_io()
49            .build()
50            .unwrap();
51
52        rt.spawn(crate::server::start_async::<NdArray>(
53            Default::default(),
54            3000,
55        ));
56        rt.spawn(crate::server::start_async::<NdArray>(
57            Default::default(),
58            3010,
59        ));
60
61        let remote_device_1 = super::RemoteDevice::new("ws://localhost:3000");
62        let remote_device_2 = super::RemoteDevice::new("ws://localhost:3010");
63
64        // Some random input
65        let input_shape = [1, 28, 28];
66        let input = Tensor::<RemoteBackend, 3>::random(
67            input_shape,
68            Distribution::Default,
69            &remote_device_1,
70        );
71        let numbers_expected: Vec<f32> = input.to_data().to_vec().unwrap();
72
73        // Move tensor to device 2
74        let input = input.to_device(&remote_device_2);
75        let numbers: Vec<f32> = input.to_data().to_vec().unwrap();
76        assert_eq!(numbers, numbers_expected);
77
78        // Move tensor back to device 1
79        let input = input.to_device(&remote_device_1);
80        let numbers: Vec<f32> = input.to_data().to_vec().unwrap();
81        assert_eq!(numbers, numbers_expected);
82
83        rt.shutdown_background();
84    }
85}