1#![allow(missing_docs)]
3use std::{
4 collections::{BTreeMap, HashMap},
5 net::SocketAddr,
6 path::PathBuf,
7 time::Duration,
8};
9
10use anyhow::{anyhow, bail, ensure, Context, Result};
11use clap::Subcommand;
12use console::{style, Emoji};
13use futures_lite::{Stream, StreamExt};
14use indicatif::{
15 HumanBytes, HumanDuration, MultiProgress, ProgressBar, ProgressDrawTarget, ProgressState,
16 ProgressStyle,
17};
18use iroh::{NodeAddr, PublicKey, RelayUrl};
19use tokio::io::AsyncWriteExt;
20
21use crate::{
22 get::{db::DownloadProgress, progress::BlobProgress, Stats},
23 net_protocol::DownloadMode,
24 provider::AddProgress,
25 rpc::client::blobs::{
26 self, BlobInfo, BlobStatus, CollectionInfo, DownloadOptions, IncompleteBlobInfo, WrapOption,
27 },
28 store::{ConsistencyCheckProgress, ExportFormat, ExportMode, ReportLevel, ValidateProgress},
29 ticket::BlobTicket,
30 util::SetTagOption,
31 BlobFormat, Hash, HashAndFormat, Tag,
32};
33
34pub mod tags;
35
36#[allow(clippy::large_enum_variant)]
38#[derive(Subcommand, Debug, Clone)]
39pub enum BlobCommands {
40 Add {
42 source: BlobSource,
46
47 #[clap(flatten)]
48 options: BlobAddOptions,
49 },
50 Get {
55 #[clap(name = "TICKET OR HASH")]
57 ticket: TicketOrHash,
58 #[clap(long)]
60 address: Vec<SocketAddr>,
61 #[clap(long)]
63 relay_url: Option<RelayUrl>,
64 #[clap(long)]
66 recursive: Option<bool>,
67 #[clap(long)]
69 override_addresses: bool,
70 #[clap(long)]
72 node: Option<PublicKey>,
73 #[clap(long, short)]
79 out: Option<OutputTarget>,
80 #[clap(long, default_value_t = false)]
83 stable: bool,
84 #[clap(long)]
86 tag: Option<String>,
87 #[clap(long)]
92 queued: bool,
93 },
94 Export {
96 hash: Hash,
98 out: OutputTarget,
102 #[clap(long, default_value_t = false)]
105 recursive: bool,
106 #[clap(long, default_value_t = false)]
109 stable: bool,
110 },
111 #[clap(subcommand)]
113 List(ListCommands),
114 Validate {
116 #[clap(short, long, action(clap::ArgAction::Count))]
118 verbose: u8,
119 #[clap(long, default_value_t = false)]
126 repair: bool,
127 },
128 ConsistencyCheck {
130 #[clap(short, long, action(clap::ArgAction::Count))]
132 verbose: u8,
133 #[clap(long, default_value_t = false)]
140 repair: bool,
141 },
142 #[clap(subcommand)]
144 Delete(DeleteCommands),
145 Share {
147 hash: Hash,
149 #[clap(long, default_value_t = false)]
151 recursive: bool,
152 #[clap(long, hide = true)]
154 debug: bool,
155 },
156}
157
158#[derive(Debug, Clone, derive_more::Display)]
160pub enum TicketOrHash {
161 Ticket(BlobTicket),
162 Hash(Hash),
163}
164
165impl std::str::FromStr for TicketOrHash {
166 type Err = anyhow::Error;
167
168 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
169 if let Ok(ticket) = BlobTicket::from_str(s) {
170 return Ok(Self::Ticket(ticket));
171 }
172 if let Ok(hash) = Hash::from_str(s) {
173 return Ok(Self::Hash(hash));
174 }
175 Err(anyhow!("neither a valid ticket or hash"))
176 }
177}
178
179impl BlobCommands {
180 pub async fn run(self, blobs: &blobs::Client, addr: NodeAddr) -> Result<()> {
182 match self {
183 Self::Get {
184 ticket,
185 mut address,
186 relay_url,
187 recursive,
188 override_addresses,
189 node,
190 out,
191 stable,
192 tag,
193 queued,
194 } => {
195 let (node_addr, hash, format) = match ticket {
196 TicketOrHash::Ticket(ticket) => {
197 let (node_addr, hash, blob_format) = ticket.into_parts();
198
199 let node_addr = {
201 let NodeAddr {
202 node_id,
203 relay_url: original_relay_url,
204 direct_addresses,
205 } = node_addr;
206 let addresses = if override_addresses {
207 address
209 } else {
210 address.extend(direct_addresses);
212 address
213 };
214
215 let relay_url = relay_url.or(original_relay_url);
217
218 NodeAddr::from_parts(node_id, relay_url, addresses)
219 };
220
221 let blob_format = match recursive {
223 Some(true) => BlobFormat::HashSeq,
224 Some(false) => BlobFormat::Raw,
225 None => blob_format,
226 };
227
228 (node_addr, hash, blob_format)
229 }
230 TicketOrHash::Hash(hash) => {
231 let blob_format = match recursive {
233 Some(true) => BlobFormat::HashSeq,
234 Some(false) => BlobFormat::Raw,
235 None => BlobFormat::Raw,
236 };
237
238 let Some(node) = node else {
239 bail!("missing NodeId");
240 };
241
242 let node_addr = NodeAddr::from_parts(node, relay_url, address);
243 (node_addr, hash, blob_format)
244 }
245 };
246
247 if format != BlobFormat::Raw && out == Some(OutputTarget::Stdout) {
248 return Err(anyhow::anyhow!("The input arguments refer to a collection of blobs and output is set to STDOUT. Only single blobs may be passed in this case."));
249 }
250
251 let tag = match tag {
252 Some(tag) => SetTagOption::Named(Tag::from(tag)),
253 None => SetTagOption::Auto,
254 };
255
256 let mode = match queued {
257 true => DownloadMode::Queued,
258 false => DownloadMode::Direct,
259 };
260
261 let mut stream = blobs
262 .download_with_opts(
263 hash,
264 DownloadOptions {
265 format,
266 nodes: vec![node_addr],
267 tag,
268 mode,
269 },
270 )
271 .await?;
272
273 show_download_progress(hash, &mut stream).await?;
274
275 match out {
276 None => {}
277 Some(OutputTarget::Stdout) => {
278 let mut blob_read = blobs.read(hash).await?;
281 tokio::io::copy(&mut blob_read, &mut tokio::io::stdout()).await?;
282 }
283 Some(OutputTarget::Path(path)) => {
284 let absolute = std::env::current_dir()?.join(&path);
285 if matches!(format, BlobFormat::HashSeq) {
286 ensure!(!absolute.is_dir(), "output must not be a directory");
287 }
288 let recursive = format == BlobFormat::HashSeq;
289 let mode = match stable {
290 true => ExportMode::TryReference,
291 false => ExportMode::Copy,
292 };
293 let format = match recursive {
294 true => ExportFormat::Collection,
295 false => ExportFormat::Blob,
296 };
297 tracing::info!("exporting to {} -> {}", path.display(), absolute.display());
298 let stream = blobs.export(hash, absolute, format, mode).await?;
299
300 stream.await?;
302 }
303 };
304
305 Ok(())
306 }
307 Self::Export {
308 hash,
309 out,
310 recursive,
311 stable,
312 } => {
313 match out {
314 OutputTarget::Stdout => {
315 ensure!(
316 !recursive,
317 "Recursive option is not supported when exporting to STDOUT"
318 );
319 let mut blob_read = blobs.read(hash).await?;
320 tokio::io::copy(&mut blob_read, &mut tokio::io::stdout()).await?;
321 }
322 OutputTarget::Path(path) => {
323 let absolute = std::env::current_dir()?.join(&path);
324 if !recursive {
325 ensure!(!absolute.is_dir(), "output must not be a directory");
326 }
327 let mode = match stable {
328 true => ExportMode::TryReference,
329 false => ExportMode::Copy,
330 };
331 let format = match recursive {
332 true => ExportFormat::Collection,
333 false => ExportFormat::Blob,
334 };
335 tracing::info!(
336 "exporting {hash} to {} -> {}",
337 path.display(),
338 absolute.display()
339 );
340 let stream = blobs.export(hash, absolute, format, mode).await?;
341 stream.await?;
343 }
344 };
345 Ok(())
346 }
347 Self::List(cmd) => cmd.run(blobs).await,
348 Self::Delete(cmd) => cmd.run(blobs).await,
349 Self::Validate { verbose, repair } => validate(blobs, verbose, repair).await,
350 Self::ConsistencyCheck { verbose, repair } => {
351 consistency_check(blobs, verbose, repair).await
352 }
353 Self::Add {
354 source: path,
355 options,
356 } => add_with_opts(blobs, addr, path, options).await,
357 Self::Share {
358 hash,
359 recursive,
360 debug,
361 } => {
362 let format = if recursive {
363 BlobFormat::HashSeq
364 } else {
365 BlobFormat::Raw
366 };
367 let status = blobs.status(hash).await?;
368 let ticket = BlobTicket::new(addr, hash, format)?;
369
370 let (blob_status, size) = match (status, format) {
371 (BlobStatus::Complete { size }, BlobFormat::Raw) => ("blob", size),
372 (BlobStatus::Partial { size }, BlobFormat::Raw) => {
373 ("incomplete blob", size.value())
374 }
375 (BlobStatus::Complete { size }, BlobFormat::HashSeq) => ("collection", size),
376 (BlobStatus::Partial { size }, BlobFormat::HashSeq) => {
377 ("incomplete collection", size.value())
378 }
379 (BlobStatus::NotFound, _) => {
380 return Err(anyhow!("blob is missing"));
381 }
382 };
383 println!(
384 "Ticket for {blob_status} {hash} ({})\n{ticket}",
385 HumanBytes(size)
386 );
387
388 if debug {
389 println!("{ticket:#?}")
390 }
391 Ok(())
392 }
393 }
394 }
395}
396
397#[derive(clap::Args, Debug, Clone)]
399pub struct BlobAddOptions {
400 #[clap(long, default_value_t = false)]
405 pub in_place: bool,
406
407 #[clap(long)]
409 pub tag: Option<String>,
410
411 #[clap(long, default_value_t = false)]
425 pub wrap: bool,
426
427 #[clap(long, requires = "wrap")]
432 pub filename: Option<String>,
433
434 #[clap(long)]
436 pub no_ticket: bool,
437}
438
439#[derive(Subcommand, Debug, Clone)]
441pub enum ListCommands {
442 Blobs,
444 IncompleteBlobs,
446 Collections,
448}
449
450impl ListCommands {
451 pub async fn run(self, blobs: &blobs::Client) -> Result<()> {
453 match self {
454 Self::Blobs => {
455 let mut response = blobs.list().await?;
456 while let Some(item) = response.next().await {
457 let BlobInfo { path, hash, size } = item?;
458 println!("{} {} ({})", path, hash, HumanBytes(size));
459 }
460 }
461 Self::IncompleteBlobs => {
462 let mut response = blobs.list_incomplete().await?;
463 while let Some(item) = response.next().await {
464 let IncompleteBlobInfo { hash, size, .. } = item?;
465 println!("{} ({})", hash, HumanBytes(size));
466 }
467 }
468 Self::Collections => {
469 let mut response = blobs.list_collections()?;
470 while let Some(item) = response.next().await {
471 let CollectionInfo {
472 tag,
473 hash,
474 total_blobs_count,
475 total_blobs_size,
476 } = item?;
477 let total_blobs_count = total_blobs_count.unwrap_or_default();
478 let total_blobs_size = total_blobs_size.unwrap_or_default();
479 println!(
480 "{}: {} {} {} ({})",
481 tag,
482 hash,
483 total_blobs_count,
484 if total_blobs_count > 1 {
485 "blobs"
486 } else {
487 "blob"
488 },
489 HumanBytes(total_blobs_size),
490 );
491 }
492 }
493 }
494 Ok(())
495 }
496}
497
498#[derive(Subcommand, Debug, Clone)]
500pub enum DeleteCommands {
501 Blob {
503 #[arg(required = true)]
505 hash: Hash,
506 },
507}
508
509impl DeleteCommands {
510 pub async fn run(self, blobs: &blobs::Client) -> Result<()> {
512 match self {
513 Self::Blob { hash } => {
514 let response = blobs.delete_blob(hash).await;
515 if let Err(e) = response {
516 eprintln!("Error: {}", e);
517 }
518 }
519 }
520 Ok(())
521 }
522}
523
524fn get_report_level(verbose: u8) -> ReportLevel {
526 match verbose {
527 0 => ReportLevel::Warn,
528 1 => ReportLevel::Info,
529 _ => ReportLevel::Trace,
530 }
531}
532
533fn apply_report_level(text: String, level: ReportLevel) -> console::StyledObject<String> {
535 match level {
536 ReportLevel::Trace => style(text).dim(),
537 ReportLevel::Info => style(text),
538 ReportLevel::Warn => style(text).yellow(),
539 ReportLevel::Error => style(text).red(),
540 }
541}
542
543pub async fn consistency_check(blobs: &blobs::Client, verbose: u8, repair: bool) -> Result<()> {
545 let mut response = blobs.consistency_check(repair).await?;
546 let verbosity = get_report_level(verbose);
547 let print = |level: ReportLevel, entry: Option<Hash>, message: String| {
548 if level < verbosity {
549 return;
550 }
551 let level_text = level.to_string().to_lowercase();
552 let text = if let Some(hash) = entry {
553 format!("{}: {} ({})", level_text, message, hash.to_hex())
554 } else {
555 format!("{}: {}", level_text, message)
556 };
557 let styled = apply_report_level(text, level);
558 eprintln!("{}", styled);
559 };
560
561 while let Some(item) = response.next().await {
562 match item? {
563 ConsistencyCheckProgress::Start => {
564 eprintln!("Starting consistency check ...");
565 }
566 ConsistencyCheckProgress::Update {
567 message,
568 entry,
569 level,
570 } => {
571 print(level, entry, message);
572 }
573 ConsistencyCheckProgress::Done => {
574 eprintln!("Consistency check done");
575 }
576 ConsistencyCheckProgress::Abort(error) => {
577 eprintln!("Consistency check error {}", error);
578 break;
579 }
580 }
581 }
582 Ok(())
583}
584
585pub async fn validate(blobs: &blobs::Client, verbose: u8, repair: bool) -> Result<()> {
587 let mut state = ValidateProgressState::new();
588 let mut response = blobs.validate(repair).await?;
589 let verbosity = get_report_level(verbose);
590 let print = |level: ReportLevel, entry: Option<Hash>, message: String| {
591 if level < verbosity {
592 return;
593 }
594 let level_text = level.to_string().to_lowercase();
595 let text = if let Some(hash) = entry {
596 format!("{}: {} ({})", level_text, message, hash.to_hex())
597 } else {
598 format!("{}: {}", level_text, message)
599 };
600 let styled = apply_report_level(text, level);
601 eprintln!("{}", styled);
602 };
603
604 let mut partial = BTreeMap::new();
605
606 while let Some(item) = response.next().await {
607 match item? {
608 ValidateProgress::PartialEntry {
609 id,
610 hash,
611 path,
612 size,
613 } => {
614 partial.insert(id, hash);
615 print(
616 ReportLevel::Trace,
617 Some(hash),
618 format!(
619 "Validating partial entry {} {} {}",
620 id,
621 path.unwrap_or_default(),
622 size
623 ),
624 );
625 }
626 ValidateProgress::PartialEntryProgress { id, offset } => {
627 let entry = partial.get(&id).cloned();
628 print(
629 ReportLevel::Trace,
630 entry,
631 format!("Partial entry {} at {}", id, offset),
632 );
633 }
634 ValidateProgress::PartialEntryDone { id, ranges } => {
635 let entry: Option<Hash> = partial.remove(&id);
636 print(
637 ReportLevel::Info,
638 entry,
639 format!("Partial entry {} done {:?}", id, ranges.to_chunk_ranges()),
640 );
641 }
642 ValidateProgress::Starting { total } => {
643 state.starting(total);
644 }
645 ValidateProgress::Entry {
646 id,
647 hash,
648 path,
649 size,
650 } => {
651 state.add_entry(id, hash, path, size);
652 }
653 ValidateProgress::EntryProgress { id, offset } => {
654 state.progress(id, offset);
655 }
656 ValidateProgress::EntryDone { id, error } => {
657 state.done(id, error);
658 }
659 ValidateProgress::Abort(error) => {
660 state.abort(error.to_string());
661 break;
662 }
663 ValidateProgress::AllDone => {
664 break;
665 }
666 }
667 }
668 Ok(())
669}
670
671struct ValidateProgressState {
673 mp: MultiProgress,
674 pbs: HashMap<u64, ProgressBar>,
675 overall: ProgressBar,
676 total: u64,
677 errors: u64,
678 successes: u64,
679}
680
681impl ValidateProgressState {
682 fn new() -> Self {
684 let mp = MultiProgress::new();
685 let overall = mp.add(ProgressBar::new(0));
686 overall.enable_steady_tick(Duration::from_millis(500));
687 Self {
688 mp,
689 pbs: HashMap::new(),
690 overall,
691 total: 0,
692 errors: 0,
693 successes: 0,
694 }
695 }
696
697 fn starting(&mut self, total: u64) {
699 self.total = total;
700 self.errors = 0;
701 self.successes = 0;
702 self.overall.set_position(0);
703 self.overall.set_length(total);
704 self.overall.set_style(
705 ProgressStyle::default_bar()
706 .template("{spinner:.green} [{bar:60.cyan/blue}] {msg}")
707 .unwrap()
708 .progress_chars("=>-"),
709 );
710 }
711
712 fn add_entry(&mut self, id: u64, hash: Hash, path: Option<String>, size: u64) {
714 let pb = self.mp.insert_before(&self.overall, ProgressBar::new(size));
715 pb.set_style(ProgressStyle::default_bar()
716 .template("{spinner:.green} [{bar:40.cyan/blue}] {msg} {bytes}/{total_bytes} ({bytes_per_sec}, eta {eta})").unwrap()
717 .progress_chars("=>-"));
718 let msg = if let Some(path) = path {
719 format!("{} {}", hash.to_hex(), path)
720 } else {
721 hash.to_hex().to_string()
722 };
723 pb.set_message(msg);
724 pb.set_position(0);
725 pb.set_length(size);
726 pb.enable_steady_tick(Duration::from_millis(500));
727 self.pbs.insert(id, pb);
728 }
729
730 fn progress(&mut self, id: u64, progress: u64) {
732 if let Some(pb) = self.pbs.get_mut(&id) {
733 pb.set_position(progress);
734 }
735 }
736
737 fn abort(self, error: String) {
739 let error_line = self.mp.add(ProgressBar::new(0));
740 error_line.set_style(ProgressStyle::default_bar().template("{msg}").unwrap());
741 error_line.set_message(error);
742 }
743
744 fn done(&mut self, id: u64, error: Option<String>) {
746 if let Some(pb) = self.pbs.remove(&id) {
747 let ok_char = style(Emoji("✔", "OK")).green();
748 let fail_char = style(Emoji("✗", "Error")).red();
749 let ok = error.is_none();
750 let msg = match error {
751 Some(error) => format!("{} {} {}", pb.message(), fail_char, error),
752 None => format!("{} {}", pb.message(), ok_char),
753 };
754 if ok {
755 self.successes += 1;
756 } else {
757 self.errors += 1;
758 }
759 self.overall.set_position(self.errors + self.successes);
760 self.overall.set_message(format!(
761 "Overall {} {}, {} {}",
762 self.errors, fail_char, self.successes, ok_char
763 ));
764 if ok {
765 pb.finish_and_clear();
766 } else {
767 pb.set_style(ProgressStyle::default_bar().template("{msg}").unwrap());
768 pb.finish_with_message(msg);
769 }
770 }
771 }
772}
773
774#[derive(Debug, Clone, derive_more::Display, PartialEq, Eq)]
776pub enum BlobSource {
777 #[display("STDIN")]
779 Stdin,
780 #[display("{}", _0.display())]
782 Path(PathBuf),
783}
784
785impl From<String> for BlobSource {
786 fn from(s: String) -> Self {
787 if s == "STDIN" {
788 return BlobSource::Stdin;
789 }
790
791 BlobSource::Path(s.into())
792 }
793}
794
795#[derive(Debug, Clone)]
797pub enum BlobSourceIroh {
798 LocalFs { path: PathBuf, in_place: bool },
800 Stdin,
802}
803
804#[derive(Debug, Clone)]
806pub enum TicketOption {
807 None,
809 Print,
811}
812
813pub async fn add_with_opts(
815 blobs: &blobs::Client,
816 addr: NodeAddr,
817 source: BlobSource,
818 opts: BlobAddOptions,
819) -> Result<()> {
820 let tag = match opts.tag {
821 Some(tag) => SetTagOption::Named(Tag::from(tag)),
822 None => SetTagOption::Auto,
823 };
824 let ticket = match opts.no_ticket {
825 true => TicketOption::None,
826 false => TicketOption::Print,
827 };
828 let source = match source {
829 BlobSource::Stdin => BlobSourceIroh::Stdin,
830 BlobSource::Path(path) => BlobSourceIroh::LocalFs {
831 path,
832 in_place: opts.in_place,
833 },
834 };
835 let wrap = match (opts.wrap, opts.filename) {
836 (true, None) => WrapOption::Wrap { name: None },
837 (true, Some(filename)) => WrapOption::Wrap {
838 name: Some(filename),
839 },
840 (false, None) => WrapOption::NoWrap,
841 (false, Some(_)) => bail!("`--filename` may not be used without `--wrap`"),
842 };
843
844 add(blobs, addr, source, tag, ticket, wrap).await
845}
846
847pub async fn add(
849 blobs: &blobs::Client,
850 addr: NodeAddr,
851 source: BlobSourceIroh,
852 tag: SetTagOption,
853 ticket: TicketOption,
854 wrap: WrapOption,
855) -> Result<()> {
856 let (hash, format, entries) = match source {
857 BlobSourceIroh::LocalFs { path, in_place } => {
858 let absolute = path.canonicalize()?;
859 println!("Adding {} as {}...", path.display(), absolute.display());
860
861 let stream = blobs.add_from_path(absolute, in_place, tag, wrap).await?;
863 aggregate_add_response(stream).await?
864 }
865 BlobSourceIroh::Stdin => {
866 println!("Adding from STDIN...");
867 let (file, path) = tempfile::NamedTempFile::new()?.into_parts();
869 let mut file = tokio::fs::File::from_std(file);
870 let path_buf = path.to_path_buf();
871 tokio::io::copy(&mut tokio::io::stdin(), &mut file).await?;
873 file.flush().await?;
874 drop(file);
875
876 let stream = blobs.add_from_path(path_buf, false, tag, wrap).await?;
878 aggregate_add_response(stream).await?
879 }
880 };
881
882 print_add_response(hash, format, entries);
883 if let TicketOption::Print = ticket {
884 let ticket = BlobTicket::new(addr, hash, format)?;
885 println!("All-in-one ticket: {ticket}");
886 }
887 Ok(())
888}
889
890#[derive(Debug)]
892pub struct ProvideResponseEntry {
893 pub name: String,
894 pub size: u64,
895 pub hash: Hash,
896}
897
898pub async fn aggregate_add_response(
900 mut stream: impl Stream<Item = Result<AddProgress>> + Unpin,
901) -> Result<(Hash, BlobFormat, Vec<ProvideResponseEntry>)> {
902 let mut hash_and_format = None;
903 let mut collections = BTreeMap::<u64, (String, u64, Option<Hash>)>::new();
904 let mut mp = Some(ProvideProgressState::new());
905 while let Some(item) = stream.next().await {
906 match item? {
907 AddProgress::Found { name, id, size } => {
908 tracing::trace!("Found({id},{name},{size})");
909 if let Some(mp) = mp.as_mut() {
910 mp.found(name.clone(), id, size);
911 }
912 collections.insert(id, (name, size, None));
913 }
914 AddProgress::Progress { id, offset } => {
915 tracing::trace!("Progress({id}, {offset})");
916 if let Some(mp) = mp.as_mut() {
917 mp.progress(id, offset);
918 }
919 }
920 AddProgress::Done { hash, id } => {
921 tracing::trace!("Done({id},{hash:?})");
922 if let Some(mp) = mp.as_mut() {
923 mp.done(id, hash);
924 }
925 match collections.get_mut(&id) {
926 Some((_, _, ref mut h)) => {
927 *h = Some(hash);
928 }
929 None => {
930 anyhow::bail!("Got Done for unknown collection id {id}");
931 }
932 }
933 }
934 AddProgress::AllDone { hash, format, .. } => {
935 tracing::trace!("AllDone({hash:?})");
936 if let Some(mp) = mp.take() {
937 mp.all_done();
938 }
939 hash_and_format = Some(HashAndFormat { hash, format });
940 break;
941 }
942 AddProgress::Abort(e) => {
943 if let Some(mp) = mp.take() {
944 mp.error();
945 }
946 anyhow::bail!("Error while adding data: {e}");
947 }
948 }
949 }
950 let HashAndFormat { hash, format } =
951 hash_and_format.context("Missing hash for collection or blob")?;
952 let entries = collections
953 .into_iter()
954 .map(|(_, (name, size, hash))| {
955 let hash = hash.context(format!("Missing hash for {name}"))?;
956 Ok(ProvideResponseEntry { name, size, hash })
957 })
958 .collect::<Result<Vec<_>>>()?;
959 Ok((hash, format, entries))
960}
961
962pub fn print_add_response(hash: Hash, format: BlobFormat, entries: Vec<ProvideResponseEntry>) {
964 let mut total_size = 0;
965 for ProvideResponseEntry { name, size, hash } in entries {
966 total_size += size;
967 println!("- {}: {} {:#}", name, HumanBytes(size), hash);
968 }
969 println!("Total: {}", HumanBytes(total_size));
970 println!();
971 match format {
972 BlobFormat::Raw => println!("Blob: {}", hash),
973 BlobFormat::HashSeq => println!("Collection: {}", hash),
974 }
975}
976
977#[derive(Debug)]
979pub struct ProvideProgressState {
980 mp: MultiProgress,
981 pbs: HashMap<u64, ProgressBar>,
982}
983
984impl ProvideProgressState {
985 fn new() -> Self {
987 Self {
988 mp: MultiProgress::new(),
989 pbs: HashMap::new(),
990 }
991 }
992
993 fn found(&mut self, name: String, id: u64, size: u64) {
995 let pb = self.mp.add(ProgressBar::new(size));
996 pb.set_style(ProgressStyle::default_bar()
997 .template("{spinner:.green} [{bar:40.cyan/blue}] {msg} {bytes}/{total_bytes} ({bytes_per_sec}, eta {eta})").unwrap()
998 .progress_chars("=>-"));
999 pb.set_message(name);
1000 pb.set_length(size);
1001 pb.set_position(0);
1002 pb.enable_steady_tick(Duration::from_millis(500));
1003 self.pbs.insert(id, pb);
1004 }
1005
1006 fn progress(&mut self, id: u64, progress: u64) {
1008 if let Some(pb) = self.pbs.get_mut(&id) {
1009 pb.set_position(progress);
1010 }
1011 }
1012
1013 fn done(&mut self, id: u64, _hash: Hash) {
1015 if let Some(pb) = self.pbs.remove(&id) {
1016 pb.finish_and_clear();
1017 self.mp.remove(&pb);
1018 }
1019 }
1020
1021 fn all_done(self) {
1023 self.mp.clear().ok();
1024 }
1025
1026 fn error(self) {
1028 self.mp.clear().ok();
1029 }
1030}
1031
1032pub async fn show_download_progress(
1034 hash: Hash,
1035 mut stream: impl Stream<Item = Result<DownloadProgress>> + Unpin,
1036) -> Result<()> {
1037 eprintln!("Fetching: {}", hash);
1038 let mp = MultiProgress::new();
1039 mp.set_draw_target(ProgressDrawTarget::stderr());
1040 let op = mp.add(make_overall_progress());
1041 let ip = mp.add(make_individual_progress());
1042 op.set_message(format!("{} Connecting ...\n", style("[1/3]").bold().dim()));
1043 let mut seq = false;
1044 while let Some(x) = stream.next().await {
1045 match x? {
1046 DownloadProgress::InitialState(state) => {
1047 if state.connected {
1048 op.set_message(format!("{} Requesting ...\n", style("[2/3]").bold().dim()));
1049 }
1050 if let Some(count) = state.root.child_count {
1051 op.set_message(format!(
1052 "{} Downloading {} blob(s)\n",
1053 style("[3/3]").bold().dim(),
1054 count + 1,
1055 ));
1056 op.set_length(count + 1);
1057 op.reset();
1058 op.set_position(state.current.map(u64::from).unwrap_or(0));
1059 seq = true;
1060 }
1061 if let Some(blob) = state.get_current() {
1062 if let Some(size) = blob.size {
1063 ip.set_length(size.value());
1064 ip.reset();
1065 match blob.progress {
1066 BlobProgress::Pending => {}
1067 BlobProgress::Progressing(offset) => ip.set_position(offset),
1068 BlobProgress::Done => ip.finish_and_clear(),
1069 }
1070 if !seq {
1071 op.finish_and_clear();
1072 }
1073 }
1074 }
1075 }
1076 DownloadProgress::FoundLocal { .. } => {}
1077 DownloadProgress::Connected => {
1078 op.set_message(format!("{} Requesting ...\n", style("[2/3]").bold().dim()));
1079 }
1080 DownloadProgress::FoundHashSeq { children, .. } => {
1081 op.set_message(format!(
1082 "{} Downloading {} blob(s)\n",
1083 style("[3/3]").bold().dim(),
1084 children + 1,
1085 ));
1086 op.set_length(children + 1);
1087 op.reset();
1088 seq = true;
1089 }
1090 DownloadProgress::Found { size, child, .. } => {
1091 if seq {
1092 op.set_position(child.into());
1093 } else {
1094 op.finish_and_clear();
1095 }
1096 ip.set_length(size);
1097 ip.reset();
1098 }
1099 DownloadProgress::Progress { offset, .. } => {
1100 ip.set_position(offset);
1101 }
1102 DownloadProgress::Done { .. } => {
1103 ip.finish_and_clear();
1104 }
1105 DownloadProgress::AllDone(Stats {
1106 bytes_read,
1107 elapsed,
1108 ..
1109 }) => {
1110 op.finish_and_clear();
1111 eprintln!(
1112 "Transferred {} in {}, {}/s",
1113 HumanBytes(bytes_read),
1114 HumanDuration(elapsed),
1115 HumanBytes((bytes_read as f64 / elapsed.as_secs_f64()) as u64)
1116 );
1117 break;
1118 }
1119 DownloadProgress::Abort(e) => {
1120 bail!("download aborted: {}", e);
1121 }
1122 }
1123 }
1124 Ok(())
1125}
1126
1127#[derive(Debug, Clone, derive_more::Display, PartialEq, Eq)]
1129pub enum OutputTarget {
1130 #[display("STDOUT")]
1132 Stdout,
1133 #[display("{}", _0.display())]
1135 Path(PathBuf),
1136}
1137
1138impl From<String> for OutputTarget {
1139 fn from(s: String) -> Self {
1140 if s == "STDOUT" {
1141 return OutputTarget::Stdout;
1142 }
1143
1144 OutputTarget::Path(s.into())
1145 }
1146}
1147
1148fn make_overall_progress() -> ProgressBar {
1150 let pb = ProgressBar::hidden();
1151 pb.enable_steady_tick(std::time::Duration::from_millis(100));
1152 pb.set_style(
1153 ProgressStyle::with_template(
1154 "{msg}{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {pos}/{len}",
1155 )
1156 .unwrap()
1157 .progress_chars("#>-"),
1158 );
1159 pb
1160}
1161
1162fn make_individual_progress() -> ProgressBar {
1164 let pb = ProgressBar::hidden();
1165 pb.enable_steady_tick(std::time::Duration::from_millis(100));
1166 pb.set_style(
1167 ProgressStyle::with_template("{msg}{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {bytes}/{total_bytes} ({eta})")
1168 .unwrap()
1169 .with_key(
1170 "eta",
1171 |state: &ProgressState, w: &mut dyn std::fmt::Write| {
1172 write!(w, "{:.1}s", state.eta().as_secs_f64()).unwrap()
1173 },
1174 )
1175 .progress_chars("#>-"),
1176 );
1177 pb
1178}
1179
1180#[cfg(test)]
1181mod tests {
1182 use super::*;
1183
1184 #[test]
1185 fn test_blob_source() {
1186 assert_eq!(
1187 BlobSource::from(BlobSource::Stdin.to_string()),
1188 BlobSource::Stdin
1189 );
1190
1191 assert_eq!(
1192 BlobSource::from(BlobSource::Path("hello/world".into()).to_string()),
1193 BlobSource::Path("hello/world".into()),
1194 );
1195 }
1196
1197 #[test]
1198 fn test_output_target() {
1199 assert_eq!(
1200 OutputTarget::from(OutputTarget::Stdout.to_string()),
1201 OutputTarget::Stdout
1202 );
1203
1204 assert_eq!(
1205 OutputTarget::from(OutputTarget::Path("hello/world".into()).to_string()),
1206 OutputTarget::Path("hello/world".into()),
1207 );
1208 }
1209}