1use std::collections::HashMap;
2use std::io;
3use std::sync::Arc;
4
5use async_trait::async_trait;
6use distant_auth::msg::AuthenticationResponse;
7use log::*;
8use tokio::sync::{oneshot, RwLock};
9
10use crate::common::{ConnectionId, Destination, Map};
11use crate::manager::{
12 ConnectionInfo, ConnectionList, ManagerAuthenticationId, ManagerChannelId, ManagerRequest,
13 ManagerResponse, SemVer,
14};
15use crate::server::{RequestCtx, Server, ServerHandler};
16
17mod authentication;
18pub use authentication::*;
19
20mod config;
21pub use config::*;
22
23mod connection;
24pub use connection::*;
25
26mod handler;
27pub use handler::*;
28
29pub struct ManagerServer {
31 config: Config,
33
34 channels: RwLock<HashMap<ManagerChannelId, ManagerChannel>>,
37
38 connections: RwLock<HashMap<ConnectionId, ManagerConnection>>,
40
41 registry:
43 Arc<RwLock<HashMap<ManagerAuthenticationId, oneshot::Sender<AuthenticationResponse>>>>,
44}
45
46impl ManagerServer {
47 pub fn new(config: Config) -> Server<Self> {
51 Server::new().handler(Self {
52 config,
53 channels: RwLock::new(HashMap::new()),
54 connections: RwLock::new(HashMap::new()),
55 registry: Arc::new(RwLock::new(HashMap::new())),
56 })
57 }
58
59 async fn launch(
64 &self,
65 destination: Destination,
66 options: Map,
67 mut authenticator: ManagerAuthenticator,
68 ) -> io::Result<Destination> {
69 let scheme = match destination.scheme.as_deref() {
70 Some(scheme) => {
71 trace!("Using scheme {}", scheme);
72 scheme
73 }
74 None => {
75 trace!(
76 "Using fallback scheme of {}",
77 self.config.launch_fallback_scheme.as_str()
78 );
79 self.config.launch_fallback_scheme.as_str()
80 }
81 }
82 .to_lowercase();
83
84 let credentials = {
85 let handler = self.config.launch_handlers.get(&scheme).ok_or_else(|| {
86 io::Error::new(
87 io::ErrorKind::InvalidInput,
88 format!("No launch handler registered for {scheme}"),
89 )
90 })?;
91 handler
92 .launch(&destination, &options, &mut authenticator)
93 .await?
94 };
95
96 Ok(credentials)
97 }
98
99 async fn connect(
103 &self,
104 destination: Destination,
105 options: Map,
106 mut authenticator: ManagerAuthenticator,
107 ) -> io::Result<ConnectionId> {
108 let scheme = match destination.scheme.as_deref() {
109 Some(scheme) => {
110 trace!("Using scheme {}", scheme);
111 scheme
112 }
113 None => {
114 trace!(
115 "Using fallback scheme of {}",
116 self.config.connect_fallback_scheme.as_str()
117 );
118 self.config.connect_fallback_scheme.as_str()
119 }
120 }
121 .to_lowercase();
122
123 let client = {
124 let handler = self.config.connect_handlers.get(&scheme).ok_or_else(|| {
125 io::Error::new(
126 io::ErrorKind::InvalidInput,
127 format!("No connect handler registered for {scheme}"),
128 )
129 })?;
130 handler
131 .connect(&destination, &options, &mut authenticator)
132 .await?
133 };
134
135 let connection = ManagerConnection::spawn(destination, options, client).await?;
136 let id = connection.id;
137 self.connections.write().await.insert(id, connection);
138 Ok(id)
139 }
140
141 async fn version(&self) -> io::Result<SemVer> {
143 env!("CARGO_PKG_VERSION")
144 .parse()
145 .map_err(|x| io::Error::new(io::ErrorKind::Other, x))
146 }
147
148 async fn info(&self, id: ConnectionId) -> io::Result<ConnectionInfo> {
150 match self.connections.read().await.get(&id) {
151 Some(connection) => Ok(ConnectionInfo {
152 id: connection.id,
153 destination: connection.destination.clone(),
154 options: connection.options.clone(),
155 }),
156 None => Err(io::Error::new(
157 io::ErrorKind::NotConnected,
158 "No connection found",
159 )),
160 }
161 }
162
163 async fn list(&self) -> io::Result<ConnectionList> {
165 Ok(ConnectionList(
166 self.connections
167 .read()
168 .await
169 .values()
170 .map(|conn| (conn.id, conn.destination.clone()))
171 .collect(),
172 ))
173 }
174
175 async fn kill(&self, id: ConnectionId) -> io::Result<()> {
177 match self.connections.write().await.remove(&id) {
178 Some(connection) => {
179 if let Ok(ids) = connection.channel_ids().await {
181 let mut channels_lock = self.channels.write().await;
182 for id in ids {
183 if let Some(channel) = channels_lock.remove(&id) {
184 if let Err(x) = channel.close() {
185 error!("[Conn {id}] {x}");
186 }
187 }
188 }
189 }
190
191 debug!("[Conn {id}] Aborting");
193 connection.abort();
194
195 Ok(())
196 }
197 None => Err(io::Error::new(
198 io::ErrorKind::NotConnected,
199 "No connection found",
200 )),
201 }
202 }
203}
204
205#[async_trait]
206impl ServerHandler for ManagerServer {
207 type Request = ManagerRequest;
208 type Response = ManagerResponse;
209
210 async fn on_request(&self, ctx: RequestCtx<Self::Request, Self::Response>) {
211 debug!("manager::on_request({ctx:?})");
212 let RequestCtx {
213 connection_id,
214 request,
215 reply,
216 } = ctx;
217
218 let response = match request.payload {
219 ManagerRequest::Version {} => {
220 debug!("Looking up version");
221 match self.version().await {
222 Ok(version) => ManagerResponse::Version { version },
223 Err(x) => ManagerResponse::from(x),
224 }
225 }
226 ManagerRequest::Launch {
227 destination,
228 options,
229 } => {
230 info!("Launching {destination} with {options}");
231 match self
232 .launch(
233 *destination,
234 options,
235 ManagerAuthenticator {
236 reply: reply.clone(),
237 registry: Arc::clone(&self.registry),
238 },
239 )
240 .await
241 {
242 Ok(destination) => ManagerResponse::Launched { destination },
243 Err(x) => ManagerResponse::from(x),
244 }
245 }
246 ManagerRequest::Connect {
247 destination,
248 options,
249 } => {
250 info!("Connecting to {destination} with {options}");
251 match self
252 .connect(
253 *destination,
254 options,
255 ManagerAuthenticator {
256 reply: reply.clone(),
257 registry: Arc::clone(&self.registry),
258 },
259 )
260 .await
261 {
262 Ok(id) => ManagerResponse::Connected { id },
263 Err(x) => ManagerResponse::from(x),
264 }
265 }
266 ManagerRequest::Authenticate { id, msg } => {
267 trace!("Retrieving authentication callback registry");
268 match self.registry.write().await.remove(&id) {
269 Some(cb) => {
270 trace!("Sending {msg:?} through authentication callback");
271 match cb.send(msg) {
272 Ok(_) => return,
273 Err(_) => ManagerResponse::Error {
274 description: "Unable to forward authentication callback"
275 .to_string(),
276 },
277 }
278 }
279 None => ManagerResponse::from(io::Error::new(
280 io::ErrorKind::InvalidInput,
281 "Invalid authentication id",
282 )),
283 }
284 }
285 ManagerRequest::OpenChannel { id } => {
286 debug!("Attempting to retrieve connection {id}");
287 match self.connections.read().await.get(&id) {
288 Some(connection) => {
289 debug!("Opening channel through connection {id}");
290 match connection.open_channel(reply.clone()) {
291 Ok(channel) => {
292 info!("[Conn {id}] Channel {} has been opened", channel.id());
293 let id = channel.id();
294 self.channels.write().await.insert(id, channel);
295 ManagerResponse::ChannelOpened { id }
296 }
297 Err(x) => ManagerResponse::from(x),
298 }
299 }
300 None => ManagerResponse::from(io::Error::new(
301 io::ErrorKind::NotConnected,
302 "Connection does not exist",
303 )),
304 }
305 }
306 ManagerRequest::Channel { id, request } => {
307 debug!("Attempting to retrieve channel {id}");
308 match self.channels.read().await.get(&id) {
309 Some(channel) => {
313 debug!("Sending {request:?} through channel {id}");
314 match channel.send(request) {
315 Ok(_) => return,
316 Err(x) => ManagerResponse::from(x),
317 }
318 }
319 None => ManagerResponse::from(io::Error::new(
320 io::ErrorKind::NotConnected,
321 "Channel is not open or does not exist",
322 )),
323 }
324 }
325 ManagerRequest::CloseChannel { id } => {
326 debug!("Attempting to remove channel {id}");
327 match self.channels.write().await.remove(&id) {
328 Some(channel) => {
329 debug!("Removed channel {}", channel.id());
330 match channel.close() {
331 Ok(_) => {
332 info!("Channel {id} has been closed");
333 ManagerResponse::ChannelClosed { id }
334 }
335 Err(x) => ManagerResponse::from(x),
336 }
337 }
338 None => ManagerResponse::from(io::Error::new(
339 io::ErrorKind::NotConnected,
340 "Channel is not open or does not exist",
341 )),
342 }
343 }
344 ManagerRequest::Info { id } => {
345 debug!("Attempting to retrieve information for connection {id}");
346 match self.info(id).await {
347 Ok(info) => {
348 info!("Retrieved information for connection {id}");
349 ManagerResponse::Info(info)
350 }
351 Err(x) => ManagerResponse::from(x),
352 }
353 }
354 ManagerRequest::List => {
355 debug!("Attempting to retrieve the list of connections");
356 match self.list().await {
357 Ok(list) => {
358 info!("Retrieved list of connections");
359 ManagerResponse::List(list)
360 }
361 Err(x) => ManagerResponse::from(x),
362 }
363 }
364 ManagerRequest::Kill { id } => {
365 debug!("Attempting to kill connection {id}");
366 match self.kill(id).await {
367 Ok(()) => {
368 info!("Killed connection {id}");
369 ManagerResponse::Killed
370 }
371 Err(x) => ManagerResponse::from(x),
372 }
373 }
374 };
375
376 if let Err(x) = reply.send(response) {
377 error!("[Conn {}] {}", connection_id, x);
378 }
379 }
380}
381
382#[cfg(test)]
383mod tests {
384 use tokio::sync::mpsc;
385
386 use super::*;
387 use crate::client::UntypedClient;
388 use crate::common::FramedTransport;
389 use crate::server::ServerReply;
390 use crate::{boxed_connect_handler, boxed_launch_handler};
391
392 fn test_config() -> Config {
393 Config {
394 launch_fallback_scheme: "ssh".to_string(),
395 connect_fallback_scheme: "distant".to_string(),
396 connection_buffer_size: 100,
397 user: false,
398 launch_handlers: HashMap::new(),
399 connect_handlers: HashMap::new(),
400 }
401 }
402
403 fn detached_untyped_client() -> UntypedClient {
405 UntypedClient::spawn_inmemory(FramedTransport::pair(1).0, Default::default())
406 }
407
408 fn setup(config: Config) -> (ManagerServer, ManagerAuthenticator) {
410 let registry = Arc::new(RwLock::new(HashMap::new()));
411
412 let authenticator = ManagerAuthenticator {
413 reply: ServerReply {
414 origin_id: format!("{}", rand::random::<u8>()),
415 tx: mpsc::unbounded_channel().0,
416 },
417 registry: Arc::clone(®istry),
418 };
419
420 let server = ManagerServer {
421 config,
422 channels: RwLock::new(HashMap::new()),
423 connections: RwLock::new(HashMap::new()),
424 registry,
425 };
426
427 (server, authenticator)
428 }
429
430 #[tokio::test]
431 async fn launch_should_fail_if_destination_scheme_is_unsupported() {
432 let (server, authenticator) = setup(test_config());
433
434 let destination = "scheme://host".parse::<Destination>().unwrap();
435 let options = "".parse::<Map>().unwrap();
436 let err = server
437 .launch(destination, options, authenticator)
438 .await
439 .unwrap_err();
440 assert_eq!(err.kind(), io::ErrorKind::InvalidInput, "{:?}", err);
441 }
442
443 #[tokio::test]
444 async fn launch_should_fail_if_handler_tied_to_scheme_fails() {
445 let mut config = test_config();
446
447 let handler = boxed_launch_handler!(|_a, _b, _c| {
448 Err(io::Error::new(io::ErrorKind::Other, "test failure"))
449 });
450
451 config.launch_handlers.insert("scheme".to_string(), handler);
452
453 let (server, authenticator) = setup(config);
454 let destination = "scheme://host".parse::<Destination>().unwrap();
455 let options = "".parse::<Map>().unwrap();
456 let err = server
457 .launch(destination, options, authenticator)
458 .await
459 .unwrap_err();
460 assert_eq!(err.kind(), io::ErrorKind::Other);
461 assert_eq!(err.to_string(), "test failure");
462 }
463
464 #[tokio::test]
465 async fn launch_should_return_new_destination_on_success() {
466 let mut config = test_config();
467
468 let handler = boxed_launch_handler!(|_a, _b, _c| {
469 Ok("scheme2://host2".parse::<Destination>().unwrap())
470 });
471
472 config.launch_handlers.insert("scheme".to_string(), handler);
473
474 let (server, authenticator) = setup(config);
475 let destination = "scheme://host".parse::<Destination>().unwrap();
476 let options = "key=value".parse::<Map>().unwrap();
477 let destination = server
478 .launch(destination, options, authenticator)
479 .await
480 .unwrap();
481
482 assert_eq!(
483 destination,
484 "scheme2://host2".parse::<Destination>().unwrap()
485 );
486 }
487
488 #[tokio::test]
489 async fn connect_should_fail_if_destination_scheme_is_unsupported() {
490 let (server, authenticator) = setup(test_config());
491
492 let destination = "scheme://host".parse::<Destination>().unwrap();
493 let options = "".parse::<Map>().unwrap();
494 let err = server
495 .connect(destination, options, authenticator)
496 .await
497 .unwrap_err();
498 assert_eq!(err.kind(), io::ErrorKind::InvalidInput, "{:?}", err);
499 }
500
501 #[tokio::test]
502 async fn connect_should_fail_if_handler_tied_to_scheme_fails() {
503 let mut config = test_config();
504
505 let handler = boxed_connect_handler!(|_a, _b, _c| {
506 Err(io::Error::new(io::ErrorKind::Other, "test failure"))
507 });
508
509 config
510 .connect_handlers
511 .insert("scheme".to_string(), handler);
512
513 let (server, authenticator) = setup(config);
514 let destination = "scheme://host".parse::<Destination>().unwrap();
515 let options = "".parse::<Map>().unwrap();
516 let err = server
517 .connect(destination, options, authenticator)
518 .await
519 .unwrap_err();
520 assert_eq!(err.kind(), io::ErrorKind::Other);
521 assert_eq!(err.to_string(), "test failure");
522 }
523
524 #[tokio::test]
525 async fn connect_should_return_id_of_new_connection_on_success() {
526 let mut config = test_config();
527
528 let handler = boxed_connect_handler!(|_a, _b, _c| { Ok(detached_untyped_client()) });
529
530 config
531 .connect_handlers
532 .insert("scheme".to_string(), handler);
533
534 let (server, authenticator) = setup(config);
535 let destination = "scheme://host".parse::<Destination>().unwrap();
536 let options = "key=value".parse::<Map>().unwrap();
537 let id = server
538 .connect(destination, options, authenticator)
539 .await
540 .unwrap();
541
542 let lock = server.connections.read().await;
543 let connection = lock.get(&id).unwrap();
544 assert_eq!(connection.id, id);
545 assert_eq!(connection.destination, "scheme://host");
546 assert_eq!(connection.options, "key=value".parse().unwrap());
547 }
548
549 #[tokio::test]
550 async fn info_should_fail_if_no_connection_found_for_specified_id() {
551 let (server, _) = setup(test_config());
552
553 let err = server.info(999).await.unwrap_err();
554 assert_eq!(err.kind(), io::ErrorKind::NotConnected, "{:?}", err);
555 }
556
557 #[tokio::test]
558 async fn info_should_return_information_about_established_connection() {
559 let (server, _) = setup(test_config());
560
561 let connection = ManagerConnection::spawn(
562 "scheme://host".parse().unwrap(),
563 "key=value".parse().unwrap(),
564 detached_untyped_client(),
565 )
566 .await
567 .unwrap();
568 let id = connection.id;
569 server.connections.write().await.insert(id, connection);
570
571 let info = server.info(id).await.unwrap();
572 assert_eq!(
573 info,
574 ConnectionInfo {
575 id,
576 destination: "scheme://host".parse().unwrap(),
577 options: "key=value".parse().unwrap(),
578 }
579 );
580 }
581
582 #[tokio::test]
583 async fn list_should_return_empty_connection_list_if_no_established_connections() {
584 let (server, _) = setup(test_config());
585
586 let list = server.list().await.unwrap();
587 assert_eq!(list, ConnectionList(HashMap::new()));
588 }
589
590 #[tokio::test]
591 async fn list_should_return_a_list_of_established_connections() {
592 let (server, _) = setup(test_config());
593
594 let connection = ManagerConnection::spawn(
595 "scheme://host".parse().unwrap(),
596 "key=value".parse().unwrap(),
597 detached_untyped_client(),
598 )
599 .await
600 .unwrap();
601 let id_1 = connection.id;
602 server.connections.write().await.insert(id_1, connection);
603
604 let connection = ManagerConnection::spawn(
605 "other://host2".parse().unwrap(),
606 "key=value".parse().unwrap(),
607 detached_untyped_client(),
608 )
609 .await
610 .unwrap();
611 let id_2 = connection.id;
612 server.connections.write().await.insert(id_2, connection);
613
614 let list = server.list().await.unwrap();
615 assert_eq!(
616 list.get(&id_1).unwrap(),
617 &"scheme://host".parse::<Destination>().unwrap()
618 );
619 assert_eq!(
620 list.get(&id_2).unwrap(),
621 &"other://host2".parse::<Destination>().unwrap()
622 );
623 }
624
625 #[tokio::test]
626 async fn kill_should_fail_if_no_connection_found_for_specified_id() {
627 let (server, _) = setup(test_config());
628
629 let err = server.kill(999).await.unwrap_err();
630 assert_eq!(err.kind(), io::ErrorKind::NotConnected, "{:?}", err);
631 }
632
633 #[tokio::test]
634 async fn kill_should_terminate_established_connection_and_remove_it_from_the_list() {
635 let (server, _) = setup(test_config());
636
637 let connection = ManagerConnection::spawn(
638 "scheme://host".parse().unwrap(),
639 "key=value".parse().unwrap(),
640 detached_untyped_client(),
641 )
642 .await
643 .unwrap();
644 let id = connection.id;
645 server.connections.write().await.insert(id, connection);
646
647 server.kill(id).await.unwrap();
648
649 let lock = server.connections.read().await;
650 assert!(!lock.contains_key(&id), "Connection still exists");
651 }
652}