distant_core/
api.rs

1use std::io;
2use std::path::PathBuf;
3use std::sync::Arc;
4
5use async_trait::async_trait;
6use distant_net::common::ConnectionId;
7use distant_net::server::{Reply, RequestCtx, ServerHandler};
8use log::*;
9
10use crate::protocol::{
11    self, ChangeKind, DirEntry, Environment, Error, Metadata, Permissions, ProcessId, PtySize,
12    SearchId, SearchQuery, SetPermissionsOptions, SystemInfo, Version,
13};
14
15mod reply;
16use reply::DistantSingleReply;
17
18/// Represents the context provided to the [`DistantApi`] for incoming requests
19pub struct DistantCtx {
20    pub connection_id: ConnectionId,
21    pub reply: Box<dyn Reply<Data = protocol::Response>>,
22}
23
24/// Represents a [`ServerHandler`] that leverages an API compliant with `distant`
25pub struct DistantApiServerHandler<T>
26where
27    T: DistantApi,
28{
29    api: Arc<T>,
30}
31
32impl<T> DistantApiServerHandler<T>
33where
34    T: DistantApi,
35{
36    pub fn new(api: T) -> Self {
37        Self { api: Arc::new(api) }
38    }
39}
40
41#[inline]
42fn unsupported<T>(label: &str) -> io::Result<T> {
43    Err(io::Error::new(
44        io::ErrorKind::Unsupported,
45        format!("{label} is unsupported"),
46    ))
47}
48
49/// Interface to support the suite of functionality available with distant,
50/// which can be used to build other servers that are compatible with distant
51#[async_trait]
52pub trait DistantApi {
53    /// Invoked whenever a new connection is established.
54    #[allow(unused_variables)]
55    async fn on_connect(&self, id: ConnectionId) -> io::Result<()> {
56        Ok(())
57    }
58
59    /// Invoked whenever an existing connection is dropped.
60    #[allow(unused_variables)]
61    async fn on_disconnect(&self, id: ConnectionId) -> io::Result<()> {
62        Ok(())
63    }
64
65    /// Retrieves information about the server's capabilities.
66    ///
67    /// *Override this, otherwise it will return "unsupported" as an error.*
68    #[allow(unused_variables)]
69    async fn version(&self, ctx: DistantCtx) -> io::Result<Version> {
70        unsupported("version")
71    }
72
73    /// Reads bytes from a file.
74    ///
75    /// * `path` - the path to the file
76    ///
77    /// *Override this, otherwise it will return "unsupported" as an error.*
78    #[allow(unused_variables)]
79    async fn read_file(&self, ctx: DistantCtx, path: PathBuf) -> io::Result<Vec<u8>> {
80        unsupported("read_file")
81    }
82
83    /// Reads bytes from a file as text.
84    ///
85    /// * `path` - the path to the file
86    ///
87    /// *Override this, otherwise it will return "unsupported" as an error.*
88    #[allow(unused_variables)]
89    async fn read_file_text(&self, ctx: DistantCtx, path: PathBuf) -> io::Result<String> {
90        unsupported("read_file_text")
91    }
92
93    /// Writes bytes to a file, overwriting the file if it exists.
94    ///
95    /// * `path` - the path to the file
96    /// * `data` - the data to write
97    ///
98    /// *Override this, otherwise it will return "unsupported" as an error.*
99    #[allow(unused_variables)]
100    async fn write_file(&self, ctx: DistantCtx, path: PathBuf, data: Vec<u8>) -> io::Result<()> {
101        unsupported("write_file")
102    }
103
104    /// Writes text to a file, overwriting the file if it exists.
105    ///
106    /// * `path` - the path to the file
107    /// * `data` - the data to write
108    ///
109    /// *Override this, otherwise it will return "unsupported" as an error.*
110    #[allow(unused_variables)]
111    async fn write_file_text(
112        &self,
113        ctx: DistantCtx,
114        path: PathBuf,
115        data: String,
116    ) -> io::Result<()> {
117        unsupported("write_file_text")
118    }
119
120    /// Writes bytes to the end of a file, creating it if it is missing.
121    ///
122    /// * `path` - the path to the file
123    /// * `data` - the data to append
124    ///
125    /// *Override this, otherwise it will return "unsupported" as an error.*
126    #[allow(unused_variables)]
127    async fn append_file(&self, ctx: DistantCtx, path: PathBuf, data: Vec<u8>) -> io::Result<()> {
128        unsupported("append_file")
129    }
130
131    /// Writes bytes to the end of a file, creating it if it is missing.
132    ///
133    /// * `path` - the path to the file
134    /// * `data` - the data to append
135    ///
136    /// *Override this, otherwise it will return "unsupported" as an error.*
137    #[allow(unused_variables)]
138    async fn append_file_text(
139        &self,
140        ctx: DistantCtx,
141        path: PathBuf,
142        data: String,
143    ) -> io::Result<()> {
144        unsupported("append_file_text")
145    }
146
147    /// Reads entries from a directory.
148    ///
149    /// * `path` - the path to the directory
150    /// * `depth` - how far to traverse the directory, 0 being unlimited
151    /// * `absolute` - if true, will return absolute paths instead of relative paths
152    /// * `canonicalize` - if true, will canonicalize entry paths before returned
153    /// * `include_root` - if true, will include the directory specified in the entries
154    ///
155    /// *Override this, otherwise it will return "unsupported" as an error.*
156    #[allow(unused_variables)]
157    async fn read_dir(
158        &self,
159        ctx: DistantCtx,
160        path: PathBuf,
161        depth: usize,
162        absolute: bool,
163        canonicalize: bool,
164        include_root: bool,
165    ) -> io::Result<(Vec<DirEntry>, Vec<io::Error>)> {
166        unsupported("read_dir")
167    }
168
169    /// Creates a directory.
170    ///
171    /// * `path` - the path to the directory
172    /// * `all` - if true, will create all missing parent components
173    ///
174    /// *Override this, otherwise it will return "unsupported" as an error.*
175    #[allow(unused_variables)]
176    async fn create_dir(&self, ctx: DistantCtx, path: PathBuf, all: bool) -> io::Result<()> {
177        unsupported("create_dir")
178    }
179
180    /// Copies some file or directory.
181    ///
182    /// * `src` - the path to the file or directory to copy
183    /// * `dst` - the path where the copy will be placed
184    ///
185    /// *Override this, otherwise it will return "unsupported" as an error.*
186    #[allow(unused_variables)]
187    async fn copy(&self, ctx: DistantCtx, src: PathBuf, dst: PathBuf) -> io::Result<()> {
188        unsupported("copy")
189    }
190
191    /// Removes some file or directory.
192    ///
193    /// * `path` - the path to a file or directory
194    /// * `force` - if true, will remove non-empty directories
195    ///
196    /// *Override this, otherwise it will return "unsupported" as an error.*
197    #[allow(unused_variables)]
198    async fn remove(&self, ctx: DistantCtx, path: PathBuf, force: bool) -> io::Result<()> {
199        unsupported("remove")
200    }
201
202    /// Renames some file or directory.
203    ///
204    /// * `src` - the path to the file or directory to rename
205    /// * `dst` - the new name for the file or directory
206    ///
207    /// *Override this, otherwise it will return "unsupported" as an error.*
208    #[allow(unused_variables)]
209    async fn rename(&self, ctx: DistantCtx, src: PathBuf, dst: PathBuf) -> io::Result<()> {
210        unsupported("rename")
211    }
212
213    /// Watches a file or directory for changes.
214    ///
215    /// * `path` - the path to the file or directory
216    /// * `recursive` - if true, will watch for changes within subdirectories and beyond
217    /// * `only` - if non-empty, will limit reported changes to those included in this list
218    /// * `except` - if non-empty, will limit reported changes to those not included in this list
219    ///
220    /// *Override this, otherwise it will return "unsupported" as an error.*
221    #[allow(unused_variables)]
222    async fn watch(
223        &self,
224        ctx: DistantCtx,
225        path: PathBuf,
226        recursive: bool,
227        only: Vec<ChangeKind>,
228        except: Vec<ChangeKind>,
229    ) -> io::Result<()> {
230        unsupported("watch")
231    }
232
233    /// Removes a file or directory from being watched.
234    ///
235    /// * `path` - the path to the file or directory
236    ///
237    /// *Override this, otherwise it will return "unsupported" as an error.*
238    #[allow(unused_variables)]
239    async fn unwatch(&self, ctx: DistantCtx, path: PathBuf) -> io::Result<()> {
240        unsupported("unwatch")
241    }
242
243    /// Checks if the specified path exists.
244    ///
245    /// * `path` - the path to the file or directory
246    ///
247    /// *Override this, otherwise it will return "unsupported" as an error.*
248    #[allow(unused_variables)]
249    async fn exists(&self, ctx: DistantCtx, path: PathBuf) -> io::Result<bool> {
250        unsupported("exists")
251    }
252
253    /// Reads metadata for a file or directory.
254    ///
255    /// * `path` - the path to the file or directory
256    /// * `canonicalize` - if true, will include a canonicalized path in the metadata
257    /// * `resolve_file_type` - if true, will resolve symlinks to underlying type (file or dir)
258    ///
259    /// *Override this, otherwise it will return "unsupported" as an error.*
260    #[allow(unused_variables)]
261    async fn metadata(
262        &self,
263        ctx: DistantCtx,
264        path: PathBuf,
265        canonicalize: bool,
266        resolve_file_type: bool,
267    ) -> io::Result<Metadata> {
268        unsupported("metadata")
269    }
270
271    /// Sets permissions for a file, directory, or symlink.
272    ///
273    /// * `path` - the path to the file, directory, or symlink
274    /// * `resolve_symlink` - if true, will resolve the path to the underlying file/directory
275    /// * `permissions` - the new permissions to apply
276    ///
277    /// *Override this, otherwise it will return "unsupported" as an error.*
278    #[allow(unused_variables)]
279    async fn set_permissions(
280        &self,
281        ctx: DistantCtx,
282        path: PathBuf,
283        permissions: Permissions,
284        options: SetPermissionsOptions,
285    ) -> io::Result<()> {
286        unsupported("set_permissions")
287    }
288
289    /// Searches files for matches based on a query.
290    ///
291    /// * `query` - the specific query to perform
292    ///
293    /// *Override this, otherwise it will return "unsupported" as an error.*
294    #[allow(unused_variables)]
295    async fn search(&self, ctx: DistantCtx, query: SearchQuery) -> io::Result<SearchId> {
296        unsupported("search")
297    }
298
299    /// Cancels an actively-ongoing search.
300    ///
301    /// * `id` - the id of the search to cancel
302    ///
303    /// *Override this, otherwise it will return "unsupported" as an error.*
304    #[allow(unused_variables)]
305    async fn cancel_search(&self, ctx: DistantCtx, id: SearchId) -> io::Result<()> {
306        unsupported("cancel_search")
307    }
308
309    /// Spawns a new process, returning its id.
310    ///
311    /// * `cmd` - the full command to run as a new process (including arguments)
312    /// * `environment` - the environment variables to associate with the process
313    /// * `current_dir` - the alternative current directory to use with the process
314    /// * `pty` - if provided, will run the process within a PTY of the given size
315    ///
316    /// *Override this, otherwise it will return "unsupported" as an error.*
317    #[allow(unused_variables)]
318    async fn proc_spawn(
319        &self,
320        ctx: DistantCtx,
321        cmd: String,
322        environment: Environment,
323        current_dir: Option<PathBuf>,
324        pty: Option<PtySize>,
325    ) -> io::Result<ProcessId> {
326        unsupported("proc_spawn")
327    }
328
329    /// Kills a running process by its id.
330    ///
331    /// * `id` - the unique id of the process
332    ///
333    /// *Override this, otherwise it will return "unsupported" as an error.*
334    #[allow(unused_variables)]
335    async fn proc_kill(&self, ctx: DistantCtx, id: ProcessId) -> io::Result<()> {
336        unsupported("proc_kill")
337    }
338
339    /// Sends data to the stdin of the process with the specified id.
340    ///
341    /// * `id` - the unique id of the process
342    /// * `data` - the bytes to send to stdin
343    ///
344    /// *Override this, otherwise it will return "unsupported" as an error.*
345    #[allow(unused_variables)]
346    async fn proc_stdin(&self, ctx: DistantCtx, id: ProcessId, data: Vec<u8>) -> io::Result<()> {
347        unsupported("proc_stdin")
348    }
349
350    /// Resizes the PTY of the process with the specified id.
351    ///
352    /// * `id` - the unique id of the process
353    /// * `size` - the new size of the pty
354    ///
355    /// *Override this, otherwise it will return "unsupported" as an error.*
356    #[allow(unused_variables)]
357    async fn proc_resize_pty(
358        &self,
359        ctx: DistantCtx,
360        id: ProcessId,
361        size: PtySize,
362    ) -> io::Result<()> {
363        unsupported("proc_resize_pty")
364    }
365
366    /// Retrieves information about the system.
367    ///
368    /// *Override this, otherwise it will return "unsupported" as an error.*
369    #[allow(unused_variables)]
370    async fn system_info(&self, ctx: DistantCtx) -> io::Result<SystemInfo> {
371        unsupported("system_info")
372    }
373}
374
375#[async_trait]
376impl<T> ServerHandler for DistantApiServerHandler<T>
377where
378    T: DistantApi + Send + Sync + 'static,
379{
380    type Request = protocol::Msg<protocol::Request>;
381    type Response = protocol::Msg<protocol::Response>;
382
383    /// Overridden to leverage [`DistantApi`] implementation of `on_connect`.
384    async fn on_connect(&self, id: ConnectionId) -> io::Result<()> {
385        T::on_connect(&self.api, id).await
386    }
387
388    /// Overridden to leverage [`DistantApi`] implementation of `on_disconnect`.
389    async fn on_disconnect(&self, id: ConnectionId) -> io::Result<()> {
390        T::on_disconnect(&self.api, id).await
391    }
392
393    async fn on_request(&self, ctx: RequestCtx<Self::Request, Self::Response>) {
394        let RequestCtx {
395            connection_id,
396            request,
397            reply,
398        } = ctx;
399
400        // Convert our reply to a queued reply so we can ensure that the result
401        // of an API function is sent back before anything else
402        let reply = reply.queue();
403
404        // Process single vs batch requests
405        let response = match request.payload {
406            protocol::Msg::Single(data) => {
407                let ctx = DistantCtx {
408                    connection_id,
409                    reply: Box::new(DistantSingleReply::from(reply.clone_reply())),
410                };
411
412                let data = handle_request(Arc::clone(&self.api), ctx, data).await;
413
414                // Report outgoing errors in our debug logs
415                if let protocol::Response::Error(x) = &data {
416                    debug!("[Conn {}] {}", connection_id, x);
417                }
418
419                protocol::Msg::Single(data)
420            }
421            protocol::Msg::Batch(list)
422                if matches!(request.header.get_as("sequence"), Some(Ok(true))) =>
423            {
424                let mut out = Vec::new();
425                let mut has_failed = false;
426
427                for data in list {
428                    // Once we hit a failure, all remaining requests return interrupted
429                    if has_failed {
430                        out.push(protocol::Response::Error(protocol::Error {
431                            kind: protocol::ErrorKind::Interrupted,
432                            description: String::from("Canceled due to earlier error"),
433                        }));
434                        continue;
435                    }
436
437                    let ctx = DistantCtx {
438                        connection_id,
439                        reply: Box::new(DistantSingleReply::from(reply.clone_reply())),
440                    };
441
442                    let data = handle_request(Arc::clone(&self.api), ctx, data).await;
443
444                    // Report outgoing errors in our debug logs and mark as failed
445                    // to cancel any future tasks being run
446                    if let protocol::Response::Error(x) = &data {
447                        debug!("[Conn {}] {}", connection_id, x);
448                        has_failed = true;
449                    }
450
451                    out.push(data);
452                }
453
454                protocol::Msg::Batch(out)
455            }
456            protocol::Msg::Batch(list) => {
457                let mut tasks = Vec::new();
458
459                // If sequence specified as true, we want to process in order, otherwise we can
460                // process in any order
461
462                for data in list {
463                    let api = Arc::clone(&self.api);
464                    let ctx = DistantCtx {
465                        connection_id,
466                        reply: Box::new(DistantSingleReply::from(reply.clone_reply())),
467                    };
468
469                    let task = tokio::spawn(async move {
470                        let data = handle_request(api, ctx, data).await;
471
472                        // Report outgoing errors in our debug logs
473                        if let protocol::Response::Error(x) = &data {
474                            debug!("[Conn {}] {}", connection_id, x);
475                        }
476
477                        data
478                    });
479
480                    tasks.push(task);
481                }
482
483                let out = futures::future::join_all(tasks)
484                    .await
485                    .into_iter()
486                    .map(|x| match x {
487                        Ok(x) => x,
488                        Err(x) => protocol::Response::Error(x.to_string().into()),
489                    })
490                    .collect();
491                protocol::Msg::Batch(out)
492            }
493        };
494
495        // Queue up our result to go before ANY of the other messages that might be sent.
496        // This is important to avoid situations such as when a process is started, but before
497        // the confirmation can be sent some stdout or stderr is captured and sent first.
498        if let Err(x) = reply.send_before(response) {
499            error!("[Conn {}] Failed to send response: {}", connection_id, x);
500        }
501
502        // Flush out all of our replies thus far and toggle to no longer hold submissions
503        if let Err(x) = reply.flush(false) {
504            error!(
505                "[Conn {}] Failed to flush response queue: {}",
506                connection_id, x
507            );
508        }
509    }
510}
511
512/// Processes an incoming request
513async fn handle_request<T>(
514    api: Arc<T>,
515    ctx: DistantCtx,
516    request: protocol::Request,
517) -> protocol::Response
518where
519    T: DistantApi + Send + Sync,
520{
521    match request {
522        protocol::Request::Version {} => api
523            .version(ctx)
524            .await
525            .map(protocol::Response::Version)
526            .unwrap_or_else(protocol::Response::from),
527        protocol::Request::FileRead { path } => api
528            .read_file(ctx, path)
529            .await
530            .map(|data| protocol::Response::Blob { data })
531            .unwrap_or_else(protocol::Response::from),
532        protocol::Request::FileReadText { path } => api
533            .read_file_text(ctx, path)
534            .await
535            .map(|data| protocol::Response::Text { data })
536            .unwrap_or_else(protocol::Response::from),
537        protocol::Request::FileWrite { path, data } => api
538            .write_file(ctx, path, data)
539            .await
540            .map(|_| protocol::Response::Ok)
541            .unwrap_or_else(protocol::Response::from),
542        protocol::Request::FileWriteText { path, text } => api
543            .write_file_text(ctx, path, text)
544            .await
545            .map(|_| protocol::Response::Ok)
546            .unwrap_or_else(protocol::Response::from),
547        protocol::Request::FileAppend { path, data } => api
548            .append_file(ctx, path, data)
549            .await
550            .map(|_| protocol::Response::Ok)
551            .unwrap_or_else(protocol::Response::from),
552        protocol::Request::FileAppendText { path, text } => api
553            .append_file_text(ctx, path, text)
554            .await
555            .map(|_| protocol::Response::Ok)
556            .unwrap_or_else(protocol::Response::from),
557        protocol::Request::DirRead {
558            path,
559            depth,
560            absolute,
561            canonicalize,
562            include_root,
563        } => api
564            .read_dir(ctx, path, depth, absolute, canonicalize, include_root)
565            .await
566            .map(|(entries, errors)| protocol::Response::DirEntries {
567                entries,
568                errors: errors.into_iter().map(Error::from).collect(),
569            })
570            .unwrap_or_else(protocol::Response::from),
571        protocol::Request::DirCreate { path, all } => api
572            .create_dir(ctx, path, all)
573            .await
574            .map(|_| protocol::Response::Ok)
575            .unwrap_or_else(protocol::Response::from),
576        protocol::Request::Remove { path, force } => api
577            .remove(ctx, path, force)
578            .await
579            .map(|_| protocol::Response::Ok)
580            .unwrap_or_else(protocol::Response::from),
581        protocol::Request::Copy { src, dst } => api
582            .copy(ctx, src, dst)
583            .await
584            .map(|_| protocol::Response::Ok)
585            .unwrap_or_else(protocol::Response::from),
586        protocol::Request::Rename { src, dst } => api
587            .rename(ctx, src, dst)
588            .await
589            .map(|_| protocol::Response::Ok)
590            .unwrap_or_else(protocol::Response::from),
591        protocol::Request::Watch {
592            path,
593            recursive,
594            only,
595            except,
596        } => api
597            .watch(ctx, path, recursive, only, except)
598            .await
599            .map(|_| protocol::Response::Ok)
600            .unwrap_or_else(protocol::Response::from),
601        protocol::Request::Unwatch { path } => api
602            .unwatch(ctx, path)
603            .await
604            .map(|_| protocol::Response::Ok)
605            .unwrap_or_else(protocol::Response::from),
606        protocol::Request::Exists { path } => api
607            .exists(ctx, path)
608            .await
609            .map(|value| protocol::Response::Exists { value })
610            .unwrap_or_else(protocol::Response::from),
611        protocol::Request::Metadata {
612            path,
613            canonicalize,
614            resolve_file_type,
615        } => api
616            .metadata(ctx, path, canonicalize, resolve_file_type)
617            .await
618            .map(protocol::Response::Metadata)
619            .unwrap_or_else(protocol::Response::from),
620        protocol::Request::SetPermissions {
621            path,
622            permissions,
623            options,
624        } => api
625            .set_permissions(ctx, path, permissions, options)
626            .await
627            .map(|_| protocol::Response::Ok)
628            .unwrap_or_else(protocol::Response::from),
629        protocol::Request::Search { query } => api
630            .search(ctx, query)
631            .await
632            .map(|id| protocol::Response::SearchStarted { id })
633            .unwrap_or_else(protocol::Response::from),
634        protocol::Request::CancelSearch { id } => api
635            .cancel_search(ctx, id)
636            .await
637            .map(|_| protocol::Response::Ok)
638            .unwrap_or_else(protocol::Response::from),
639        protocol::Request::ProcSpawn {
640            cmd,
641            environment,
642            current_dir,
643            pty,
644        } => api
645            .proc_spawn(ctx, cmd.into(), environment, current_dir, pty)
646            .await
647            .map(|id| protocol::Response::ProcSpawned { id })
648            .unwrap_or_else(protocol::Response::from),
649        protocol::Request::ProcKill { id } => api
650            .proc_kill(ctx, id)
651            .await
652            .map(|_| protocol::Response::Ok)
653            .unwrap_or_else(protocol::Response::from),
654        protocol::Request::ProcStdin { id, data } => api
655            .proc_stdin(ctx, id, data)
656            .await
657            .map(|_| protocol::Response::Ok)
658            .unwrap_or_else(protocol::Response::from),
659        protocol::Request::ProcResizePty { id, size } => api
660            .proc_resize_pty(ctx, id, size)
661            .await
662            .map(|_| protocol::Response::Ok)
663            .unwrap_or_else(protocol::Response::from),
664        protocol::Request::SystemInfo {} => api
665            .system_info(ctx)
666            .await
667            .map(protocol::Response::SystemInfo)
668            .unwrap_or_else(protocol::Response::from),
669    }
670}