iroh_blobs/
cli.rs

1//! Define blob-related commands.
2#![allow(missing_docs)]
3use std::{
4    collections::{BTreeMap, HashMap},
5    net::SocketAddr,
6    path::PathBuf,
7    time::Duration,
8};
9
10use anyhow::{anyhow, bail, ensure, Context, Result};
11use clap::Subcommand;
12use console::{style, Emoji};
13use futures_lite::{Stream, StreamExt};
14use indicatif::{
15    HumanBytes, HumanDuration, MultiProgress, ProgressBar, ProgressDrawTarget, ProgressState,
16    ProgressStyle,
17};
18use iroh::{NodeAddr, PublicKey, RelayUrl};
19use tokio::io::AsyncWriteExt;
20
21use crate::{
22    get::{db::DownloadProgress, progress::BlobProgress, Stats},
23    net_protocol::DownloadMode,
24    provider::AddProgress,
25    rpc::client::blobs::{
26        self, BlobInfo, BlobStatus, CollectionInfo, DownloadOptions, IncompleteBlobInfo, WrapOption,
27    },
28    store::{ConsistencyCheckProgress, ExportFormat, ExportMode, ReportLevel, ValidateProgress},
29    ticket::BlobTicket,
30    util::SetTagOption,
31    BlobFormat, Hash, HashAndFormat, Tag,
32};
33
34pub mod tags;
35
36/// Subcommands for the blob command.
37#[allow(clippy::large_enum_variant)]
38#[derive(Subcommand, Debug, Clone)]
39pub enum BlobCommands {
40    /// Add data from PATH to the running node.
41    Add {
42        /// Path to a file or folder.
43        ///
44        /// If set to `STDIN`, the data will be read from stdin.
45        source: BlobSource,
46
47        #[clap(flatten)]
48        options: BlobAddOptions,
49    },
50    /// Download data to the running node's database and provide it.
51    ///
52    /// In addition to downloading the data, you can also specify an optional output directory
53    /// where the data will be exported to after it has been downloaded.
54    Get {
55        /// Ticket or Hash to use.
56        #[clap(name = "TICKET OR HASH")]
57        ticket: TicketOrHash,
58        /// Additional socket address to use to contact the node. Can be used multiple times.
59        #[clap(long)]
60        address: Vec<SocketAddr>,
61        /// Override the relay URL to use to contact the node.
62        #[clap(long)]
63        relay_url: Option<RelayUrl>,
64        /// Override to treat the blob as a raw blob or a hash sequence.
65        #[clap(long)]
66        recursive: Option<bool>,
67        /// If set, the ticket's direct addresses will not be used.
68        #[clap(long)]
69        override_addresses: bool,
70        /// NodeId of the provider.
71        #[clap(long)]
72        node: Option<PublicKey>,
73        /// Directory or file in which to save the file(s).
74        ///
75        /// If set to `STDOUT` the output will be redirected to stdout.
76        ///
77        /// If not specified, the data will only be stored internally.
78        #[clap(long, short)]
79        out: Option<OutputTarget>,
80        /// If set, the data will be moved to the output directory, and iroh will assume that it
81        /// will not change.
82        #[clap(long, default_value_t = false)]
83        stable: bool,
84        /// Tag to tag the data with.
85        #[clap(long)]
86        tag: Option<String>,
87        /// If set, will queue the download in the download queue.
88        ///
89        /// Use this if you are doing many downloads in parallel and want to limit the number of
90        /// downloads running concurrently.
91        #[clap(long)]
92        queued: bool,
93    },
94    /// Export a blob from the internal blob store to the local filesystem.
95    Export {
96        /// The hash to export.
97        hash: Hash,
98        /// Directory or file in which to save the file(s).
99        ///
100        /// If set to `STDOUT` the output will be redirected to stdout.
101        out: OutputTarget,
102        /// Set to true if the hash refers to a collection and you want to export all children of
103        /// the collection.
104        #[clap(long, default_value_t = false)]
105        recursive: bool,
106        /// If set, the data will be moved to the output directory, and iroh will assume that it
107        /// will not change.
108        #[clap(long, default_value_t = false)]
109        stable: bool,
110    },
111    /// List available content on the node.
112    #[clap(subcommand)]
113    List(ListCommands),
114    /// Validate hashes on the running node.
115    Validate {
116        /// Verbosity level.
117        #[clap(short, long, action(clap::ArgAction::Count))]
118        verbose: u8,
119        /// Repair the store by removing invalid data
120        ///
121        /// Caution: this will remove data to make the store consistent, even
122        /// if the data might be salvageable. E.g. for an entry for which the
123        /// outboard data is missing, the entry will be removed, even if the
124        /// data is complete.
125        #[clap(long, default_value_t = false)]
126        repair: bool,
127    },
128    /// Perform a database consistency check on the running node.
129    ConsistencyCheck {
130        /// Verbosity level.
131        #[clap(short, long, action(clap::ArgAction::Count))]
132        verbose: u8,
133        /// Repair the store by removing invalid data
134        ///
135        /// Caution: this will remove data to make the store consistent, even
136        /// if the data might be salvageable. E.g. for an entry for which the
137        /// outboard data is missing, the entry will be removed, even if the
138        /// data is complete.
139        #[clap(long, default_value_t = false)]
140        repair: bool,
141    },
142    /// Delete content on the node.
143    #[clap(subcommand)]
144    Delete(DeleteCommands),
145    /// Get a ticket to share this blob.
146    Share {
147        /// Hash of the blob to share.
148        hash: Hash,
149        /// If the blob is a collection, the requester will also fetch the listed blobs.
150        #[clap(long, default_value_t = false)]
151        recursive: bool,
152        /// Display the contents of this ticket too.
153        #[clap(long, hide = true)]
154        debug: bool,
155    },
156}
157
158/// Possible outcomes of an input.
159#[derive(Debug, Clone, derive_more::Display)]
160pub enum TicketOrHash {
161    Ticket(BlobTicket),
162    Hash(Hash),
163}
164
165impl std::str::FromStr for TicketOrHash {
166    type Err = anyhow::Error;
167
168    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
169        if let Ok(ticket) = BlobTicket::from_str(s) {
170            return Ok(Self::Ticket(ticket));
171        }
172        if let Ok(hash) = Hash::from_str(s) {
173            return Ok(Self::Hash(hash));
174        }
175        Err(anyhow!("neither a valid ticket or hash"))
176    }
177}
178
179impl BlobCommands {
180    /// Runs the blob command given the iroh client.
181    pub async fn run(self, blobs: &blobs::Client, addr: NodeAddr) -> Result<()> {
182        match self {
183            Self::Get {
184                ticket,
185                mut address,
186                relay_url,
187                recursive,
188                override_addresses,
189                node,
190                out,
191                stable,
192                tag,
193                queued,
194            } => {
195                let (node_addr, hash, format) = match ticket {
196                    TicketOrHash::Ticket(ticket) => {
197                        let (node_addr, hash, blob_format) = ticket.into_parts();
198
199                        // create the node address with the appropriate overrides
200                        let node_addr = {
201                            let NodeAddr {
202                                node_id,
203                                relay_url: original_relay_url,
204                                direct_addresses,
205                            } = node_addr;
206                            let addresses = if override_addresses {
207                                // use only the cli supplied ones
208                                address
209                            } else {
210                                // use both the cli supplied ones and the ticket ones
211                                address.extend(direct_addresses);
212                                address
213                            };
214
215                            // prefer direct arg over ticket
216                            let relay_url = relay_url.or(original_relay_url);
217
218                            NodeAddr::from_parts(node_id, relay_url, addresses)
219                        };
220
221                        // check if the blob format has an override
222                        let blob_format = match recursive {
223                            Some(true) => BlobFormat::HashSeq,
224                            Some(false) => BlobFormat::Raw,
225                            None => blob_format,
226                        };
227
228                        (node_addr, hash, blob_format)
229                    }
230                    TicketOrHash::Hash(hash) => {
231                        // check if the blob format has an override
232                        let blob_format = match recursive {
233                            Some(true) => BlobFormat::HashSeq,
234                            Some(false) => BlobFormat::Raw,
235                            None => BlobFormat::Raw,
236                        };
237
238                        let Some(node) = node else {
239                            bail!("missing NodeId");
240                        };
241
242                        let node_addr = NodeAddr::from_parts(node, relay_url, address);
243                        (node_addr, hash, blob_format)
244                    }
245                };
246
247                if format != BlobFormat::Raw && out == Some(OutputTarget::Stdout) {
248                    return Err(anyhow::anyhow!("The input arguments refer to a collection of blobs and output is set to STDOUT. Only single blobs may be passed in this case."));
249                }
250
251                let tag = match tag {
252                    Some(tag) => SetTagOption::Named(Tag::from(tag)),
253                    None => SetTagOption::Auto,
254                };
255
256                let mode = match queued {
257                    true => DownloadMode::Queued,
258                    false => DownloadMode::Direct,
259                };
260
261                let mut stream = blobs
262                    .download_with_opts(
263                        hash,
264                        DownloadOptions {
265                            format,
266                            nodes: vec![node_addr],
267                            tag,
268                            mode,
269                        },
270                    )
271                    .await?;
272
273                show_download_progress(hash, &mut stream).await?;
274
275                match out {
276                    None => {}
277                    Some(OutputTarget::Stdout) => {
278                        // we asserted above that `OutputTarget::Stdout` is only permitted if getting a
279                        // single hash and not a hashseq.
280                        let mut blob_read = blobs.read(hash).await?;
281                        tokio::io::copy(&mut blob_read, &mut tokio::io::stdout()).await?;
282                    }
283                    Some(OutputTarget::Path(path)) => {
284                        let absolute = std::env::current_dir()?.join(&path);
285                        if matches!(format, BlobFormat::HashSeq) {
286                            ensure!(!absolute.is_dir(), "output must not be a directory");
287                        }
288                        let recursive = format == BlobFormat::HashSeq;
289                        let mode = match stable {
290                            true => ExportMode::TryReference,
291                            false => ExportMode::Copy,
292                        };
293                        let format = match recursive {
294                            true => ExportFormat::Collection,
295                            false => ExportFormat::Blob,
296                        };
297                        tracing::info!("exporting to {} -> {}", path.display(), absolute.display());
298                        let stream = blobs.export(hash, absolute, format, mode).await?;
299
300                        // TODO: report export progress
301                        stream.await?;
302                    }
303                };
304
305                Ok(())
306            }
307            Self::Export {
308                hash,
309                out,
310                recursive,
311                stable,
312            } => {
313                match out {
314                    OutputTarget::Stdout => {
315                        ensure!(
316                            !recursive,
317                            "Recursive option is not supported when exporting to STDOUT"
318                        );
319                        let mut blob_read = blobs.read(hash).await?;
320                        tokio::io::copy(&mut blob_read, &mut tokio::io::stdout()).await?;
321                    }
322                    OutputTarget::Path(path) => {
323                        let absolute = std::env::current_dir()?.join(&path);
324                        if !recursive {
325                            ensure!(!absolute.is_dir(), "output must not be a directory");
326                        }
327                        let mode = match stable {
328                            true => ExportMode::TryReference,
329                            false => ExportMode::Copy,
330                        };
331                        let format = match recursive {
332                            true => ExportFormat::Collection,
333                            false => ExportFormat::Blob,
334                        };
335                        tracing::info!(
336                            "exporting {hash} to {} -> {}",
337                            path.display(),
338                            absolute.display()
339                        );
340                        let stream = blobs.export(hash, absolute, format, mode).await?;
341                        // TODO: report export progress
342                        stream.await?;
343                    }
344                };
345                Ok(())
346            }
347            Self::List(cmd) => cmd.run(blobs).await,
348            Self::Delete(cmd) => cmd.run(blobs).await,
349            Self::Validate { verbose, repair } => validate(blobs, verbose, repair).await,
350            Self::ConsistencyCheck { verbose, repair } => {
351                consistency_check(blobs, verbose, repair).await
352            }
353            Self::Add {
354                source: path,
355                options,
356            } => add_with_opts(blobs, addr, path, options).await,
357            Self::Share {
358                hash,
359                recursive,
360                debug,
361            } => {
362                let format = if recursive {
363                    BlobFormat::HashSeq
364                } else {
365                    BlobFormat::Raw
366                };
367                let status = blobs.status(hash).await?;
368                let ticket = BlobTicket::new(addr, hash, format)?;
369
370                let (blob_status, size) = match (status, format) {
371                    (BlobStatus::Complete { size }, BlobFormat::Raw) => ("blob", size),
372                    (BlobStatus::Partial { size }, BlobFormat::Raw) => {
373                        ("incomplete blob", size.value())
374                    }
375                    (BlobStatus::Complete { size }, BlobFormat::HashSeq) => ("collection", size),
376                    (BlobStatus::Partial { size }, BlobFormat::HashSeq) => {
377                        ("incomplete collection", size.value())
378                    }
379                    (BlobStatus::NotFound, _) => {
380                        return Err(anyhow!("blob is missing"));
381                    }
382                };
383                println!(
384                    "Ticket for {blob_status} {hash} ({})\n{ticket}",
385                    HumanBytes(size)
386                );
387
388                if debug {
389                    println!("{ticket:#?}")
390                }
391                Ok(())
392            }
393        }
394    }
395}
396
397/// Options for the `blob add` command.
398#[derive(clap::Args, Debug, Clone)]
399pub struct BlobAddOptions {
400    /// Add in place
401    ///
402    /// Set this to true only if you are sure that the data in its current location
403    /// will not change.
404    #[clap(long, default_value_t = false)]
405    pub in_place: bool,
406
407    /// Tag to tag the data with.
408    #[clap(long)]
409    pub tag: Option<String>,
410
411    /// Wrap the added file or directory in a collection.
412    ///
413    /// When adding a single file, without `wrap` the file is added as a single blob and no
414    /// collection is created. When enabling `wrap` it also creates a collection with a
415    /// single entry, where the entry's name is the filename and the entry's content is blob.
416    ///
417    /// When adding a directory, a collection is always created.
418    /// Without `wrap`, the collection directly contains the entries from the added directory.
419    /// With `wrap`, the directory will be nested so that all names in the collection are
420    /// prefixed with the directory name, thus preserving the name of the directory.
421    ///
422    /// When adding content from STDIN and setting `wrap` you also need to set `filename` to name
423    /// the entry pointing to the content from STDIN.
424    #[clap(long, default_value_t = false)]
425    pub wrap: bool,
426
427    /// Override the filename used for the entry in the created collection.
428    ///
429    /// Only supported `wrap` is set.
430    /// Required when adding content from STDIN and setting `wrap`.
431    #[clap(long, requires = "wrap")]
432    pub filename: Option<String>,
433
434    /// Do not print the all-in-one ticket to get the added data from this node.
435    #[clap(long)]
436    pub no_ticket: bool,
437}
438
439/// Possible list subcommands.
440#[derive(Subcommand, Debug, Clone)]
441pub enum ListCommands {
442    /// List the available blobs on the running provider.
443    Blobs,
444    /// List the blobs on the running provider that are not full files.
445    IncompleteBlobs,
446    /// List the available collections on the running provider.
447    Collections,
448}
449
450impl ListCommands {
451    /// Runs a list subcommand.
452    pub async fn run(self, blobs: &blobs::Client) -> Result<()> {
453        match self {
454            Self::Blobs => {
455                let mut response = blobs.list().await?;
456                while let Some(item) = response.next().await {
457                    let BlobInfo { path, hash, size } = item?;
458                    println!("{} {} ({})", path, hash, HumanBytes(size));
459                }
460            }
461            Self::IncompleteBlobs => {
462                let mut response = blobs.list_incomplete().await?;
463                while let Some(item) = response.next().await {
464                    let IncompleteBlobInfo { hash, size, .. } = item?;
465                    println!("{} ({})", hash, HumanBytes(size));
466                }
467            }
468            Self::Collections => {
469                let mut response = blobs.list_collections()?;
470                while let Some(item) = response.next().await {
471                    let CollectionInfo {
472                        tag,
473                        hash,
474                        total_blobs_count,
475                        total_blobs_size,
476                    } = item?;
477                    let total_blobs_count = total_blobs_count.unwrap_or_default();
478                    let total_blobs_size = total_blobs_size.unwrap_or_default();
479                    println!(
480                        "{}: {} {} {} ({})",
481                        tag,
482                        hash,
483                        total_blobs_count,
484                        if total_blobs_count > 1 {
485                            "blobs"
486                        } else {
487                            "blob"
488                        },
489                        HumanBytes(total_blobs_size),
490                    );
491                }
492            }
493        }
494        Ok(())
495    }
496}
497
498/// Possible delete subcommands.
499#[derive(Subcommand, Debug, Clone)]
500pub enum DeleteCommands {
501    /// Delete the given blobs
502    Blob {
503        /// Blobs to delete
504        #[arg(required = true)]
505        hash: Hash,
506    },
507}
508
509impl DeleteCommands {
510    /// Runs the delete command.
511    pub async fn run(self, blobs: &blobs::Client) -> Result<()> {
512        match self {
513            Self::Blob { hash } => {
514                let response = blobs.delete_blob(hash).await;
515                if let Err(e) = response {
516                    eprintln!("Error: {}", e);
517                }
518            }
519        }
520        Ok(())
521    }
522}
523
524/// Returns the corresponding [`ReportLevel`] given the verbosity level.
525fn get_report_level(verbose: u8) -> ReportLevel {
526    match verbose {
527        0 => ReportLevel::Warn,
528        1 => ReportLevel::Info,
529        _ => ReportLevel::Trace,
530    }
531}
532
533/// Applies the report level to the given text.
534fn apply_report_level(text: String, level: ReportLevel) -> console::StyledObject<String> {
535    match level {
536        ReportLevel::Trace => style(text).dim(),
537        ReportLevel::Info => style(text),
538        ReportLevel::Warn => style(text).yellow(),
539        ReportLevel::Error => style(text).red(),
540    }
541}
542
543/// Checks the consistency of the blobs on the running node, and repairs inconsistencies if instructed.
544pub async fn consistency_check(blobs: &blobs::Client, verbose: u8, repair: bool) -> Result<()> {
545    let mut response = blobs.consistency_check(repair).await?;
546    let verbosity = get_report_level(verbose);
547    let print = |level: ReportLevel, entry: Option<Hash>, message: String| {
548        if level < verbosity {
549            return;
550        }
551        let level_text = level.to_string().to_lowercase();
552        let text = if let Some(hash) = entry {
553            format!("{}: {} ({})", level_text, message, hash.to_hex())
554        } else {
555            format!("{}: {}", level_text, message)
556        };
557        let styled = apply_report_level(text, level);
558        eprintln!("{}", styled);
559    };
560
561    while let Some(item) = response.next().await {
562        match item? {
563            ConsistencyCheckProgress::Start => {
564                eprintln!("Starting consistency check ...");
565            }
566            ConsistencyCheckProgress::Update {
567                message,
568                entry,
569                level,
570            } => {
571                print(level, entry, message);
572            }
573            ConsistencyCheckProgress::Done => {
574                eprintln!("Consistency check done");
575            }
576            ConsistencyCheckProgress::Abort(error) => {
577                eprintln!("Consistency check error {}", error);
578                break;
579            }
580        }
581    }
582    Ok(())
583}
584
585/// Checks the validity of the blobs on the running node, and repairs anything invalid if instructed.
586pub async fn validate(blobs: &blobs::Client, verbose: u8, repair: bool) -> Result<()> {
587    let mut state = ValidateProgressState::new();
588    let mut response = blobs.validate(repair).await?;
589    let verbosity = get_report_level(verbose);
590    let print = |level: ReportLevel, entry: Option<Hash>, message: String| {
591        if level < verbosity {
592            return;
593        }
594        let level_text = level.to_string().to_lowercase();
595        let text = if let Some(hash) = entry {
596            format!("{}: {} ({})", level_text, message, hash.to_hex())
597        } else {
598            format!("{}: {}", level_text, message)
599        };
600        let styled = apply_report_level(text, level);
601        eprintln!("{}", styled);
602    };
603
604    let mut partial = BTreeMap::new();
605
606    while let Some(item) = response.next().await {
607        match item? {
608            ValidateProgress::PartialEntry {
609                id,
610                hash,
611                path,
612                size,
613            } => {
614                partial.insert(id, hash);
615                print(
616                    ReportLevel::Trace,
617                    Some(hash),
618                    format!(
619                        "Validating partial entry {} {} {}",
620                        id,
621                        path.unwrap_or_default(),
622                        size
623                    ),
624                );
625            }
626            ValidateProgress::PartialEntryProgress { id, offset } => {
627                let entry = partial.get(&id).cloned();
628                print(
629                    ReportLevel::Trace,
630                    entry,
631                    format!("Partial entry {} at {}", id, offset),
632                );
633            }
634            ValidateProgress::PartialEntryDone { id, ranges } => {
635                let entry: Option<Hash> = partial.remove(&id);
636                print(
637                    ReportLevel::Info,
638                    entry,
639                    format!("Partial entry {} done {:?}", id, ranges.to_chunk_ranges()),
640                );
641            }
642            ValidateProgress::Starting { total } => {
643                state.starting(total);
644            }
645            ValidateProgress::Entry {
646                id,
647                hash,
648                path,
649                size,
650            } => {
651                state.add_entry(id, hash, path, size);
652            }
653            ValidateProgress::EntryProgress { id, offset } => {
654                state.progress(id, offset);
655            }
656            ValidateProgress::EntryDone { id, error } => {
657                state.done(id, error);
658            }
659            ValidateProgress::Abort(error) => {
660                state.abort(error.to_string());
661                break;
662            }
663            ValidateProgress::AllDone => {
664                break;
665            }
666        }
667    }
668    Ok(())
669}
670
671/// Collection of all the validation progress state.
672struct ValidateProgressState {
673    mp: MultiProgress,
674    pbs: HashMap<u64, ProgressBar>,
675    overall: ProgressBar,
676    total: u64,
677    errors: u64,
678    successes: u64,
679}
680
681impl ValidateProgressState {
682    /// Creates a new validation progress state collection.
683    fn new() -> Self {
684        let mp = MultiProgress::new();
685        let overall = mp.add(ProgressBar::new(0));
686        overall.enable_steady_tick(Duration::from_millis(500));
687        Self {
688            mp,
689            pbs: HashMap::new(),
690            overall,
691            total: 0,
692            errors: 0,
693            successes: 0,
694        }
695    }
696
697    /// Sets the total number to the provided value and style the progress bar to starting.
698    fn starting(&mut self, total: u64) {
699        self.total = total;
700        self.errors = 0;
701        self.successes = 0;
702        self.overall.set_position(0);
703        self.overall.set_length(total);
704        self.overall.set_style(
705            ProgressStyle::default_bar()
706                .template("{spinner:.green} [{bar:60.cyan/blue}] {msg}")
707                .unwrap()
708                .progress_chars("=>-"),
709        );
710    }
711
712    /// Adds a message to the progress bar in the given `id`.
713    fn add_entry(&mut self, id: u64, hash: Hash, path: Option<String>, size: u64) {
714        let pb = self.mp.insert_before(&self.overall, ProgressBar::new(size));
715        pb.set_style(ProgressStyle::default_bar()
716            .template("{spinner:.green} [{bar:40.cyan/blue}] {msg} {bytes}/{total_bytes} ({bytes_per_sec}, eta {eta})").unwrap()
717            .progress_chars("=>-"));
718        let msg = if let Some(path) = path {
719            format!("{} {}", hash.to_hex(), path)
720        } else {
721            hash.to_hex().to_string()
722        };
723        pb.set_message(msg);
724        pb.set_position(0);
725        pb.set_length(size);
726        pb.enable_steady_tick(Duration::from_millis(500));
727        self.pbs.insert(id, pb);
728    }
729
730    /// Progresses the progress bar with `id` by `progress` amount.
731    fn progress(&mut self, id: u64, progress: u64) {
732        if let Some(pb) = self.pbs.get_mut(&id) {
733            pb.set_position(progress);
734        }
735    }
736
737    /// Set an error in the progress bar. Consumes the [`ValidateProgressState`].
738    fn abort(self, error: String) {
739        let error_line = self.mp.add(ProgressBar::new(0));
740        error_line.set_style(ProgressStyle::default_bar().template("{msg}").unwrap());
741        error_line.set_message(error);
742    }
743
744    /// Finishes a progress bar with a given error message.
745    fn done(&mut self, id: u64, error: Option<String>) {
746        if let Some(pb) = self.pbs.remove(&id) {
747            let ok_char = style(Emoji("✔", "OK")).green();
748            let fail_char = style(Emoji("✗", "Error")).red();
749            let ok = error.is_none();
750            let msg = match error {
751                Some(error) => format!("{} {} {}", pb.message(), fail_char, error),
752                None => format!("{} {}", pb.message(), ok_char),
753            };
754            if ok {
755                self.successes += 1;
756            } else {
757                self.errors += 1;
758            }
759            self.overall.set_position(self.errors + self.successes);
760            self.overall.set_message(format!(
761                "Overall {} {}, {} {}",
762                self.errors, fail_char, self.successes, ok_char
763            ));
764            if ok {
765                pb.finish_and_clear();
766            } else {
767                pb.set_style(ProgressStyle::default_bar().template("{msg}").unwrap());
768                pb.finish_with_message(msg);
769            }
770        }
771    }
772}
773
774/// Where the data should be read from.
775#[derive(Debug, Clone, derive_more::Display, PartialEq, Eq)]
776pub enum BlobSource {
777    /// Reads from stdin
778    #[display("STDIN")]
779    Stdin,
780    /// Reads from the provided path
781    #[display("{}", _0.display())]
782    Path(PathBuf),
783}
784
785impl From<String> for BlobSource {
786    fn from(s: String) -> Self {
787        if s == "STDIN" {
788            return BlobSource::Stdin;
789        }
790
791        BlobSource::Path(s.into())
792    }
793}
794
795/// Data source for adding data to iroh.
796#[derive(Debug, Clone)]
797pub enum BlobSourceIroh {
798    /// A file or directory on the node's local file system.
799    LocalFs { path: PathBuf, in_place: bool },
800    /// Data passed via STDIN.
801    Stdin,
802}
803
804/// Whether to print an all-in-one ticket.
805#[derive(Debug, Clone)]
806pub enum TicketOption {
807    /// Do not print an all-in-one ticket
808    None,
809    /// Print an all-in-one ticket.
810    Print,
811}
812
813/// Adds a [`BlobSource`] given some [`BlobAddOptions`].
814pub async fn add_with_opts(
815    blobs: &blobs::Client,
816    addr: NodeAddr,
817    source: BlobSource,
818    opts: BlobAddOptions,
819) -> Result<()> {
820    let tag = match opts.tag {
821        Some(tag) => SetTagOption::Named(Tag::from(tag)),
822        None => SetTagOption::Auto,
823    };
824    let ticket = match opts.no_ticket {
825        true => TicketOption::None,
826        false => TicketOption::Print,
827    };
828    let source = match source {
829        BlobSource::Stdin => BlobSourceIroh::Stdin,
830        BlobSource::Path(path) => BlobSourceIroh::LocalFs {
831            path,
832            in_place: opts.in_place,
833        },
834    };
835    let wrap = match (opts.wrap, opts.filename) {
836        (true, None) => WrapOption::Wrap { name: None },
837        (true, Some(filename)) => WrapOption::Wrap {
838            name: Some(filename),
839        },
840        (false, None) => WrapOption::NoWrap,
841        (false, Some(_)) => bail!("`--filename` may not be used without `--wrap`"),
842    };
843
844    add(blobs, addr, source, tag, ticket, wrap).await
845}
846
847/// Adds data to iroh, either from a path or, if path is `None`, from STDIN.
848pub async fn add(
849    blobs: &blobs::Client,
850    addr: NodeAddr,
851    source: BlobSourceIroh,
852    tag: SetTagOption,
853    ticket: TicketOption,
854    wrap: WrapOption,
855) -> Result<()> {
856    let (hash, format, entries) = match source {
857        BlobSourceIroh::LocalFs { path, in_place } => {
858            let absolute = path.canonicalize()?;
859            println!("Adding {} as {}...", path.display(), absolute.display());
860
861            // tell the node to add the data
862            let stream = blobs.add_from_path(absolute, in_place, tag, wrap).await?;
863            aggregate_add_response(stream).await?
864        }
865        BlobSourceIroh::Stdin => {
866            println!("Adding from STDIN...");
867            // Store STDIN content into a temporary file
868            let (file, path) = tempfile::NamedTempFile::new()?.into_parts();
869            let mut file = tokio::fs::File::from_std(file);
870            let path_buf = path.to_path_buf();
871            // Copy from stdin to the file, until EOF
872            tokio::io::copy(&mut tokio::io::stdin(), &mut file).await?;
873            file.flush().await?;
874            drop(file);
875
876            // tell the node to add the data
877            let stream = blobs.add_from_path(path_buf, false, tag, wrap).await?;
878            aggregate_add_response(stream).await?
879        }
880    };
881
882    print_add_response(hash, format, entries);
883    if let TicketOption::Print = ticket {
884        let ticket = BlobTicket::new(addr, hash, format)?;
885        println!("All-in-one ticket: {ticket}");
886    }
887    Ok(())
888}
889
890/// Entry with a given name, size, and hash.
891#[derive(Debug)]
892pub struct ProvideResponseEntry {
893    pub name: String,
894    pub size: u64,
895    pub hash: Hash,
896}
897
898/// Combines the [`AddProgress`] outputs from a [`Stream`] into a single tuple.
899pub async fn aggregate_add_response(
900    mut stream: impl Stream<Item = Result<AddProgress>> + Unpin,
901) -> Result<(Hash, BlobFormat, Vec<ProvideResponseEntry>)> {
902    let mut hash_and_format = None;
903    let mut collections = BTreeMap::<u64, (String, u64, Option<Hash>)>::new();
904    let mut mp = Some(ProvideProgressState::new());
905    while let Some(item) = stream.next().await {
906        match item? {
907            AddProgress::Found { name, id, size } => {
908                tracing::trace!("Found({id},{name},{size})");
909                if let Some(mp) = mp.as_mut() {
910                    mp.found(name.clone(), id, size);
911                }
912                collections.insert(id, (name, size, None));
913            }
914            AddProgress::Progress { id, offset } => {
915                tracing::trace!("Progress({id}, {offset})");
916                if let Some(mp) = mp.as_mut() {
917                    mp.progress(id, offset);
918                }
919            }
920            AddProgress::Done { hash, id } => {
921                tracing::trace!("Done({id},{hash:?})");
922                if let Some(mp) = mp.as_mut() {
923                    mp.done(id, hash);
924                }
925                match collections.get_mut(&id) {
926                    Some((_, _, ref mut h)) => {
927                        *h = Some(hash);
928                    }
929                    None => {
930                        anyhow::bail!("Got Done for unknown collection id {id}");
931                    }
932                }
933            }
934            AddProgress::AllDone { hash, format, .. } => {
935                tracing::trace!("AllDone({hash:?})");
936                if let Some(mp) = mp.take() {
937                    mp.all_done();
938                }
939                hash_and_format = Some(HashAndFormat { hash, format });
940                break;
941            }
942            AddProgress::Abort(e) => {
943                if let Some(mp) = mp.take() {
944                    mp.error();
945                }
946                anyhow::bail!("Error while adding data: {e}");
947            }
948        }
949    }
950    let HashAndFormat { hash, format } =
951        hash_and_format.context("Missing hash for collection or blob")?;
952    let entries = collections
953        .into_iter()
954        .map(|(_, (name, size, hash))| {
955            let hash = hash.context(format!("Missing hash for {name}"))?;
956            Ok(ProvideResponseEntry { name, size, hash })
957        })
958        .collect::<Result<Vec<_>>>()?;
959    Ok((hash, format, entries))
960}
961
962/// Prints out the add response.
963pub fn print_add_response(hash: Hash, format: BlobFormat, entries: Vec<ProvideResponseEntry>) {
964    let mut total_size = 0;
965    for ProvideResponseEntry { name, size, hash } in entries {
966        total_size += size;
967        println!("- {}: {} {:#}", name, HumanBytes(size), hash);
968    }
969    println!("Total: {}", HumanBytes(total_size));
970    println!();
971    match format {
972        BlobFormat::Raw => println!("Blob: {}", hash),
973        BlobFormat::HashSeq => println!("Collection: {}", hash),
974    }
975}
976
977/// Progress state for providing.
978#[derive(Debug)]
979pub struct ProvideProgressState {
980    mp: MultiProgress,
981    pbs: HashMap<u64, ProgressBar>,
982}
983
984impl ProvideProgressState {
985    /// Creates a new provide progress state.
986    fn new() -> Self {
987        Self {
988            mp: MultiProgress::new(),
989            pbs: HashMap::new(),
990        }
991    }
992
993    /// Inserts a new progress bar with the given id, name, and size.
994    fn found(&mut self, name: String, id: u64, size: u64) {
995        let pb = self.mp.add(ProgressBar::new(size));
996        pb.set_style(ProgressStyle::default_bar()
997            .template("{spinner:.green} [{bar:40.cyan/blue}] {msg} {bytes}/{total_bytes} ({bytes_per_sec}, eta {eta})").unwrap()
998            .progress_chars("=>-"));
999        pb.set_message(name);
1000        pb.set_length(size);
1001        pb.set_position(0);
1002        pb.enable_steady_tick(Duration::from_millis(500));
1003        self.pbs.insert(id, pb);
1004    }
1005
1006    /// Adds some progress to the progress bar with the given id.
1007    fn progress(&mut self, id: u64, progress: u64) {
1008        if let Some(pb) = self.pbs.get_mut(&id) {
1009            pb.set_position(progress);
1010        }
1011    }
1012
1013    /// Sets the multiprogress bar with the given id as finished and clear it.
1014    fn done(&mut self, id: u64, _hash: Hash) {
1015        if let Some(pb) = self.pbs.remove(&id) {
1016            pb.finish_and_clear();
1017            self.mp.remove(&pb);
1018        }
1019    }
1020
1021    /// Sets the multiprogress bar as finished and clear them.
1022    fn all_done(self) {
1023        self.mp.clear().ok();
1024    }
1025
1026    /// Clears the multiprogress bar.
1027    fn error(self) {
1028        self.mp.clear().ok();
1029    }
1030}
1031
1032/// Displays the download progress for a given stream.
1033pub async fn show_download_progress(
1034    hash: Hash,
1035    mut stream: impl Stream<Item = Result<DownloadProgress>> + Unpin,
1036) -> Result<()> {
1037    eprintln!("Fetching: {}", hash);
1038    let mp = MultiProgress::new();
1039    mp.set_draw_target(ProgressDrawTarget::stderr());
1040    let op = mp.add(make_overall_progress());
1041    let ip = mp.add(make_individual_progress());
1042    op.set_message(format!("{} Connecting ...\n", style("[1/3]").bold().dim()));
1043    let mut seq = false;
1044    while let Some(x) = stream.next().await {
1045        match x? {
1046            DownloadProgress::InitialState(state) => {
1047                if state.connected {
1048                    op.set_message(format!("{} Requesting ...\n", style("[2/3]").bold().dim()));
1049                }
1050                if let Some(count) = state.root.child_count {
1051                    op.set_message(format!(
1052                        "{} Downloading {} blob(s)\n",
1053                        style("[3/3]").bold().dim(),
1054                        count + 1,
1055                    ));
1056                    op.set_length(count + 1);
1057                    op.reset();
1058                    op.set_position(state.current.map(u64::from).unwrap_or(0));
1059                    seq = true;
1060                }
1061                if let Some(blob) = state.get_current() {
1062                    if let Some(size) = blob.size {
1063                        ip.set_length(size.value());
1064                        ip.reset();
1065                        match blob.progress {
1066                            BlobProgress::Pending => {}
1067                            BlobProgress::Progressing(offset) => ip.set_position(offset),
1068                            BlobProgress::Done => ip.finish_and_clear(),
1069                        }
1070                        if !seq {
1071                            op.finish_and_clear();
1072                        }
1073                    }
1074                }
1075            }
1076            DownloadProgress::FoundLocal { .. } => {}
1077            DownloadProgress::Connected => {
1078                op.set_message(format!("{} Requesting ...\n", style("[2/3]").bold().dim()));
1079            }
1080            DownloadProgress::FoundHashSeq { children, .. } => {
1081                op.set_message(format!(
1082                    "{} Downloading {} blob(s)\n",
1083                    style("[3/3]").bold().dim(),
1084                    children + 1,
1085                ));
1086                op.set_length(children + 1);
1087                op.reset();
1088                seq = true;
1089            }
1090            DownloadProgress::Found { size, child, .. } => {
1091                if seq {
1092                    op.set_position(child.into());
1093                } else {
1094                    op.finish_and_clear();
1095                }
1096                ip.set_length(size);
1097                ip.reset();
1098            }
1099            DownloadProgress::Progress { offset, .. } => {
1100                ip.set_position(offset);
1101            }
1102            DownloadProgress::Done { .. } => {
1103                ip.finish_and_clear();
1104            }
1105            DownloadProgress::AllDone(Stats {
1106                bytes_read,
1107                elapsed,
1108                ..
1109            }) => {
1110                op.finish_and_clear();
1111                eprintln!(
1112                    "Transferred {} in {}, {}/s",
1113                    HumanBytes(bytes_read),
1114                    HumanDuration(elapsed),
1115                    HumanBytes((bytes_read as f64 / elapsed.as_secs_f64()) as u64)
1116                );
1117                break;
1118            }
1119            DownloadProgress::Abort(e) => {
1120                bail!("download aborted: {}", e);
1121            }
1122        }
1123    }
1124    Ok(())
1125}
1126
1127/// Where the data should be stored.
1128#[derive(Debug, Clone, derive_more::Display, PartialEq, Eq)]
1129pub enum OutputTarget {
1130    /// Writes to stdout
1131    #[display("STDOUT")]
1132    Stdout,
1133    /// Writes to the provided path
1134    #[display("{}", _0.display())]
1135    Path(PathBuf),
1136}
1137
1138impl From<String> for OutputTarget {
1139    fn from(s: String) -> Self {
1140        if s == "STDOUT" {
1141            return OutputTarget::Stdout;
1142        }
1143
1144        OutputTarget::Path(s.into())
1145    }
1146}
1147
1148/// Creates a [`ProgressBar`] with some defaults for the overall progress.
1149fn make_overall_progress() -> ProgressBar {
1150    let pb = ProgressBar::hidden();
1151    pb.enable_steady_tick(std::time::Duration::from_millis(100));
1152    pb.set_style(
1153        ProgressStyle::with_template(
1154            "{msg}{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {pos}/{len}",
1155        )
1156        .unwrap()
1157        .progress_chars("#>-"),
1158    );
1159    pb
1160}
1161
1162/// Creates a [`ProgressBar`] with some defaults for the individual progress.
1163fn make_individual_progress() -> ProgressBar {
1164    let pb = ProgressBar::hidden();
1165    pb.enable_steady_tick(std::time::Duration::from_millis(100));
1166    pb.set_style(
1167        ProgressStyle::with_template("{msg}{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {bytes}/{total_bytes} ({eta})")
1168            .unwrap()
1169            .with_key(
1170                "eta",
1171                |state: &ProgressState, w: &mut dyn std::fmt::Write| {
1172                    write!(w, "{:.1}s", state.eta().as_secs_f64()).unwrap()
1173                },
1174            )
1175            .progress_chars("#>-"),
1176    );
1177    pb
1178}
1179
1180#[cfg(test)]
1181mod tests {
1182    use super::*;
1183
1184    #[test]
1185    fn test_blob_source() {
1186        assert_eq!(
1187            BlobSource::from(BlobSource::Stdin.to_string()),
1188            BlobSource::Stdin
1189        );
1190
1191        assert_eq!(
1192            BlobSource::from(BlobSource::Path("hello/world".into()).to_string()),
1193            BlobSource::Path("hello/world".into()),
1194        );
1195    }
1196
1197    #[test]
1198    fn test_output_target() {
1199        assert_eq!(
1200            OutputTarget::from(OutputTarget::Stdout.to_string()),
1201            OutputTarget::Stdout
1202        );
1203
1204        assert_eq!(
1205            OutputTarget::from(OutputTarget::Path("hello/world".into()).to_string()),
1206            OutputTarget::Path("hello/world".into()),
1207        );
1208    }
1209}