Skip to main content

walker_common/scoop/
source.rs

1use anyhow::bail;
2use bytes::Bytes;
3use std::{
4    borrow::Cow,
5    path::{Path, PathBuf},
6};
7use url::Url;
8
9#[cfg(feature = "s3")]
10use aws_config::{BehaviorVersion, Region, meta::region::RegionProviderChain};
11#[cfg(feature = "s3")]
12use aws_sdk_s3::{
13    Client,
14    config::{AppName, Credentials},
15};
16
17#[derive(Clone, Debug, PartialEq, Eq)]
18#[non_exhaustive]
19pub enum Source {
20    Path(PathBuf),
21    Http(Url),
22    #[cfg(feature = "s3")]
23    S3(S3),
24}
25
26impl TryFrom<&str> for Source {
27    type Error = anyhow::Error;
28
29    fn try_from(value: &str) -> Result<Self, Self::Error> {
30        if value.starts_with("http://") || value.starts_with("https://") {
31            return Ok(Self::Http(Url::parse(value)?));
32        }
33
34        #[cfg(feature = "s3")]
35        if value.starts_with("s3://") {
36            return Ok(Self::S3(S3::try_from(value)?));
37        }
38        #[cfg(not(feature = "s3"))]
39        if value.starts_with("s3://") {
40            bail!("S3 URLs are not supported");
41        }
42
43        Ok(Self::Path(value.into()))
44    }
45}
46
47impl Source {
48    pub async fn discover(self) -> anyhow::Result<Vec<Self>> {
49        match self {
50            Self::Path(path) => Ok(Self::discover_path(path)?
51                .into_iter()
52                .map(Self::Path)
53                .collect()),
54            #[cfg(feature = "s3")]
55            Self::S3(s3) if s3.key.is_none() => Ok(Self::discover_s3(s3)
56                .await?
57                .into_iter()
58                .map(Self::S3)
59                .collect()),
60            value => Ok(vec![value]),
61        }
62    }
63
64    fn discover_path(path: PathBuf) -> anyhow::Result<Vec<PathBuf>> {
65        log::debug!("Discovering: {}", path.display());
66
67        if !path.exists() {
68            bail!("{} does not exist", path.display());
69        } else if path.is_file() {
70            log::debug!("Is a file");
71            Ok(vec![path])
72        } else if path.is_dir() {
73            log::debug!("Is a directory");
74            let mut result = Vec::new();
75
76            for path in walkdir::WalkDir::new(path).into_iter() {
77                let path = path?;
78                if path.file_type().is_file() {
79                    result.push(path.path().to_path_buf());
80                }
81            }
82
83            Ok(result)
84        } else {
85            log::warn!("Is something unknown: {}", path.display());
86            Ok(vec![])
87        }
88    }
89
90    #[cfg(feature = "s3")]
91    async fn discover_s3(s3: S3) -> anyhow::Result<Vec<S3>> {
92        let client = s3.client().await?;
93
94        let mut response = client
95            .list_objects_v2()
96            .bucket(s3.bucket.clone())
97            .max_keys(100)
98            .into_paginator()
99            .send();
100
101        let mut result = vec![];
102        while let Some(next) = response.next().await {
103            let next = next?;
104            for object in next.contents() {
105                if let Some(key) = object.key.clone() {
106                    result.push(key);
107                }
108            }
109        }
110
111        Ok(result
112            .into_iter()
113            .map(|key| S3 {
114                key: Some(key),
115                ..(s3.clone())
116            })
117            .collect())
118    }
119
120    pub fn name(&self) -> Cow<'_, str> {
121        match self {
122            Self::Path(path) => path.to_string_lossy(),
123            Self::Http(url) => url.as_str().into(),
124            #[cfg(feature = "s3")]
125            Self::S3(s3) => format!(
126                "s3://{}/{}/{}",
127                s3.region,
128                s3.bucket,
129                s3.key.as_deref().unwrap_or_default()
130            )
131            .into(),
132        }
133    }
134
135    /// Load the content of the source
136    pub async fn load(&self) -> Result<Bytes, anyhow::Error> {
137        Ok(match self {
138            Self::Path(path) => tokio::fs::read(path).await?.into(),
139            Self::Http(url) => {
140                reqwest::get(url.clone())
141                    .await?
142                    .error_for_status()?
143                    .bytes()
144                    .await?
145            }
146            #[cfg(feature = "s3")]
147            Self::S3(s3) => {
148                let client = s3.client();
149                client
150                    .await?
151                    .get_object()
152                    .key(s3.key.clone().unwrap_or_default())
153                    .bucket(s3.bucket.clone())
154                    .send()
155                    .await?
156                    .body
157                    .collect()
158                    .await?
159                    .into_bytes()
160            }
161        })
162    }
163
164    /// Delete the source
165    pub async fn delete(&self) -> anyhow::Result<()> {
166        match self {
167            Self::Path(file) => {
168                // just delete the file
169                tokio::fs::remove_file(&file).await?;
170            }
171            Self::Http(url) => {
172                // issue a DELETE request
173                reqwest::Client::builder()
174                    .build()?
175                    .delete(url.clone())
176                    .send()
177                    .await?;
178            }
179            #[cfg(feature = "s3")]
180            Self::S3(s3) => {
181                // delete the object from the bucket
182                let client = s3.client();
183                client
184                    .await?
185                    .delete_object()
186                    .key(s3.key.clone().unwrap_or_default())
187                    .bucket(s3.bucket.clone())
188                    .send()
189                    .await?;
190            }
191        }
192
193        Ok(())
194    }
195
196    /// move the source
197    ///
198    /// NOTE: This is a no-op for HTTP sources.
199    pub async fn r#move(&self, path: &str) -> anyhow::Result<()> {
200        match self {
201            Self::Path(file) => {
202                let path = Path::new(&path);
203                tokio::fs::create_dir_all(path).await?;
204                tokio::fs::copy(&file, path.join(file)).await?;
205                tokio::fs::remove_file(&file).await?;
206            }
207            Self::Http(url) => {
208                // no-op, but warn
209                log::warn!("Unable to move HTTP source ({url}), skipping!");
210            }
211            #[cfg(feature = "s3")]
212            Self::S3(s3) => {
213                let client = s3.client();
214                client
215                    .await?
216                    .copy_object()
217                    .copy_source(s3.key.clone().unwrap_or_default())
218                    .key(format!("{path}/{}", s3.key.as_deref().unwrap_or_default()))
219                    .bucket(s3.bucket.clone())
220                    .send()
221                    .await?;
222            }
223        }
224
225        Ok(())
226    }
227}
228
229#[cfg(feature = "s3")]
230#[derive(Clone, Debug, PartialEq, Eq)]
231pub struct S3 {
232    region: String,
233    credentials: Option<(String, String)>,
234    bucket: String,
235    key: Option<String>,
236}
237
238#[cfg(feature = "s3")]
239impl TryFrom<&str> for S3 {
240    type Error = anyhow::Error;
241
242    fn try_from(value: &str) -> Result<Self, Self::Error> {
243        let uri = fluent_uri::Uri::try_from(value)?;
244
245        let Some(auth) = uri.authority() else {
246            bail!("Missing authority");
247        };
248
249        let path = uri.path().to_string();
250        let path = path.trim_start_matches('/');
251        if path.is_empty() {
252            bail!("Missing bucket");
253        }
254
255        let (bucket, key) = match path.split_once('/') {
256            Some((bucket, key)) => (bucket.to_string(), Some(key.to_string())),
257            None => (path.to_string(), None),
258        };
259
260        let region = auth.host().to_string();
261
262        let credentials = auth.userinfo().and_then(|userinfo| {
263            userinfo
264                .split_once(':')
265                .map(|(username, password)| (username.to_string(), password.to_string()))
266        });
267
268        Ok(S3 {
269            region,
270            credentials,
271            bucket,
272            key,
273        })
274    }
275}
276
277#[cfg(feature = "s3")]
278impl S3 {
279    pub async fn client(&self) -> anyhow::Result<Client> {
280        let region_provider = RegionProviderChain::first_try(Region::new(self.region.clone()));
281
282        let mut shared_config = aws_config::defaults(BehaviorVersion::latest())
283            .region(region_provider)
284            .app_name(AppName::new(crate::USER_AGENT)?);
285
286        if let Some((key_id, access_key)) = &self.credentials {
287            let credentials = Credentials::new(key_id, access_key, None, None, "config");
288            shared_config = shared_config.credentials_provider(credentials);
289        }
290
291        let shared_config = shared_config.load().await;
292
293        Ok(Client::new(&shared_config))
294    }
295}
296
297#[cfg(test)]
298mod tests {
299    #[allow(unused_imports)]
300    use super::*;
301
302    #[cfg(feature = "s3")]
303    #[test]
304    fn parse_s3() {
305        assert_eq!(
306            S3 {
307                region: "us-east-1".to_string(),
308                credentials: None,
309                bucket: "b1".to_string(),
310                key: None,
311            },
312            S3::try_from("s3://us-east-1/b1").unwrap()
313        );
314        assert_eq!(
315            S3 {
316                region: "us-east-1".to_string(),
317                credentials: Some(("foo".to_string(), "bar".to_string())),
318                bucket: "b1".to_string(),
319                key: None,
320            },
321            S3::try_from("s3://foo:bar@us-east-1/b1").unwrap()
322        );
323        assert_eq!(
324            S3 {
325                region: "us-east-1".to_string(),
326                credentials: Some(("foo".to_string(), "bar".to_string())),
327                bucket: "b1".to_string(),
328                key: Some("path/to/file".to_string()),
329            },
330            S3::try_from("s3://foo:bar@us-east-1/b1/path/to/file").unwrap()
331        );
332    }
333
334    #[cfg(feature = "s3")]
335    #[test]
336    fn parse_s3_custom_region() {
337        assert_eq!(
338            S3 {
339                region: "my.own.endpoint".to_string(),
340                credentials: None,
341                bucket: "b1".to_string(),
342                key: None,
343            },
344            S3::try_from("s3://my.own.endpoint/b1").unwrap()
345        );
346        assert_eq!(
347            S3 {
348                region: "my.own.endpoint".to_string(),
349                credentials: Some(("foo".to_string(), "bar".to_string())),
350                bucket: "b1".to_string(),
351                key: None,
352            },
353            S3::try_from("s3://foo:bar@my.own.endpoint/b1").unwrap()
354        );
355        assert_eq!(
356            S3 {
357                region: "my.own.endpoint".to_string(),
358                credentials: Some(("foo".to_string(), "bar".to_string())),
359                bucket: "b1".to_string(),
360                key: Some("path/to/file".to_string()),
361            },
362            S3::try_from("s3://foo:bar@my.own.endpoint/b1/path/to/file").unwrap()
363        );
364    }
365}