intelli_shell/service/
export.rs

1use std::{
2    collections::{HashMap, HashSet},
3    env,
4    io::ErrorKind,
5};
6
7use async_stream::try_stream;
8use color_eyre::Report;
9use futures_util::StreamExt;
10use regex::Regex;
11use reqwest::{
12    StatusCode, Url,
13    header::{self, HeaderName, HeaderValue},
14};
15use tokio::{fs::File, io::AsyncWriteExt};
16use tracing::instrument;
17
18use super::IntelliShellService;
19use crate::{
20    cli::{ExportItemsProcess, HttpMethod},
21    config::GistConfig,
22    errors::{AppError, Result, UserFacingError},
23    model::{ExportStats, ImportExportItem, ImportExportStream},
24    utils::{
25        ShellType,
26        dto::{GIST_README_FILENAME, GIST_README_FILENAME_UPPER, GistDto, GistFileDto, ImportExportItemDto},
27        extract_gist_data, extract_variables, flatten_str, get_export_gist_token, get_shell_type,
28    },
29};
30
31impl IntelliShellService {
32    /// Prepare a stream of items to export, optionally filtering commands
33    pub async fn prepare_items_export(&self, filter: Option<Regex>) -> Result<ImportExportStream> {
34        if let Some(ref filter) = filter {
35            tracing::info!("Exporting commands matching `{filter}` and their related completions");
36        } else {
37            tracing::info!("Exporting all commands and completions");
38        }
39
40        let storage = self.storage.clone();
41        let export_stream = try_stream! {
42            // This set will accumulate the unique variable identifiers as we stream commands
43            let mut unique_flat_vars = HashSet::new();
44
45            // Get the initial stream of commands from the storage layer
46            let is_filtered = filter.is_some();
47            let mut command_stream = storage.export_user_commands(filter).await;
48
49            // Process each command from the stream one by one
50            while let Some(command_result) = command_stream.next().await {
51                let command = command_result?;
52
53                // Extract all variables from the command and accumulate them for later
54                if is_filtered {
55                    let flat_root_cmd = flatten_str(command.cmd.split_whitespace().next().unwrap_or(""));
56                    if !flat_root_cmd.is_empty() {
57                        let variables = extract_variables(&command.cmd);
58                        for variable in variables {
59                            for flat_name in variable.flat_names {
60                                unique_flat_vars.insert((flat_root_cmd.clone(), flat_name));
61                            }
62                        }
63                    }
64                }
65
66                // Yield the command immediately
67                yield ImportExportItem::Command(command);
68            }
69
70            // Once the command stream is exhausted, export completions
71            let completions = if is_filtered {
72                // When filtering commands, export only related completions
73                storage.export_user_variable_completions(unique_flat_vars).await?
74            } else {
75                // Otherwise, export all completions
76                storage.list_variable_completions(None, None, true).await?
77            };
78            // Yield each completion
79            for completion in completions {
80                yield ImportExportItem::Completion(completion);
81            }
82        };
83
84        Ok(Box::pin(export_stream))
85    }
86
87    /// Exports given commands and completions
88    pub async fn export_items(
89        &self,
90        items: ImportExportStream,
91        args: ExportItemsProcess,
92        gist_config: GistConfig,
93    ) -> Result<ExportStats> {
94        let ExportItemsProcess {
95            location,
96            file,
97            http,
98            gist,
99            filter: _,
100            headers,
101            method,
102        } = args;
103
104        if file {
105            if location == "-" {
106                self.export_stdout_items(items).await
107            } else {
108                self.export_file_items(items, location).await
109            }
110        } else if http {
111            self.export_http_items(items, location, headers, method).await
112        } else if gist {
113            self.export_gist_items(items, location, gist_config).await
114        } else {
115            // Determine which mode based on the location
116            if location == "gist"
117                || location.starts_with("https://gist.github.com")
118                || location.starts_with("https://gist.githubusercontent.com")
119                || location.starts_with("https://api.github.com/gists")
120            {
121                self.export_gist_items(items, location, gist_config).await
122            } else if location.starts_with("http://") || location.starts_with("https://") {
123                self.export_http_items(items, location, headers, method).await
124            } else if location == "-" {
125                self.export_stdout_items(items).await
126            } else {
127                self.export_file_items(items, location).await
128            }
129        }
130    }
131
132    #[instrument(skip_all)]
133    async fn export_stdout_items(&self, mut items: ImportExportStream) -> Result<ExportStats> {
134        tracing::info!("Writing items to stdout");
135        let mut stats = ExportStats::default();
136        let mut stdout = String::new();
137        while let Some(item) = items.next().await {
138            stdout += &match item? {
139                ImportExportItem::Command(c) => {
140                    stats.commands_exported += 1;
141                    c.to_string()
142                }
143                ImportExportItem::Completion(c) => {
144                    stats.completions_exported += 1;
145                    c.to_string()
146                }
147            };
148            stdout += "\n";
149        }
150        stats.stdout = Some(stdout);
151        Ok(stats)
152    }
153
154    #[instrument(skip_all)]
155    async fn export_file_items(&self, mut items: ImportExportStream, path: String) -> Result<ExportStats> {
156        let mut file = match File::create(&path).await {
157            Ok(f) => f,
158            Err(err) if err.kind() == ErrorKind::PermissionDenied => {
159                return Err(UserFacingError::FileNotAccessible("write").into());
160            }
161            Err(err) if err.kind() == ErrorKind::NotFound => {
162                return Err(UserFacingError::ExportFileParentNotFound.into());
163            }
164            Err(err) if err.kind() == ErrorKind::IsADirectory => {
165                return Err(UserFacingError::ExportLocationNotAFile.into());
166            }
167            Err(err) => return Err(Report::from(err).into()),
168        };
169        tracing::info!("Writing items to file: {path}");
170
171        let mut stats = ExportStats::default();
172        while let Some(item) = items.next().await {
173            let content = match item? {
174                ImportExportItem::Command(c) => {
175                    stats.commands_exported += 1;
176                    format!("{c}\n")
177                }
178                ImportExportItem::Completion(c) => {
179                    stats.completions_exported += 1;
180                    format!("{c}\n")
181                }
182            };
183            file.write_all(content.as_bytes()).await.map_err(|err| {
184                if err.kind() == ErrorKind::BrokenPipe {
185                    AppError::from(UserFacingError::FileBrokenPipe)
186                } else {
187                    AppError::from(err)
188                }
189            })?;
190        }
191        file.flush().await?;
192        Ok(stats)
193    }
194
195    #[instrument(skip_all)]
196    async fn export_http_items(
197        &self,
198        mut items: ImportExportStream,
199        url: String,
200        headers: Vec<(HeaderName, HeaderValue)>,
201        method: HttpMethod,
202    ) -> Result<ExportStats> {
203        // Parse the URL
204        let url = Url::parse(&url).map_err(|err| {
205            tracing::error!("Couldn't parse url: {err}");
206            UserFacingError::HttpInvalidUrl
207        })?;
208
209        let method = method.into();
210        tracing::info!("Writing items to http: {method} {url}");
211
212        // Collect items to export
213        let mut stats = ExportStats::default();
214        let mut items_to_export = Vec::new();
215        while let Some(item) = items.next().await {
216            items_to_export.push(match item? {
217                ImportExportItem::Command(c) => {
218                    stats.commands_exported += 1;
219                    ImportExportItemDto::Command(c.into())
220                }
221                ImportExportItem::Completion(c) => {
222                    stats.completions_exported += 1;
223                    ImportExportItemDto::Completion(c.into())
224                }
225            });
226        }
227
228        // Build the request
229        let client = reqwest::Client::new();
230        let mut req = client.request(method, url);
231
232        // Add headers
233        for (name, value) in headers {
234            tracing::debug!("Appending '{name}' header");
235            req = req.header(name, value);
236        }
237
238        // Set JSON body
239        req = req.json(&items_to_export);
240
241        // Send the request
242        let res = req.send().await.map_err(|err| {
243            tracing::error!("{err:?}");
244            UserFacingError::HttpRequestFailed(err.to_string())
245        })?;
246
247        // Check the response status
248        if !res.status().is_success() {
249            let status = res.status();
250            let status_str = status.as_str();
251            let body = res.text().await.unwrap_or_default();
252            if let Some(reason) = status.canonical_reason() {
253                tracing::error!("Got response [{status_str}] {reason}:\n{body}");
254                return Err(
255                    UserFacingError::HttpRequestFailed(format!("received {status_str} {reason} response")).into(),
256                );
257            } else {
258                tracing::error!("Got response [{status_str}]:\n{body}");
259                return Err(UserFacingError::HttpRequestFailed(format!("received {status_str} response")).into());
260            }
261        }
262
263        Ok(stats)
264    }
265
266    #[instrument(skip_all)]
267    async fn export_gist_items(
268        &self,
269        mut items: ImportExportStream,
270        gist: String,
271        gist_config: GistConfig,
272    ) -> Result<ExportStats> {
273        // Retrieve the gist id and optional sha and file
274        let (gist_id, gist_sha, gist_file) = extract_gist_data(&gist, &gist_config)?;
275
276        // If a sha is found, return an error as we can't modify it
277        if gist_sha.is_some() {
278            return Err(UserFacingError::ExportGistLocationHasSha.into());
279        }
280
281        // Retrieve the gist token to be used
282        let gist_token = get_export_gist_token(&gist_config)?;
283
284        let url = format!("https://api.github.com/gists/{gist_id}");
285        tracing::info!("Writing items to gist: {url}");
286
287        // Retrieve the gist to verify its existence
288        let client = reqwest::Client::new();
289        let res = client
290            .get(&url)
291            .header(header::ACCEPT, "application/vnd.github+json")
292            .header(header::USER_AGENT, "intelli-shell")
293            .header("X-GitHub-Api-Version", "2022-11-28")
294            .send()
295            .await
296            .map_err(|err| {
297                tracing::error!("{err:?}");
298                UserFacingError::GistRequestFailed(err.to_string())
299            })?;
300
301        // Check the response status
302        if !res.status().is_success() {
303            let status = res.status();
304            let status_str = status.as_str();
305            let body = res.text().await.unwrap_or_default();
306            if let Some(reason) = status.canonical_reason() {
307                tracing::error!("Got response [{status_str}] {reason}:\n{body}");
308                return Err(
309                    UserFacingError::GistRequestFailed(format!("received {status_str} {reason} response")).into(),
310                );
311            } else {
312                tracing::error!("Got response [{status_str}]:\n{body}");
313                return Err(UserFacingError::GistRequestFailed(format!("received {status_str} response")).into());
314            }
315        }
316
317        // Parse the body as a json
318        let actual_gist: GistDto = match res.json().await {
319            Ok(b) => b,
320            Err(err) if err.is_decode() => {
321                tracing::error!("Couldn't parse api response: {err}");
322                return Err(UserFacingError::GistRequestFailed(String::from("couldn't parse api response")).into());
323            }
324            Err(err) => {
325                tracing::error!("{err:?}");
326                return Err(UserFacingError::GistRequestFailed(err.to_string()).into());
327            }
328        };
329
330        // Determine the extension based on the file or shell
331        let extension = if let Some(ref gist_file) = gist_file
332            && let Some((_, ext)) = gist_file.rfind('.').map(|i| gist_file.split_at(i))
333        {
334            ext.to_owned()
335        } else {
336            match get_shell_type() {
337                ShellType::Cmd => ".cmd",
338                ShellType::WindowsPowerShell | ShellType::PowerShellCore => ".ps1",
339                _ => ".sh",
340            }
341            .to_owned()
342        };
343
344        // Collect items to export
345        let mut stats = ExportStats::default();
346        let mut content = String::new();
347        while let Some(item) = items.next().await {
348            match item? {
349                ImportExportItem::Command(c) => {
350                    stats.commands_exported += 1;
351                    content.push_str(&c.to_string());
352                }
353                ImportExportItem::Completion(c) => {
354                    stats.completions_exported += 1;
355                    content.push_str(&c.to_string());
356                }
357            }
358            content.push('\n');
359        }
360
361        // Prepare the data to be sent
362        let explicit_file = gist_file.is_some();
363        let mut files = vec![(
364            gist_file
365                .or_else(|| {
366                    let command_files = actual_gist
367                        .files
368                        .keys()
369                        .filter(|f| f.ends_with(&extension))
370                        .collect::<Vec<_>>();
371                    if command_files.len() == 1 {
372                        Some(command_files[0].to_string())
373                    } else {
374                        None
375                    }
376                })
377                .unwrap_or_else(|| format!("commands{extension}")),
378            GistFileDto { content },
379        )];
380        if !explicit_file
381            && !actual_gist.files.contains_key(GIST_README_FILENAME)
382            && !actual_gist.files.contains_key(GIST_README_FILENAME_UPPER)
383        {
384            files.push((
385                String::from(GIST_README_FILENAME),
386                GistFileDto {
387                    content: format!(
388                        r"# IntelliShell Commands
389
390These commands have been exported using [intelli-shell]({}), a command-line tool to bookmark and search commands.
391
392You can easily import all the commands by running:
393
394```sh
395intelli-shell import --gist {gist_id}
396```",
397                        env!("CARGO_PKG_REPOSITORY")
398                    ),
399                },
400            ));
401        }
402        let gist = GistDto {
403            files: HashMap::from_iter(files),
404        };
405
406        // Call the API
407        let client = reqwest::Client::new();
408        let res = client
409            .patch(url)
410            .header(header::ACCEPT, "application/vnd.github+json")
411            .header(header::USER_AGENT, "intelli-shell")
412            .header("X-GitHub-Api-Version", "2022-11-28")
413            .bearer_auth(gist_token)
414            .json(&gist)
415            .send()
416            .await
417            .map_err(|err| {
418                tracing::error!("{err:?}");
419                UserFacingError::GistRequestFailed(err.to_string())
420            })?;
421
422        // Check the response status
423        if !res.status().is_success() {
424            let status = res.status();
425            let status_str = status.as_str();
426            let body = res.text().await.unwrap_or_default();
427            if status == StatusCode::NOT_FOUND {
428                tracing::error!("Update got not found after a succesful get request");
429                return Err(
430                    UserFacingError::GistRequestFailed("token missing permissions to update the gist".into()).into(),
431                );
432            } else if let Some(reason) = status.canonical_reason() {
433                tracing::error!("Got response [{status_str}] {reason}:\n{body}");
434                return Err(
435                    UserFacingError::GistRequestFailed(format!("received {status_str} {reason} response")).into(),
436                );
437            } else {
438                tracing::error!("Got response [{status_str}]:\n{body}");
439                return Err(UserFacingError::GistRequestFailed(format!("received {status_str} response")).into());
440            }
441        }
442
443        Ok(stats)
444    }
445}