gemini_client_api/
utils.rs1use base64::{Engine, engine::general_purpose::STANDARD};
2use futures::future::join_all;
3pub use mime;
4use mime::Mime;
5use regex::Regex;
6#[cfg(feature = "reqwest")]
7use reqwest::Client;
8#[cfg(feature = "reqwest")]
9pub use reqwest::header::{HeaderMap, HeaderValue};
10use std::{str::FromStr, time::Duration};
11
12#[derive(Clone)]
13pub struct MatchedFiles {
14 pub index: usize,
15 pub length: usize,
16 pub mime_type: Option<Mime>,
17 pub base64: Option<String>,
18}
19#[cfg(feature = "reqwest")]
25pub async fn get_file_base64s(
26 markdown: impl AsRef<str>,
27 regex: Regex,
28 guess_mime_type: fn(url: &str) -> Mime,
29 decide_download: fn(headers: &HeaderMap) -> bool,
30 timeout: Duration,
31) -> Vec<MatchedFiles> {
32 let client = Client::builder().timeout(timeout).build().unwrap();
33 let mut tasks = Vec::new();
34
35 for file in regex.captures_iter(markdown.as_ref()) {
36 let capture = file.get(0).unwrap();
37 let url = file[1].to_string();
38 tasks.push((async |capture: regex::Match<'_>, url: String| {
39 let (mime_type, base64) = if url.starts_with("https://") || url.starts_with("http://") {
40 let response = client.get(&url).send().await;
41 match response {
42 Ok(response) if (decide_download)(response.headers()) => {
43 let mime_type = response
44 .headers()
45 .get("Content-Type")
46 .map(|mime| mime.to_str().ok())
47 .flatten()
48 .map(|mime| Mime::from_str(mime).ok())
49 .flatten();
50
51 let base64 = response
52 .bytes()
53 .await
54 .ok()
55 .map(|bytes| STANDARD.encode(bytes));
56 let mime_type = match base64 {
57 Some(_) => mime_type.or_else(|| Some(guess_mime_type(&url))),
58 None => None,
59 };
60 (mime_type, base64)
61 }
62 _ => (None, None),
63 }
64 } else {
65 let base64 = tokio::fs::read(url.clone())
66 .await
67 .ok()
68 .map(|bytes| STANDARD.encode(&bytes));
69 match base64 {
70 Some(base64) => (Some(guess_mime_type(&url)), Some(base64)),
71 None => (None, None),
72 }
73 };
74 MatchedFiles {
75 index: capture.start(),
76 length: capture.len(),
77 mime_type,
78 base64,
79 }
80 })(capture, url));
81 }
82 join_all(tasks).await
83}