1use anyhow::bail;
2use bytes::Bytes;
3use std::{
4 borrow::Cow,
5 path::{Path, PathBuf},
6};
7use url::Url;
8
9#[cfg(feature = "s3")]
10use aws_config::{BehaviorVersion, Region, meta::region::RegionProviderChain};
11#[cfg(feature = "s3")]
12use aws_sdk_s3::{
13 Client,
14 config::{AppName, Credentials},
15};
16
17#[derive(Clone, Debug, PartialEq, Eq)]
18#[non_exhaustive]
19pub enum Source {
20 Path(PathBuf),
21 Http(Url),
22 #[cfg(feature = "s3")]
23 S3(S3),
24}
25
26impl TryFrom<&str> for Source {
27 type Error = anyhow::Error;
28
29 fn try_from(value: &str) -> Result<Self, Self::Error> {
30 if value.starts_with("http://") || value.starts_with("https://") {
31 return Ok(Self::Http(Url::parse(value)?));
32 }
33
34 #[cfg(feature = "s3")]
35 if value.starts_with("s3://") {
36 return Ok(Self::S3(S3::try_from(value)?));
37 }
38 #[cfg(not(feature = "s3"))]
39 if value.starts_with("s3://") {
40 bail!("S3 URLs are not supported");
41 }
42
43 Ok(Self::Path(value.into()))
44 }
45}
46
47impl Source {
48 pub async fn discover(self) -> anyhow::Result<Vec<Self>> {
49 match self {
50 Self::Path(path) => Ok(Self::discover_path(path)?
51 .into_iter()
52 .map(Self::Path)
53 .collect()),
54 #[cfg(feature = "s3")]
55 Self::S3(s3) if s3.key.is_none() => Ok(Self::discover_s3(s3)
56 .await?
57 .into_iter()
58 .map(Self::S3)
59 .collect()),
60 value => Ok(vec![value]),
61 }
62 }
63
64 fn discover_path(path: PathBuf) -> anyhow::Result<Vec<PathBuf>> {
65 log::debug!("Discovering: {}", path.display());
66
67 if !path.exists() {
68 bail!("{} does not exist", path.display());
69 } else if path.is_file() {
70 log::debug!("Is a file");
71 Ok(vec![path])
72 } else if path.is_dir() {
73 log::debug!("Is a directory");
74 let mut result = Vec::new();
75
76 for path in walkdir::WalkDir::new(path).into_iter() {
77 let path = path?;
78 if path.file_type().is_file() {
79 result.push(path.path().to_path_buf());
80 }
81 }
82
83 Ok(result)
84 } else {
85 log::warn!("Is something unknown: {}", path.display());
86 Ok(vec![])
87 }
88 }
89
90 #[cfg(feature = "s3")]
91 async fn discover_s3(s3: S3) -> anyhow::Result<Vec<S3>> {
92 let client = s3.client().await?;
93
94 let mut response = client
95 .list_objects_v2()
96 .bucket(s3.bucket.clone())
97 .max_keys(100)
98 .into_paginator()
99 .send();
100
101 let mut result = vec![];
102 while let Some(next) = response.next().await {
103 let next = next?;
104 for object in next.contents() {
105 if let Some(key) = object.key.clone() {
106 result.push(key);
107 }
108 }
109 }
110
111 Ok(result
112 .into_iter()
113 .map(|key| S3 {
114 key: Some(key),
115 ..(s3.clone())
116 })
117 .collect())
118 }
119
120 pub fn name(&self) -> Cow<'_, str> {
121 match self {
122 Self::Path(path) => path.to_string_lossy(),
123 Self::Http(url) => url.as_str().into(),
124 #[cfg(feature = "s3")]
125 Self::S3(s3) => format!(
126 "s3://{}/{}/{}",
127 s3.region,
128 s3.bucket,
129 s3.key.as_deref().unwrap_or_default()
130 )
131 .into(),
132 }
133 }
134
135 pub async fn load(&self) -> Result<Bytes, anyhow::Error> {
137 Ok(match self {
138 Self::Path(path) => tokio::fs::read(path).await?.into(),
139 Self::Http(url) => {
140 reqwest::get(url.clone())
141 .await?
142 .error_for_status()?
143 .bytes()
144 .await?
145 }
146 #[cfg(feature = "s3")]
147 Self::S3(s3) => {
148 let client = s3.client();
149 client
150 .await?
151 .get_object()
152 .key(s3.key.clone().unwrap_or_default())
153 .bucket(s3.bucket.clone())
154 .send()
155 .await?
156 .body
157 .collect()
158 .await?
159 .into_bytes()
160 }
161 })
162 }
163
164 pub async fn delete(&self) -> anyhow::Result<()> {
166 match self {
167 Self::Path(file) => {
168 tokio::fs::remove_file(&file).await?;
170 }
171 Self::Http(url) => {
172 reqwest::Client::builder()
174 .build()?
175 .delete(url.clone())
176 .send()
177 .await?;
178 }
179 #[cfg(feature = "s3")]
180 Self::S3(s3) => {
181 let client = s3.client();
183 client
184 .await?
185 .delete_object()
186 .key(s3.key.clone().unwrap_or_default())
187 .bucket(s3.bucket.clone())
188 .send()
189 .await?;
190 }
191 }
192
193 Ok(())
194 }
195
196 pub async fn r#move(&self, path: &str) -> anyhow::Result<()> {
200 match self {
201 Self::Path(file) => {
202 let path = Path::new(&path);
203 tokio::fs::create_dir_all(path).await?;
204 tokio::fs::copy(&file, path.join(file)).await?;
205 tokio::fs::remove_file(&file).await?;
206 }
207 Self::Http(url) => {
208 log::warn!("Unable to move HTTP source ({url}), skipping!");
210 }
211 #[cfg(feature = "s3")]
212 Self::S3(s3) => {
213 let client = s3.client();
214 client
215 .await?
216 .copy_object()
217 .copy_source(s3.key.clone().unwrap_or_default())
218 .key(format!("{path}/{}", s3.key.as_deref().unwrap_or_default()))
219 .bucket(s3.bucket.clone())
220 .send()
221 .await?;
222 }
223 }
224
225 Ok(())
226 }
227}
228
229#[cfg(feature = "s3")]
230#[derive(Clone, Debug, PartialEq, Eq)]
231pub struct S3 {
232 region: String,
233 credentials: Option<(String, String)>,
234 bucket: String,
235 key: Option<String>,
236}
237
238#[cfg(feature = "s3")]
239impl TryFrom<&str> for S3 {
240 type Error = anyhow::Error;
241
242 fn try_from(value: &str) -> Result<Self, Self::Error> {
243 let uri = fluent_uri::Uri::try_from(value)?;
244
245 let Some(auth) = uri.authority() else {
246 bail!("Missing authority");
247 };
248
249 let path = uri.path().to_string();
250 let path = path.trim_start_matches('/');
251 if path.is_empty() {
252 bail!("Missing bucket");
253 }
254
255 let (bucket, key) = match path.split_once('/') {
256 Some((bucket, key)) => (bucket.to_string(), Some(key.to_string())),
257 None => (path.to_string(), None),
258 };
259
260 let region = auth.host().to_string();
261
262 let credentials = auth.userinfo().and_then(|userinfo| {
263 userinfo
264 .split_once(':')
265 .map(|(username, password)| (username.to_string(), password.to_string()))
266 });
267
268 Ok(S3 {
269 region,
270 credentials,
271 bucket,
272 key,
273 })
274 }
275}
276
277#[cfg(feature = "s3")]
278impl S3 {
279 pub async fn client(&self) -> anyhow::Result<Client> {
280 let region_provider = RegionProviderChain::first_try(Region::new(self.region.clone()));
281
282 let mut shared_config = aws_config::defaults(BehaviorVersion::latest())
283 .region(region_provider)
284 .app_name(AppName::new(crate::USER_AGENT)?);
285
286 if let Some((key_id, access_key)) = &self.credentials {
287 let credentials = Credentials::new(key_id, access_key, None, None, "config");
288 shared_config = shared_config.credentials_provider(credentials);
289 }
290
291 let shared_config = shared_config.load().await;
292
293 Ok(Client::new(&shared_config))
294 }
295}
296
297#[cfg(test)]
298mod tests {
299 #[allow(unused_imports)]
300 use super::*;
301
302 #[cfg(feature = "s3")]
303 #[test]
304 fn parse_s3() {
305 assert_eq!(
306 S3 {
307 region: "us-east-1".to_string(),
308 credentials: None,
309 bucket: "b1".to_string(),
310 key: None,
311 },
312 S3::try_from("s3://us-east-1/b1").unwrap()
313 );
314 assert_eq!(
315 S3 {
316 region: "us-east-1".to_string(),
317 credentials: Some(("foo".to_string(), "bar".to_string())),
318 bucket: "b1".to_string(),
319 key: None,
320 },
321 S3::try_from("s3://foo:bar@us-east-1/b1").unwrap()
322 );
323 assert_eq!(
324 S3 {
325 region: "us-east-1".to_string(),
326 credentials: Some(("foo".to_string(), "bar".to_string())),
327 bucket: "b1".to_string(),
328 key: Some("path/to/file".to_string()),
329 },
330 S3::try_from("s3://foo:bar@us-east-1/b1/path/to/file").unwrap()
331 );
332 }
333
334 #[cfg(feature = "s3")]
335 #[test]
336 fn parse_s3_custom_region() {
337 assert_eq!(
338 S3 {
339 region: "my.own.endpoint".to_string(),
340 credentials: None,
341 bucket: "b1".to_string(),
342 key: None,
343 },
344 S3::try_from("s3://my.own.endpoint/b1").unwrap()
345 );
346 assert_eq!(
347 S3 {
348 region: "my.own.endpoint".to_string(),
349 credentials: Some(("foo".to_string(), "bar".to_string())),
350 bucket: "b1".to_string(),
351 key: None,
352 },
353 S3::try_from("s3://foo:bar@my.own.endpoint/b1").unwrap()
354 );
355 assert_eq!(
356 S3 {
357 region: "my.own.endpoint".to_string(),
358 credentials: Some(("foo".to_string(), "bar".to_string())),
359 bucket: "b1".to_string(),
360 key: Some("path/to/file".to_string()),
361 },
362 S3::try_from("s3://foo:bar@my.own.endpoint/b1/path/to/file").unwrap()
363 );
364 }
365}