Skip to main content

gemini_client_api/
utils.rs

1use base64::{Engine, engine::general_purpose::STANDARD};
2use futures::future::join_all;
3pub use mime;
4use mime::Mime;
5use regex::Regex;
6#[cfg(feature = "reqwest")]
7use reqwest::Client;
8#[cfg(feature = "reqwest")]
9pub use reqwest::header::{HeaderMap, HeaderValue};
10use std::{str::FromStr, time::Duration};
11
12#[derive(Clone)]
13pub struct MatchedFiles {
14    pub index: usize,
15    pub length: usize,
16    pub mime_type: Option<Mime>,
17    pub base64: Option<String>,
18}
19/// # Panics
20/// `regex` must have a Regex with atleast 1 capture group with file URL as first capture group, else it PANICS
21/// # Arguments
22/// `guess_mime_type` is used to detect mimi_type of URL pointing to file system or web resource
23/// with no "Content-Type" header.
24#[cfg(feature = "reqwest")]
25pub async fn get_file_base64s(
26    markdown: impl AsRef<str>,
27    regex: Regex,
28    guess_mime_type: fn(url: &str) -> Mime,
29    decide_download: fn(headers: &HeaderMap) -> bool,
30    timeout: Duration,
31) -> Vec<MatchedFiles> {
32    let client = Client::builder().timeout(timeout).build().unwrap();
33    let mut tasks = Vec::new();
34
35    for file in regex.captures_iter(markdown.as_ref()) {
36        let capture = file.get(0).unwrap();
37        let url = file[1].to_string();
38        tasks.push((async |capture: regex::Match<'_>, url: String| {
39            let (mime_type, base64) = if url.starts_with("https://") || url.starts_with("http://") {
40                let response = client.get(&url).send().await;
41                match response {
42                    Ok(response) if (decide_download)(response.headers()) => {
43                        let mime_type = response
44                            .headers()
45                            .get("Content-Type")
46                            .map(|mime| mime.to_str().ok())
47                            .flatten()
48                            .map(|mime| Mime::from_str(mime).ok())
49                            .flatten();
50
51                        let base64 = response
52                            .bytes()
53                            .await
54                            .ok()
55                            .map(|bytes| STANDARD.encode(bytes));
56                        let mime_type = match base64 {
57                            Some(_) => mime_type.or_else(|| Some(guess_mime_type(&url))),
58                            None => None,
59                        };
60                        (mime_type, base64)
61                    }
62                    _ => (None, None),
63                }
64            } else {
65                let base64 = tokio::fs::read(url.clone())
66                    .await
67                    .ok()
68                    .map(|bytes| STANDARD.encode(&bytes));
69                match base64 {
70                    Some(base64) => (Some(guess_mime_type(&url)), Some(base64)),
71                    None => (None, None),
72                }
73            };
74            MatchedFiles {
75                index: capture.start(),
76                length: capture.len(),
77                mime_type,
78                base64,
79            }
80        })(capture, url));
81    }
82    join_all(tasks).await
83}