Skip to main content

gemini_client_api/
utils.rs

1use base64::{Engine, engine::general_purpose::STANDARD};
2use futures::future::join_all;
3pub use mime;
4use mime::Mime;
5use regex::Regex;
6use reqwest::Client;
7pub use reqwest::header::{HeaderMap, HeaderValue};
8use std::{str::FromStr, time::Duration};
9
10#[derive(Clone)]
11pub struct MatchedFiles {
12    pub index: usize,
13    pub length: usize,
14    pub mime_type: Option<Mime>,
15    pub base64: Option<String>,
16}
17/// # Panics
18/// `regex` must have a Regex with atleast 1 capture group with file URL as first capture group, else it PANICS
19/// # Arguments
20/// `guess_mime_type` is used to detect mimi_type of URL pointing to file system or web resource
21/// with no "Content-Type" header.
22pub async fn get_file_base64s(
23    markdown: impl AsRef<str>,
24    regex: Regex,
25    guess_mime_type: fn(url: &str) -> Mime,
26    decide_download: fn(headers: &HeaderMap) -> bool,
27    timeout: Duration,
28) -> Vec<MatchedFiles> {
29    let client = Client::builder().timeout(timeout).build().unwrap();
30    let mut tasks = Vec::new();
31
32    for file in regex.captures_iter(markdown.as_ref()) {
33        let capture = file.get(0).unwrap();
34        let url = file[1].to_string();
35        tasks.push((async |capture: regex::Match<'_>, url: String| {
36            let (mime_type, base64) = if url.starts_with("https://") || url.starts_with("http://") {
37                let response = client.get(&url).send().await;
38                match response {
39                    Ok(response) if (decide_download)(response.headers()) => {
40                        let mime_type = response
41                            .headers()
42                            .get("Content-Type")
43                            .map(|mime| mime.to_str().ok())
44                            .flatten()
45                            .map(|mime| Mime::from_str(mime).ok())
46                            .flatten();
47
48                        let base64 = response
49                            .bytes()
50                            .await
51                            .ok()
52                            .map(|bytes| STANDARD.encode(bytes));
53                        let mime_type = match base64 {
54                            Some(_) => mime_type.or_else(|| Some(guess_mime_type(&url))),
55                            None => None,
56                        };
57                        (mime_type, base64)
58                    }
59                    _ => (None, None),
60                }
61            } else {
62                let base64 = tokio::fs::read(url.clone())
63                    .await
64                    .ok()
65                    .map(|bytes| STANDARD.encode(&bytes));
66                match base64 {
67                    Some(base64) => (Some(guess_mime_type(&url)), Some(base64)),
68                    None => (None, None),
69                }
70            };
71            MatchedFiles {
72                index: capture.start(),
73                length: capture.len(),
74                mime_type,
75                base64,
76            }
77        })(capture, url));
78    }
79    join_all(tasks).await
80}