genkit/
helpers.rs

1use anyhow::Result;
2use hyper::{
3    body::{self, Buf},
4    http::HeaderValue,
5    Client, Request, Uri,
6};
7use hyper_tls::HttpsConnector;
8use rayon::iter::{ParallelBridge, ParallelIterator};
9use std::{
10    collections::HashMap,
11    fs,
12    io::{self, ErrorKind, Read},
13    path::Path,
14    process::Command,
15};
16use time::{format_description, Date};
17
18pub fn run_command(program: &str, args: &[&str]) -> Result<String, io::Error> {
19    let out = Command::new(program).args(args).output()?;
20    match out.status.success() {
21        true => Ok(String::from_utf8(out.stdout).unwrap().trim().to_string()),
22        false => Err(io::Error::new(
23            ErrorKind::Other,
24            format!("run command `{program} {}` failed.", args.join(" ")),
25        )),
26    }
27}
28
29pub fn capitalize(text: &str) -> String {
30    let mut chars = text.chars();
31    match chars.next() {
32        None => String::new(),
33        Some(f) => f.to_uppercase().collect::<String>() + &chars.as_str().to_lowercase(),
34    }
35}
36
37pub fn format_date(date: &Date) -> String {
38    let format = format_description::parse("[year]-[month]-[day]").expect("Shouldn't happen");
39    date.format(&format).expect("Serialize date error")
40}
41
42/// Split styles into string pair.
43///
44/// ```rust
45/// use genkit::helpers::split_styles;
46///
47/// let pair = split_styles("color: #abcdef; font-size: 14px; background-image: url('/test.png');");
48/// assert_eq!(pair.get("color").unwrap(), &"#abcdef");
49/// assert_eq!(pair.get("font-size").unwrap(), &"14px");
50/// assert_eq!(pair.get("background-image").unwrap(), &"url('/test.png')");
51/// assert_eq!(pair.get("width"), None);
52///
53/// let pair = split_styles("invalid");
54/// assert!(pair.is_empty());
55/// ```
56pub fn split_styles(style: &str) -> HashMap<&str, &str> {
57    style
58        .split(';')
59        .filter_map(|pair| {
60            let mut v = pair.split(':').take(2);
61            match (v.next(), v.next()) {
62                (Some(key), Some(value)) => Some((key.trim(), value.trim())),
63                _ => None,
64            }
65        })
66        .collect::<HashMap<_, _>>()
67}
68
69pub async fn fetch_url(url: &str) -> Result<impl Read> {
70    let client = Client::builder().build::<_, hyper::Body>(HttpsConnector::new());
71    let mut req = Request::new(Default::default());
72    *req.uri_mut() = url.parse::<Uri>()?;
73    req.headers_mut().insert(
74        "User-Agent",
75        HeaderValue::from_static(
76            "Mozilla/5.0 AppleWebKit/537.36 (KHTML, like Gecko) Chrome/107.0.0.0 Safari/537.36",
77        ),
78    );
79    let resp = client.request(req).await?;
80    if resp.status().is_redirection() {
81        if let Some(location) = resp.headers().get("Location") {
82            println!(
83                "Warning: url `{url}` has been redirected to `{}`",
84                location.to_str()?,
85            );
86        } else {
87            println!("Warning: url `{url}` has been redirected");
88        }
89    } else if !resp.status().is_success() {
90        let warning = format!(
91            "Warning: failed to fetch url `{url}`, status code: {status}",
92            url = url,
93            status = resp.status()
94        );
95        println!("{warning}");
96        anyhow::bail!(warning);
97    }
98    let bytes = body::to_bytes(resp.into_body()).await?;
99    Ok(bytes.reader())
100}
101
102/// Copy directory recursively.
103/// Note: the empty directory is ignored.
104pub fn copy_dir(source: &Path, dest: &Path) -> Result<()> {
105    let source_parent = source.parent().expect("Can not copy the root dir");
106    walkdir::WalkDir::new(source)
107        .into_iter()
108        .par_bridge()
109        .try_for_each(|entry| {
110            let entry = entry?;
111            let path = entry.path();
112            // `path` would be a file or directory. However, we are
113            // in a rayon's parallel thread, there is no guarantee
114            // that parent directory iterated before the file.
115            // So we just ignore the `path.is_dir()` case, when coming
116            // across the first file we'll create the parent directory.
117            if path.is_file() {
118                if let Some(parent) = path.parent() {
119                    let dest_parent = dest.join(parent.strip_prefix(source_parent)?);
120                    if !dest_parent.exists() {
121                        // Create the same dir concurrently is ok according to the docs.
122                        fs::create_dir_all(dest_parent)?;
123                    }
124                }
125                let to = dest.join(path.strip_prefix(source_parent)?);
126                fs::copy(path, to)?;
127            }
128
129            anyhow::Ok(())
130        })?;
131    Ok(())
132}
133
134/// A serde module to serialize and deserialize [`time::Date`] type.
135pub mod serde_date {
136    use super::*;
137    use serde::{de, Serialize, Serializer};
138
139    pub fn serialize<S: Serializer>(date: &Date, serializer: S) -> Result<S::Ok, S::Error> {
140        super::format_date(date).serialize(serializer)
141    }
142
143    pub fn deserialize<'de, D>(d: D) -> Result<Date, D::Error>
144    where
145        D: de::Deserializer<'de>,
146    {
147        d.deserialize_any(DateVisitor)
148    }
149
150    struct DateVisitor;
151
152    impl<'de> de::Visitor<'de> for DateVisitor {
153        type Value = Date;
154
155        fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
156            formatter.write_str("The date format is YYYY-MM-DD")
157        }
158
159        fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
160        where
161            E: de::Error,
162        {
163            let format =
164                format_description::parse("[year]-[month]-[day]").expect("Shouldn't happen");
165            Date::parse(v, &format)
166                .map_err(|e| E::custom(format!("The date value {} is invalid: {}", v, e)))
167        }
168    }
169
170    pub mod options {
171        use super::*;
172
173        struct OptionDateVisitor;
174
175        pub fn serialize<S: Serializer>(
176            date: &Option<Date>,
177            serializer: S,
178        ) -> Result<S::Ok, S::Error> {
179            if let Some(date) = date {
180                super::serialize(date, serializer)
181            } else {
182                None::<Date>.serialize(serializer)
183            }
184        }
185
186        pub fn deserialize<'de, D>(d: D) -> Result<Option<Date>, D::Error>
187        where
188            D: de::Deserializer<'de>,
189        {
190            d.deserialize_option(OptionDateVisitor)
191        }
192
193        impl<'de> de::Visitor<'de> for OptionDateVisitor {
194            type Value = Option<Date>;
195
196            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
197                formatter.write_str("a YYYY-MM-DD date or none")
198            }
199
200            fn visit_some<D>(self, d: D) -> Result<Self::Value, D::Error>
201            where
202                D: de::Deserializer<'de>,
203            {
204                d.deserialize_str(DateVisitor).map(Some)
205            }
206
207            fn visit_none<E>(self) -> Result<Self::Value, E>
208            where
209                E: de::Error,
210            {
211                Ok(None)
212            }
213        }
214    }
215}