1use std::io;
2use std::path::PathBuf;
3use std::sync::Arc;
4
5use async_trait::async_trait;
6use distant_net::common::ConnectionId;
7use distant_net::server::{Reply, RequestCtx, ServerHandler};
8use log::*;
9
10use crate::protocol::{
11 self, ChangeKind, DirEntry, Environment, Error, Metadata, Permissions, ProcessId, PtySize,
12 SearchId, SearchQuery, SetPermissionsOptions, SystemInfo, Version,
13};
14
15mod reply;
16use reply::DistantSingleReply;
17
18pub struct DistantCtx {
20 pub connection_id: ConnectionId,
21 pub reply: Box<dyn Reply<Data = protocol::Response>>,
22}
23
24pub struct DistantApiServerHandler<T>
26where
27 T: DistantApi,
28{
29 api: Arc<T>,
30}
31
32impl<T> DistantApiServerHandler<T>
33where
34 T: DistantApi,
35{
36 pub fn new(api: T) -> Self {
37 Self { api: Arc::new(api) }
38 }
39}
40
41#[inline]
42fn unsupported<T>(label: &str) -> io::Result<T> {
43 Err(io::Error::new(
44 io::ErrorKind::Unsupported,
45 format!("{label} is unsupported"),
46 ))
47}
48
49#[async_trait]
52pub trait DistantApi {
53 #[allow(unused_variables)]
55 async fn on_connect(&self, id: ConnectionId) -> io::Result<()> {
56 Ok(())
57 }
58
59 #[allow(unused_variables)]
61 async fn on_disconnect(&self, id: ConnectionId) -> io::Result<()> {
62 Ok(())
63 }
64
65 #[allow(unused_variables)]
69 async fn version(&self, ctx: DistantCtx) -> io::Result<Version> {
70 unsupported("version")
71 }
72
73 #[allow(unused_variables)]
79 async fn read_file(&self, ctx: DistantCtx, path: PathBuf) -> io::Result<Vec<u8>> {
80 unsupported("read_file")
81 }
82
83 #[allow(unused_variables)]
89 async fn read_file_text(&self, ctx: DistantCtx, path: PathBuf) -> io::Result<String> {
90 unsupported("read_file_text")
91 }
92
93 #[allow(unused_variables)]
100 async fn write_file(&self, ctx: DistantCtx, path: PathBuf, data: Vec<u8>) -> io::Result<()> {
101 unsupported("write_file")
102 }
103
104 #[allow(unused_variables)]
111 async fn write_file_text(
112 &self,
113 ctx: DistantCtx,
114 path: PathBuf,
115 data: String,
116 ) -> io::Result<()> {
117 unsupported("write_file_text")
118 }
119
120 #[allow(unused_variables)]
127 async fn append_file(&self, ctx: DistantCtx, path: PathBuf, data: Vec<u8>) -> io::Result<()> {
128 unsupported("append_file")
129 }
130
131 #[allow(unused_variables)]
138 async fn append_file_text(
139 &self,
140 ctx: DistantCtx,
141 path: PathBuf,
142 data: String,
143 ) -> io::Result<()> {
144 unsupported("append_file_text")
145 }
146
147 #[allow(unused_variables)]
157 async fn read_dir(
158 &self,
159 ctx: DistantCtx,
160 path: PathBuf,
161 depth: usize,
162 absolute: bool,
163 canonicalize: bool,
164 include_root: bool,
165 ) -> io::Result<(Vec<DirEntry>, Vec<io::Error>)> {
166 unsupported("read_dir")
167 }
168
169 #[allow(unused_variables)]
176 async fn create_dir(&self, ctx: DistantCtx, path: PathBuf, all: bool) -> io::Result<()> {
177 unsupported("create_dir")
178 }
179
180 #[allow(unused_variables)]
187 async fn copy(&self, ctx: DistantCtx, src: PathBuf, dst: PathBuf) -> io::Result<()> {
188 unsupported("copy")
189 }
190
191 #[allow(unused_variables)]
198 async fn remove(&self, ctx: DistantCtx, path: PathBuf, force: bool) -> io::Result<()> {
199 unsupported("remove")
200 }
201
202 #[allow(unused_variables)]
209 async fn rename(&self, ctx: DistantCtx, src: PathBuf, dst: PathBuf) -> io::Result<()> {
210 unsupported("rename")
211 }
212
213 #[allow(unused_variables)]
222 async fn watch(
223 &self,
224 ctx: DistantCtx,
225 path: PathBuf,
226 recursive: bool,
227 only: Vec<ChangeKind>,
228 except: Vec<ChangeKind>,
229 ) -> io::Result<()> {
230 unsupported("watch")
231 }
232
233 #[allow(unused_variables)]
239 async fn unwatch(&self, ctx: DistantCtx, path: PathBuf) -> io::Result<()> {
240 unsupported("unwatch")
241 }
242
243 #[allow(unused_variables)]
249 async fn exists(&self, ctx: DistantCtx, path: PathBuf) -> io::Result<bool> {
250 unsupported("exists")
251 }
252
253 #[allow(unused_variables)]
261 async fn metadata(
262 &self,
263 ctx: DistantCtx,
264 path: PathBuf,
265 canonicalize: bool,
266 resolve_file_type: bool,
267 ) -> io::Result<Metadata> {
268 unsupported("metadata")
269 }
270
271 #[allow(unused_variables)]
279 async fn set_permissions(
280 &self,
281 ctx: DistantCtx,
282 path: PathBuf,
283 permissions: Permissions,
284 options: SetPermissionsOptions,
285 ) -> io::Result<()> {
286 unsupported("set_permissions")
287 }
288
289 #[allow(unused_variables)]
295 async fn search(&self, ctx: DistantCtx, query: SearchQuery) -> io::Result<SearchId> {
296 unsupported("search")
297 }
298
299 #[allow(unused_variables)]
305 async fn cancel_search(&self, ctx: DistantCtx, id: SearchId) -> io::Result<()> {
306 unsupported("cancel_search")
307 }
308
309 #[allow(unused_variables)]
318 async fn proc_spawn(
319 &self,
320 ctx: DistantCtx,
321 cmd: String,
322 environment: Environment,
323 current_dir: Option<PathBuf>,
324 pty: Option<PtySize>,
325 ) -> io::Result<ProcessId> {
326 unsupported("proc_spawn")
327 }
328
329 #[allow(unused_variables)]
335 async fn proc_kill(&self, ctx: DistantCtx, id: ProcessId) -> io::Result<()> {
336 unsupported("proc_kill")
337 }
338
339 #[allow(unused_variables)]
346 async fn proc_stdin(&self, ctx: DistantCtx, id: ProcessId, data: Vec<u8>) -> io::Result<()> {
347 unsupported("proc_stdin")
348 }
349
350 #[allow(unused_variables)]
357 async fn proc_resize_pty(
358 &self,
359 ctx: DistantCtx,
360 id: ProcessId,
361 size: PtySize,
362 ) -> io::Result<()> {
363 unsupported("proc_resize_pty")
364 }
365
366 #[allow(unused_variables)]
370 async fn system_info(&self, ctx: DistantCtx) -> io::Result<SystemInfo> {
371 unsupported("system_info")
372 }
373}
374
375#[async_trait]
376impl<T> ServerHandler for DistantApiServerHandler<T>
377where
378 T: DistantApi + Send + Sync + 'static,
379{
380 type Request = protocol::Msg<protocol::Request>;
381 type Response = protocol::Msg<protocol::Response>;
382
383 async fn on_connect(&self, id: ConnectionId) -> io::Result<()> {
385 T::on_connect(&self.api, id).await
386 }
387
388 async fn on_disconnect(&self, id: ConnectionId) -> io::Result<()> {
390 T::on_disconnect(&self.api, id).await
391 }
392
393 async fn on_request(&self, ctx: RequestCtx<Self::Request, Self::Response>) {
394 let RequestCtx {
395 connection_id,
396 request,
397 reply,
398 } = ctx;
399
400 let reply = reply.queue();
403
404 let response = match request.payload {
406 protocol::Msg::Single(data) => {
407 let ctx = DistantCtx {
408 connection_id,
409 reply: Box::new(DistantSingleReply::from(reply.clone_reply())),
410 };
411
412 let data = handle_request(Arc::clone(&self.api), ctx, data).await;
413
414 if let protocol::Response::Error(x) = &data {
416 debug!("[Conn {}] {}", connection_id, x);
417 }
418
419 protocol::Msg::Single(data)
420 }
421 protocol::Msg::Batch(list)
422 if matches!(request.header.get_as("sequence"), Some(Ok(true))) =>
423 {
424 let mut out = Vec::new();
425 let mut has_failed = false;
426
427 for data in list {
428 if has_failed {
430 out.push(protocol::Response::Error(protocol::Error {
431 kind: protocol::ErrorKind::Interrupted,
432 description: String::from("Canceled due to earlier error"),
433 }));
434 continue;
435 }
436
437 let ctx = DistantCtx {
438 connection_id,
439 reply: Box::new(DistantSingleReply::from(reply.clone_reply())),
440 };
441
442 let data = handle_request(Arc::clone(&self.api), ctx, data).await;
443
444 if let protocol::Response::Error(x) = &data {
447 debug!("[Conn {}] {}", connection_id, x);
448 has_failed = true;
449 }
450
451 out.push(data);
452 }
453
454 protocol::Msg::Batch(out)
455 }
456 protocol::Msg::Batch(list) => {
457 let mut tasks = Vec::new();
458
459 for data in list {
463 let api = Arc::clone(&self.api);
464 let ctx = DistantCtx {
465 connection_id,
466 reply: Box::new(DistantSingleReply::from(reply.clone_reply())),
467 };
468
469 let task = tokio::spawn(async move {
470 let data = handle_request(api, ctx, data).await;
471
472 if let protocol::Response::Error(x) = &data {
474 debug!("[Conn {}] {}", connection_id, x);
475 }
476
477 data
478 });
479
480 tasks.push(task);
481 }
482
483 let out = futures::future::join_all(tasks)
484 .await
485 .into_iter()
486 .map(|x| match x {
487 Ok(x) => x,
488 Err(x) => protocol::Response::Error(x.to_string().into()),
489 })
490 .collect();
491 protocol::Msg::Batch(out)
492 }
493 };
494
495 if let Err(x) = reply.send_before(response) {
499 error!("[Conn {}] Failed to send response: {}", connection_id, x);
500 }
501
502 if let Err(x) = reply.flush(false) {
504 error!(
505 "[Conn {}] Failed to flush response queue: {}",
506 connection_id, x
507 );
508 }
509 }
510}
511
512async fn handle_request<T>(
514 api: Arc<T>,
515 ctx: DistantCtx,
516 request: protocol::Request,
517) -> protocol::Response
518where
519 T: DistantApi + Send + Sync,
520{
521 match request {
522 protocol::Request::Version {} => api
523 .version(ctx)
524 .await
525 .map(protocol::Response::Version)
526 .unwrap_or_else(protocol::Response::from),
527 protocol::Request::FileRead { path } => api
528 .read_file(ctx, path)
529 .await
530 .map(|data| protocol::Response::Blob { data })
531 .unwrap_or_else(protocol::Response::from),
532 protocol::Request::FileReadText { path } => api
533 .read_file_text(ctx, path)
534 .await
535 .map(|data| protocol::Response::Text { data })
536 .unwrap_or_else(protocol::Response::from),
537 protocol::Request::FileWrite { path, data } => api
538 .write_file(ctx, path, data)
539 .await
540 .map(|_| protocol::Response::Ok)
541 .unwrap_or_else(protocol::Response::from),
542 protocol::Request::FileWriteText { path, text } => api
543 .write_file_text(ctx, path, text)
544 .await
545 .map(|_| protocol::Response::Ok)
546 .unwrap_or_else(protocol::Response::from),
547 protocol::Request::FileAppend { path, data } => api
548 .append_file(ctx, path, data)
549 .await
550 .map(|_| protocol::Response::Ok)
551 .unwrap_or_else(protocol::Response::from),
552 protocol::Request::FileAppendText { path, text } => api
553 .append_file_text(ctx, path, text)
554 .await
555 .map(|_| protocol::Response::Ok)
556 .unwrap_or_else(protocol::Response::from),
557 protocol::Request::DirRead {
558 path,
559 depth,
560 absolute,
561 canonicalize,
562 include_root,
563 } => api
564 .read_dir(ctx, path, depth, absolute, canonicalize, include_root)
565 .await
566 .map(|(entries, errors)| protocol::Response::DirEntries {
567 entries,
568 errors: errors.into_iter().map(Error::from).collect(),
569 })
570 .unwrap_or_else(protocol::Response::from),
571 protocol::Request::DirCreate { path, all } => api
572 .create_dir(ctx, path, all)
573 .await
574 .map(|_| protocol::Response::Ok)
575 .unwrap_or_else(protocol::Response::from),
576 protocol::Request::Remove { path, force } => api
577 .remove(ctx, path, force)
578 .await
579 .map(|_| protocol::Response::Ok)
580 .unwrap_or_else(protocol::Response::from),
581 protocol::Request::Copy { src, dst } => api
582 .copy(ctx, src, dst)
583 .await
584 .map(|_| protocol::Response::Ok)
585 .unwrap_or_else(protocol::Response::from),
586 protocol::Request::Rename { src, dst } => api
587 .rename(ctx, src, dst)
588 .await
589 .map(|_| protocol::Response::Ok)
590 .unwrap_or_else(protocol::Response::from),
591 protocol::Request::Watch {
592 path,
593 recursive,
594 only,
595 except,
596 } => api
597 .watch(ctx, path, recursive, only, except)
598 .await
599 .map(|_| protocol::Response::Ok)
600 .unwrap_or_else(protocol::Response::from),
601 protocol::Request::Unwatch { path } => api
602 .unwatch(ctx, path)
603 .await
604 .map(|_| protocol::Response::Ok)
605 .unwrap_or_else(protocol::Response::from),
606 protocol::Request::Exists { path } => api
607 .exists(ctx, path)
608 .await
609 .map(|value| protocol::Response::Exists { value })
610 .unwrap_or_else(protocol::Response::from),
611 protocol::Request::Metadata {
612 path,
613 canonicalize,
614 resolve_file_type,
615 } => api
616 .metadata(ctx, path, canonicalize, resolve_file_type)
617 .await
618 .map(protocol::Response::Metadata)
619 .unwrap_or_else(protocol::Response::from),
620 protocol::Request::SetPermissions {
621 path,
622 permissions,
623 options,
624 } => api
625 .set_permissions(ctx, path, permissions, options)
626 .await
627 .map(|_| protocol::Response::Ok)
628 .unwrap_or_else(protocol::Response::from),
629 protocol::Request::Search { query } => api
630 .search(ctx, query)
631 .await
632 .map(|id| protocol::Response::SearchStarted { id })
633 .unwrap_or_else(protocol::Response::from),
634 protocol::Request::CancelSearch { id } => api
635 .cancel_search(ctx, id)
636 .await
637 .map(|_| protocol::Response::Ok)
638 .unwrap_or_else(protocol::Response::from),
639 protocol::Request::ProcSpawn {
640 cmd,
641 environment,
642 current_dir,
643 pty,
644 } => api
645 .proc_spawn(ctx, cmd.into(), environment, current_dir, pty)
646 .await
647 .map(|id| protocol::Response::ProcSpawned { id })
648 .unwrap_or_else(protocol::Response::from),
649 protocol::Request::ProcKill { id } => api
650 .proc_kill(ctx, id)
651 .await
652 .map(|_| protocol::Response::Ok)
653 .unwrap_or_else(protocol::Response::from),
654 protocol::Request::ProcStdin { id, data } => api
655 .proc_stdin(ctx, id, data)
656 .await
657 .map(|_| protocol::Response::Ok)
658 .unwrap_or_else(protocol::Response::from),
659 protocol::Request::ProcResizePty { id, size } => api
660 .proc_resize_pty(ctx, id, size)
661 .await
662 .map(|_| protocol::Response::Ok)
663 .unwrap_or_else(protocol::Response::from),
664 protocol::Request::SystemInfo {} => api
665 .system_info(ctx)
666 .await
667 .map(protocol::Response::SystemInfo)
668 .unwrap_or_else(protocol::Response::from),
669 }
670}