distant_net/manager/
server.rs

1use std::collections::HashMap;
2use std::io;
3use std::sync::Arc;
4
5use async_trait::async_trait;
6use distant_auth::msg::AuthenticationResponse;
7use log::*;
8use tokio::sync::{oneshot, RwLock};
9
10use crate::common::{ConnectionId, Destination, Map};
11use crate::manager::{
12    ConnectionInfo, ConnectionList, ManagerAuthenticationId, ManagerChannelId, ManagerRequest,
13    ManagerResponse, SemVer,
14};
15use crate::server::{RequestCtx, Server, ServerHandler};
16
17mod authentication;
18pub use authentication::*;
19
20mod config;
21pub use config::*;
22
23mod connection;
24pub use connection::*;
25
26mod handler;
27pub use handler::*;
28
29/// Represents a manager of multiple server connections.
30pub struct ManagerServer {
31    /// Configuration settings for the server
32    config: Config,
33
34    /// Holds on to open channels feeding data back from a server to some connected client,
35    /// enabling us to cancel the tasks on demand
36    channels: RwLock<HashMap<ManagerChannelId, ManagerChannel>>,
37
38    /// Mapping of connection id -> connection
39    connections: RwLock<HashMap<ConnectionId, ManagerConnection>>,
40
41    /// Mapping of auth id -> callback
42    registry:
43        Arc<RwLock<HashMap<ManagerAuthenticationId, oneshot::Sender<AuthenticationResponse>>>>,
44}
45
46impl ManagerServer {
47    /// Creates a new [`Server`] starting with a default configuration and no authentication
48    /// methods. The provided `config` will be used to configure the launch and connect handlers
49    /// for the server as well as provide other defaults.
50    pub fn new(config: Config) -> Server<Self> {
51        Server::new().handler(Self {
52            config,
53            channels: RwLock::new(HashMap::new()),
54            connections: RwLock::new(HashMap::new()),
55            registry: Arc::new(RwLock::new(HashMap::new())),
56        })
57    }
58
59    /// Launches a new server at the specified `destination` using the given `options` information
60    /// and authentication client (if needed) to retrieve additional information needed to
61    /// enter the destination prior to starting the server, returning the destination of the
62    /// launched server
63    async fn launch(
64        &self,
65        destination: Destination,
66        options: Map,
67        mut authenticator: ManagerAuthenticator,
68    ) -> io::Result<Destination> {
69        let scheme = match destination.scheme.as_deref() {
70            Some(scheme) => {
71                trace!("Using scheme {}", scheme);
72                scheme
73            }
74            None => {
75                trace!(
76                    "Using fallback scheme of {}",
77                    self.config.launch_fallback_scheme.as_str()
78                );
79                self.config.launch_fallback_scheme.as_str()
80            }
81        }
82        .to_lowercase();
83
84        let credentials = {
85            let handler = self.config.launch_handlers.get(&scheme).ok_or_else(|| {
86                io::Error::new(
87                    io::ErrorKind::InvalidInput,
88                    format!("No launch handler registered for {scheme}"),
89                )
90            })?;
91            handler
92                .launch(&destination, &options, &mut authenticator)
93                .await?
94        };
95
96        Ok(credentials)
97    }
98
99    /// Connects to a new server at the specified `destination` using the given `options` information
100    /// and authentication client (if needed) to retrieve additional information needed to
101    /// establish the connection to the server
102    async fn connect(
103        &self,
104        destination: Destination,
105        options: Map,
106        mut authenticator: ManagerAuthenticator,
107    ) -> io::Result<ConnectionId> {
108        let scheme = match destination.scheme.as_deref() {
109            Some(scheme) => {
110                trace!("Using scheme {}", scheme);
111                scheme
112            }
113            None => {
114                trace!(
115                    "Using fallback scheme of {}",
116                    self.config.connect_fallback_scheme.as_str()
117                );
118                self.config.connect_fallback_scheme.as_str()
119            }
120        }
121        .to_lowercase();
122
123        let client = {
124            let handler = self.config.connect_handlers.get(&scheme).ok_or_else(|| {
125                io::Error::new(
126                    io::ErrorKind::InvalidInput,
127                    format!("No connect handler registered for {scheme}"),
128                )
129            })?;
130            handler
131                .connect(&destination, &options, &mut authenticator)
132                .await?
133        };
134
135        let connection = ManagerConnection::spawn(destination, options, client).await?;
136        let id = connection.id;
137        self.connections.write().await.insert(id, connection);
138        Ok(id)
139    }
140
141    /// Retrieves the manager's version.
142    async fn version(&self) -> io::Result<SemVer> {
143        env!("CARGO_PKG_VERSION")
144            .parse()
145            .map_err(|x| io::Error::new(io::ErrorKind::Other, x))
146    }
147
148    /// Retrieves information about the connection to the server with the specified `id`
149    async fn info(&self, id: ConnectionId) -> io::Result<ConnectionInfo> {
150        match self.connections.read().await.get(&id) {
151            Some(connection) => Ok(ConnectionInfo {
152                id: connection.id,
153                destination: connection.destination.clone(),
154                options: connection.options.clone(),
155            }),
156            None => Err(io::Error::new(
157                io::ErrorKind::NotConnected,
158                "No connection found",
159            )),
160        }
161    }
162
163    /// Retrieves a list of connections to servers
164    async fn list(&self) -> io::Result<ConnectionList> {
165        Ok(ConnectionList(
166            self.connections
167                .read()
168                .await
169                .values()
170                .map(|conn| (conn.id, conn.destination.clone()))
171                .collect(),
172        ))
173    }
174
175    /// Kills the connection to the server with the specified `id`
176    async fn kill(&self, id: ConnectionId) -> io::Result<()> {
177        match self.connections.write().await.remove(&id) {
178            Some(connection) => {
179                // Close any open channels
180                if let Ok(ids) = connection.channel_ids().await {
181                    let mut channels_lock = self.channels.write().await;
182                    for id in ids {
183                        if let Some(channel) = channels_lock.remove(&id) {
184                            if let Err(x) = channel.close() {
185                                error!("[Conn {id}] {x}");
186                            }
187                        }
188                    }
189                }
190
191                // Make sure the connection is aborted so nothing new can happen
192                debug!("[Conn {id}] Aborting");
193                connection.abort();
194
195                Ok(())
196            }
197            None => Err(io::Error::new(
198                io::ErrorKind::NotConnected,
199                "No connection found",
200            )),
201        }
202    }
203}
204
205#[async_trait]
206impl ServerHandler for ManagerServer {
207    type Request = ManagerRequest;
208    type Response = ManagerResponse;
209
210    async fn on_request(&self, ctx: RequestCtx<Self::Request, Self::Response>) {
211        debug!("manager::on_request({ctx:?})");
212        let RequestCtx {
213            connection_id,
214            request,
215            reply,
216        } = ctx;
217
218        let response = match request.payload {
219            ManagerRequest::Version {} => {
220                debug!("Looking up version");
221                match self.version().await {
222                    Ok(version) => ManagerResponse::Version { version },
223                    Err(x) => ManagerResponse::from(x),
224                }
225            }
226            ManagerRequest::Launch {
227                destination,
228                options,
229            } => {
230                info!("Launching {destination} with {options}");
231                match self
232                    .launch(
233                        *destination,
234                        options,
235                        ManagerAuthenticator {
236                            reply: reply.clone(),
237                            registry: Arc::clone(&self.registry),
238                        },
239                    )
240                    .await
241                {
242                    Ok(destination) => ManagerResponse::Launched { destination },
243                    Err(x) => ManagerResponse::from(x),
244                }
245            }
246            ManagerRequest::Connect {
247                destination,
248                options,
249            } => {
250                info!("Connecting to {destination} with {options}");
251                match self
252                    .connect(
253                        *destination,
254                        options,
255                        ManagerAuthenticator {
256                            reply: reply.clone(),
257                            registry: Arc::clone(&self.registry),
258                        },
259                    )
260                    .await
261                {
262                    Ok(id) => ManagerResponse::Connected { id },
263                    Err(x) => ManagerResponse::from(x),
264                }
265            }
266            ManagerRequest::Authenticate { id, msg } => {
267                trace!("Retrieving authentication callback registry");
268                match self.registry.write().await.remove(&id) {
269                    Some(cb) => {
270                        trace!("Sending {msg:?} through authentication callback");
271                        match cb.send(msg) {
272                            Ok(_) => return,
273                            Err(_) => ManagerResponse::Error {
274                                description: "Unable to forward authentication callback"
275                                    .to_string(),
276                            },
277                        }
278                    }
279                    None => ManagerResponse::from(io::Error::new(
280                        io::ErrorKind::InvalidInput,
281                        "Invalid authentication id",
282                    )),
283                }
284            }
285            ManagerRequest::OpenChannel { id } => {
286                debug!("Attempting to retrieve connection {id}");
287                match self.connections.read().await.get(&id) {
288                    Some(connection) => {
289                        debug!("Opening channel through connection {id}");
290                        match connection.open_channel(reply.clone()) {
291                            Ok(channel) => {
292                                info!("[Conn {id}] Channel {} has been opened", channel.id());
293                                let id = channel.id();
294                                self.channels.write().await.insert(id, channel);
295                                ManagerResponse::ChannelOpened { id }
296                            }
297                            Err(x) => ManagerResponse::from(x),
298                        }
299                    }
300                    None => ManagerResponse::from(io::Error::new(
301                        io::ErrorKind::NotConnected,
302                        "Connection does not exist",
303                    )),
304                }
305            }
306            ManagerRequest::Channel { id, request } => {
307                debug!("Attempting to retrieve channel {id}");
308                match self.channels.read().await.get(&id) {
309                    // TODO: For now, we are NOT sending back a response to acknowledge
310                    //       a successful channel send. We could do this in order for
311                    //       the client to listen for a complete send, but is it worth it?
312                    Some(channel) => {
313                        debug!("Sending {request:?} through channel {id}");
314                        match channel.send(request) {
315                            Ok(_) => return,
316                            Err(x) => ManagerResponse::from(x),
317                        }
318                    }
319                    None => ManagerResponse::from(io::Error::new(
320                        io::ErrorKind::NotConnected,
321                        "Channel is not open or does not exist",
322                    )),
323                }
324            }
325            ManagerRequest::CloseChannel { id } => {
326                debug!("Attempting to remove channel {id}");
327                match self.channels.write().await.remove(&id) {
328                    Some(channel) => {
329                        debug!("Removed channel {}", channel.id());
330                        match channel.close() {
331                            Ok(_) => {
332                                info!("Channel {id} has been closed");
333                                ManagerResponse::ChannelClosed { id }
334                            }
335                            Err(x) => ManagerResponse::from(x),
336                        }
337                    }
338                    None => ManagerResponse::from(io::Error::new(
339                        io::ErrorKind::NotConnected,
340                        "Channel is not open or does not exist",
341                    )),
342                }
343            }
344            ManagerRequest::Info { id } => {
345                debug!("Attempting to retrieve information for connection {id}");
346                match self.info(id).await {
347                    Ok(info) => {
348                        info!("Retrieved information for connection {id}");
349                        ManagerResponse::Info(info)
350                    }
351                    Err(x) => ManagerResponse::from(x),
352                }
353            }
354            ManagerRequest::List => {
355                debug!("Attempting to retrieve the list of connections");
356                match self.list().await {
357                    Ok(list) => {
358                        info!("Retrieved list of connections");
359                        ManagerResponse::List(list)
360                    }
361                    Err(x) => ManagerResponse::from(x),
362                }
363            }
364            ManagerRequest::Kill { id } => {
365                debug!("Attempting to kill connection {id}");
366                match self.kill(id).await {
367                    Ok(()) => {
368                        info!("Killed connection {id}");
369                        ManagerResponse::Killed
370                    }
371                    Err(x) => ManagerResponse::from(x),
372                }
373            }
374        };
375
376        if let Err(x) = reply.send(response) {
377            error!("[Conn {}] {}", connection_id, x);
378        }
379    }
380}
381
382#[cfg(test)]
383mod tests {
384    use tokio::sync::mpsc;
385
386    use super::*;
387    use crate::client::UntypedClient;
388    use crate::common::FramedTransport;
389    use crate::server::ServerReply;
390    use crate::{boxed_connect_handler, boxed_launch_handler};
391
392    fn test_config() -> Config {
393        Config {
394            launch_fallback_scheme: "ssh".to_string(),
395            connect_fallback_scheme: "distant".to_string(),
396            connection_buffer_size: 100,
397            user: false,
398            launch_handlers: HashMap::new(),
399            connect_handlers: HashMap::new(),
400        }
401    }
402
403    /// Create an untyped client that is detached such that reads and writes will fail
404    fn detached_untyped_client() -> UntypedClient {
405        UntypedClient::spawn_inmemory(FramedTransport::pair(1).0, Default::default())
406    }
407
408    /// Create a new server and authenticator
409    fn setup(config: Config) -> (ManagerServer, ManagerAuthenticator) {
410        let registry = Arc::new(RwLock::new(HashMap::new()));
411
412        let authenticator = ManagerAuthenticator {
413            reply: ServerReply {
414                origin_id: format!("{}", rand::random::<u8>()),
415                tx: mpsc::unbounded_channel().0,
416            },
417            registry: Arc::clone(&registry),
418        };
419
420        let server = ManagerServer {
421            config,
422            channels: RwLock::new(HashMap::new()),
423            connections: RwLock::new(HashMap::new()),
424            registry,
425        };
426
427        (server, authenticator)
428    }
429
430    #[tokio::test]
431    async fn launch_should_fail_if_destination_scheme_is_unsupported() {
432        let (server, authenticator) = setup(test_config());
433
434        let destination = "scheme://host".parse::<Destination>().unwrap();
435        let options = "".parse::<Map>().unwrap();
436        let err = server
437            .launch(destination, options, authenticator)
438            .await
439            .unwrap_err();
440        assert_eq!(err.kind(), io::ErrorKind::InvalidInput, "{:?}", err);
441    }
442
443    #[tokio::test]
444    async fn launch_should_fail_if_handler_tied_to_scheme_fails() {
445        let mut config = test_config();
446
447        let handler = boxed_launch_handler!(|_a, _b, _c| {
448            Err(io::Error::new(io::ErrorKind::Other, "test failure"))
449        });
450
451        config.launch_handlers.insert("scheme".to_string(), handler);
452
453        let (server, authenticator) = setup(config);
454        let destination = "scheme://host".parse::<Destination>().unwrap();
455        let options = "".parse::<Map>().unwrap();
456        let err = server
457            .launch(destination, options, authenticator)
458            .await
459            .unwrap_err();
460        assert_eq!(err.kind(), io::ErrorKind::Other);
461        assert_eq!(err.to_string(), "test failure");
462    }
463
464    #[tokio::test]
465    async fn launch_should_return_new_destination_on_success() {
466        let mut config = test_config();
467
468        let handler = boxed_launch_handler!(|_a, _b, _c| {
469            Ok("scheme2://host2".parse::<Destination>().unwrap())
470        });
471
472        config.launch_handlers.insert("scheme".to_string(), handler);
473
474        let (server, authenticator) = setup(config);
475        let destination = "scheme://host".parse::<Destination>().unwrap();
476        let options = "key=value".parse::<Map>().unwrap();
477        let destination = server
478            .launch(destination, options, authenticator)
479            .await
480            .unwrap();
481
482        assert_eq!(
483            destination,
484            "scheme2://host2".parse::<Destination>().unwrap()
485        );
486    }
487
488    #[tokio::test]
489    async fn connect_should_fail_if_destination_scheme_is_unsupported() {
490        let (server, authenticator) = setup(test_config());
491
492        let destination = "scheme://host".parse::<Destination>().unwrap();
493        let options = "".parse::<Map>().unwrap();
494        let err = server
495            .connect(destination, options, authenticator)
496            .await
497            .unwrap_err();
498        assert_eq!(err.kind(), io::ErrorKind::InvalidInput, "{:?}", err);
499    }
500
501    #[tokio::test]
502    async fn connect_should_fail_if_handler_tied_to_scheme_fails() {
503        let mut config = test_config();
504
505        let handler = boxed_connect_handler!(|_a, _b, _c| {
506            Err(io::Error::new(io::ErrorKind::Other, "test failure"))
507        });
508
509        config
510            .connect_handlers
511            .insert("scheme".to_string(), handler);
512
513        let (server, authenticator) = setup(config);
514        let destination = "scheme://host".parse::<Destination>().unwrap();
515        let options = "".parse::<Map>().unwrap();
516        let err = server
517            .connect(destination, options, authenticator)
518            .await
519            .unwrap_err();
520        assert_eq!(err.kind(), io::ErrorKind::Other);
521        assert_eq!(err.to_string(), "test failure");
522    }
523
524    #[tokio::test]
525    async fn connect_should_return_id_of_new_connection_on_success() {
526        let mut config = test_config();
527
528        let handler = boxed_connect_handler!(|_a, _b, _c| { Ok(detached_untyped_client()) });
529
530        config
531            .connect_handlers
532            .insert("scheme".to_string(), handler);
533
534        let (server, authenticator) = setup(config);
535        let destination = "scheme://host".parse::<Destination>().unwrap();
536        let options = "key=value".parse::<Map>().unwrap();
537        let id = server
538            .connect(destination, options, authenticator)
539            .await
540            .unwrap();
541
542        let lock = server.connections.read().await;
543        let connection = lock.get(&id).unwrap();
544        assert_eq!(connection.id, id);
545        assert_eq!(connection.destination, "scheme://host");
546        assert_eq!(connection.options, "key=value".parse().unwrap());
547    }
548
549    #[tokio::test]
550    async fn info_should_fail_if_no_connection_found_for_specified_id() {
551        let (server, _) = setup(test_config());
552
553        let err = server.info(999).await.unwrap_err();
554        assert_eq!(err.kind(), io::ErrorKind::NotConnected, "{:?}", err);
555    }
556
557    #[tokio::test]
558    async fn info_should_return_information_about_established_connection() {
559        let (server, _) = setup(test_config());
560
561        let connection = ManagerConnection::spawn(
562            "scheme://host".parse().unwrap(),
563            "key=value".parse().unwrap(),
564            detached_untyped_client(),
565        )
566        .await
567        .unwrap();
568        let id = connection.id;
569        server.connections.write().await.insert(id, connection);
570
571        let info = server.info(id).await.unwrap();
572        assert_eq!(
573            info,
574            ConnectionInfo {
575                id,
576                destination: "scheme://host".parse().unwrap(),
577                options: "key=value".parse().unwrap(),
578            }
579        );
580    }
581
582    #[tokio::test]
583    async fn list_should_return_empty_connection_list_if_no_established_connections() {
584        let (server, _) = setup(test_config());
585
586        let list = server.list().await.unwrap();
587        assert_eq!(list, ConnectionList(HashMap::new()));
588    }
589
590    #[tokio::test]
591    async fn list_should_return_a_list_of_established_connections() {
592        let (server, _) = setup(test_config());
593
594        let connection = ManagerConnection::spawn(
595            "scheme://host".parse().unwrap(),
596            "key=value".parse().unwrap(),
597            detached_untyped_client(),
598        )
599        .await
600        .unwrap();
601        let id_1 = connection.id;
602        server.connections.write().await.insert(id_1, connection);
603
604        let connection = ManagerConnection::spawn(
605            "other://host2".parse().unwrap(),
606            "key=value".parse().unwrap(),
607            detached_untyped_client(),
608        )
609        .await
610        .unwrap();
611        let id_2 = connection.id;
612        server.connections.write().await.insert(id_2, connection);
613
614        let list = server.list().await.unwrap();
615        assert_eq!(
616            list.get(&id_1).unwrap(),
617            &"scheme://host".parse::<Destination>().unwrap()
618        );
619        assert_eq!(
620            list.get(&id_2).unwrap(),
621            &"other://host2".parse::<Destination>().unwrap()
622        );
623    }
624
625    #[tokio::test]
626    async fn kill_should_fail_if_no_connection_found_for_specified_id() {
627        let (server, _) = setup(test_config());
628
629        let err = server.kill(999).await.unwrap_err();
630        assert_eq!(err.kind(), io::ErrorKind::NotConnected, "{:?}", err);
631    }
632
633    #[tokio::test]
634    async fn kill_should_terminate_established_connection_and_remove_it_from_the_list() {
635        let (server, _) = setup(test_config());
636
637        let connection = ManagerConnection::spawn(
638            "scheme://host".parse().unwrap(),
639            "key=value".parse().unwrap(),
640            detached_untyped_client(),
641        )
642        .await
643        .unwrap();
644        let id = connection.id;
645        server.connections.write().await.insert(id, connection);
646
647        server.kill(id).await.unwrap();
648
649        let lock = server.connections.read().await;
650        assert!(!lock.contains_key(&id), "Connection still exists");
651    }
652}