goat_cli/progress/
mod.rs

1//!
2//! Module for progress bar addition to searches.
3//!
4//! Add with `--progress-bar` in `goat-cli search` and
5//! `goat-cli newick`.
6
7use anyhow::{bail, ensure, Result};
8use async_std::task;
9use futures::StreamExt;
10use indicatif;
11use reqwest;
12use reqwest::header::ACCEPT;
13use serde_json::Value;
14use std::time::Duration;
15
16use crate::utils::cli_matches;
17use crate::{count, IndexType};
18use crate::{GOAT_URL, UPPER_CLI_SIZE_LIMIT};
19
20// a function to create and display a progress bar
21// for large requests. Currently limited to single large requests.
22
23/// Adds a progress bar to large requests.
24pub async fn progress_bar(
25    matches: &clap::ArgMatches,
26    api: &str,
27    unique_ids: Vec<String>,
28    index_type: IndexType,
29) -> Result<()> {
30    // wait briefly before submitting
31    // so we are sure the API has recieved and set the queryId
32    task::sleep(Duration::from_secs(2)).await;
33    // TODO: clean this up.
34    let (size_int, _url_vector, url_vector_api) = match api {
35        // doesn't matter what is in the vecs, they just need to be length 1
36        // as newick only supports single url calls right now.
37        // this is really bad coding...
38        "newick" => (0u64, vec!["init".to_string()], vec!["init".to_string()]),
39        other => cli_matches::process_cli_args(matches, other, unique_ids.clone(), index_type)?,
40    };
41
42    ensure!(
43        unique_ids.len() == url_vector_api.len(),
44        "No reason these lengths should be different."
45    );
46
47    let concurrent_requests = url_vector_api.len();
48
49    // should be fine to always unwrap this
50    let no_query_hits = count::count(matches, false, false, unique_ids.clone(), index_type)
51        .await?
52        .unwrap();
53    // might need tweaking...
54    // special case newick
55    if api != "newick" {
56        // I think these actually need to be
57        // 10,000... but that's our upper limit for search
58        if no_query_hits < 10000 || size_int < 10000 {
59            return Ok(());
60        }
61    }
62
63    // add the query ID's to a vec
64    let mut query_id_vec = Vec::new();
65    for i in unique_ids.iter().take(concurrent_requests) {
66        let query_id = format!("{}progress?queryId=goat_cli_{}", *GOAT_URL, i);
67        query_id_vec.push(query_id);
68    }
69
70    // we want to wrap this in a loop
71    // and break when sum(progress_x) == sum(progress_total)
72    let bar = indicatif::ProgressBar::new(512);
73    let bar_style = ("█▓▓▒░░░ ", "magenta");
74    bar.set_style(
75        indicatif::ProgressStyle::default_bar()
76            .template(&format!(
77                "{{prefix:.bold}}▕{{bar:57.{}}}▏{{pos}}/{{len}} {{wide_msg}}",
78                bar_style.1
79            ))?
80            .progress_chars(bar_style.0),
81    );
82    bar.set_prefix("Fetching from GoaT: ");
83
84    loop {
85        // main body
86        let fetches =
87            futures::stream::iter(query_id_vec.clone().into_iter().map(|path| async move {
88                // possibly make a again::RetryPolicy
89                // to catch all the values in a *very* large request.
90                let client = reqwest::Client::new();
91
92                match again::retry(|| client.get(&path).header(ACCEPT, "application/json").send())
93                    .await
94                {
95                    Ok(resp) => match resp.text().await {
96                        Ok(body) => {
97                            let v: Value = serde_json::from_str(&body)?;
98
99                            match &v["progress"] {
100                                Value::Object(_o) => {
101                                    let progress_total = v["progress"]["total"].as_u64();
102                                    let progress_x = v["progress"]["x"].as_u64();
103                                    Ok(Some((progress_x, progress_total)))
104                                }
105                                _ => Ok(None),
106                            }
107                        }
108                        Err(_) => bail!("ERROR reading {}", path),
109                    },
110                    Err(_) => bail!("ERROR downloading {}", path),
111                }
112            }))
113            .buffered(concurrent_requests)
114            // complicated. Each u64 can be an option, as some
115            // queries will finish before others
116            // entire tuple is an option, as other progress enums evaluate to None.
117            .collect::<Vec<Result<Option<(Option<u64>, Option<u64>)>>>>();
118
119        let awaited_fetches = fetches.await;
120        // what's going on here?
121        let progress_total: Result<Vec<_>, _> = awaited_fetches.into_iter().collect();
122
123        let mut progress_x_total = 0;
124        let mut progress_total_total = 0;
125        for el in progress_total.unwrap() {
126            let x_tot_tup = match el {
127                Some(t) => t,
128                None => (None, None),
129            };
130            progress_x_total += x_tot_tup.0.unwrap_or(0);
131            progress_total_total += x_tot_tup.1.unwrap_or(0);
132        }
133
134        // special case newick
135        match api {
136            "newick" => bar.set_length(progress_total_total),
137            _ => match progress_total_total > *UPPER_CLI_SIZE_LIMIT as u64 {
138                true => bar.set_length(size_int),
139                false => bar.set_length(progress_total_total),
140            },
141        }
142
143        bar.set_position(progress_x_total);
144
145        if progress_x_total >= progress_total_total {
146            break;
147        }
148
149        task::sleep(Duration::from_millis(1)).await;
150    }
151
152    Ok(())
153}