burn_remote/
lib.rs

1#[macro_use]
2extern crate derive_new;
3
4#[cfg(feature = "client")]
5pub(crate) mod client;
6
7#[cfg(feature = "server")]
8pub mod server;
9
10pub(crate) mod shared;
11
12#[cfg(feature = "client")]
13mod __client {
14    use super::*;
15
16    use crate::{client::RemoteChannel, shared::RemoteProtocol};
17    use burn_communication::Protocol;
18    use burn_router::BackendRouter;
19
20    /// The remote backend allows you to run computation on a remote device.
21    ///
22    /// Make sure there is a running server before trying to connect to it.
23    ///
24    /// ```rust, ignore
25    /// fn main() {
26    ///     let device = Default::default();
27    ///     let port = 3000;
28    ///
29    ///     // You need to activate the `server` feature flag to have access to this function.
30    ///     burn::server::start::<burn::backend::Wgpu>(device, port);
31    /// }
32    ///```
33    pub type RemoteBackend = BackendRouter<RemoteChannel<<RemoteProtocol as Protocol>::Client>>;
34
35    pub use client::RemoteDevice;
36}
37#[cfg(feature = "client")]
38pub use __client::*;
39
40#[cfg(all(test, feature = "client", feature = "server"))]
41mod tests {
42    use crate::RemoteBackend;
43    use burn_ndarray::NdArray;
44    use burn_tensor::{Distribution, Tensor};
45
46    #[test]
47    pub fn test_to_device_over_websocket() {
48        let rt = tokio::runtime::Builder::new_multi_thread()
49            .enable_io()
50            .build()
51            .unwrap();
52
53        rt.spawn(crate::server::start_websocket_async::<NdArray>(
54            Default::default(),
55            3000,
56        ));
57        rt.spawn(crate::server::start_websocket_async::<NdArray>(
58            Default::default(),
59            3010,
60        ));
61
62        let remote_device_1 = super::RemoteDevice::new("ws://localhost:3000");
63        let remote_device_2 = super::RemoteDevice::new("ws://localhost:3010");
64
65        // Some random input
66        let input_shape = [1, 28, 28];
67        let input = Tensor::<RemoteBackend, 3>::random(
68            input_shape,
69            Distribution::Default,
70            &remote_device_1,
71        );
72        let numbers_expected: Vec<f32> = input.to_data().to_vec().unwrap();
73
74        // Move tensor to device 2
75        let input = input.to_device(&remote_device_2);
76        let numbers: Vec<f32> = input.to_data().to_vec().unwrap();
77        assert_eq!(numbers, numbers_expected);
78
79        // Move tensor back to device 1
80        let input = input.to_device(&remote_device_1);
81        let numbers: Vec<f32> = input.to_data().to_vec().unwrap();
82        assert_eq!(numbers, numbers_expected);
83
84        rt.shutdown_background();
85    }
86}