cog_rust/
spec.rs

1use anyhow::Result;
2use core::fmt::Debug;
3use mime_guess::Mime;
4use schemars::{gen::SchemaGenerator, schema::Schema, JsonSchema};
5use serde::Serialize;
6use std::{
7	env::{self, temp_dir},
8	fs::File,
9	path::PathBuf,
10	str::FromStr,
11};
12use url::Url;
13use uuid::Uuid;
14
15use crate::helpers::{base64_decode, base64_encode, url_join};
16
17#[derive(Debug)]
18pub struct Path(PathBuf);
19
20impl Path {
21	/// Create a new path from a url
22	///
23	/// # Errors
24	///
25	/// Returns an error if the url cannot be downloaded or a temporary file cannot be created.
26	pub(crate) fn new(url: &Url) -> Result<Self> {
27		if url.scheme() == "data" {
28			return Self::from_dataurl(url);
29		}
30
31		tracing::debug!("Downloading file from {url}");
32		let file_path = temp_dir().join(url.path().split('/').last().unwrap_or_else(|| url.path()));
33		let request = reqwest::blocking::get(url.as_str())?.bytes()?;
34
35		std::io::copy(&mut request.as_ref(), &mut File::create(&file_path)?)?;
36		tracing::debug!("Downloaded file to {}", file_path.display());
37
38		Ok(Self(file_path))
39	}
40
41	/// Create a new path from a data url
42	///
43	/// # Errors
44	///
45	/// Returns an error if the url cannot be decoded or a temporary file cannot be created.
46	pub(crate) fn from_dataurl(url: &Url) -> Result<Self> {
47		let data = url.path().split(',').last().unwrap_or_else(|| url.path());
48
49		let file_bytes = base64_decode(data)?;
50		let mime_type = Mime::from_str(tree_magic_mini::from_u8(&file_bytes))
51			.unwrap_or(mime_guess::mime::APPLICATION_OCTET_STREAM);
52		let file_ext = mime_guess::get_mime_extensions(&mime_type)
53			.and_then(<[&str]>::last)
54			.map_or_else(String::new, |e| format!(".{e}"));
55
56		let file_path = temp_dir().join(format!("{}{file_ext}", Uuid::new_v4()));
57
58		std::fs::write(&file_path, file_bytes)?;
59		Ok(Self(file_path))
60	}
61
62	/// PUT the file to the given endpoint and return the url
63	///
64	/// # Errors
65	///
66	/// Returns an error if the file cannot be read or the upload fails.
67	///
68	/// # Panics
69	///
70	/// Panics if the file name is not valid unicode.
71	pub(crate) fn upload_put(&self, upload_url: &Url) -> Result<String> {
72		let url = url_join(upload_url, self.0.file_name().unwrap().to_str().unwrap());
73		tracing::debug!("Uploading file to {url}");
74
75		let file_bytes = std::fs::read(&self.0)?;
76		let mime_type = tree_magic_mini::from_u8(&file_bytes);
77
78		let response = reqwest::blocking::Client::new()
79			.put(url.clone())
80			.header("Content-Type", mime_type)
81			.body(file_bytes)
82			.send()?;
83
84		if !response.status().is_success() {
85			anyhow::bail!(
86				"Failed to upload file to {url}: got {}. {}",
87				response.status(),
88				response.text().unwrap_or_default()
89			);
90		}
91
92		let mut url = response.url().clone();
93		url.set_query(None);
94
95		tracing::debug!("Uploaded file to {url}");
96		Ok(url.to_string())
97	}
98
99	/// Convert the file to a data url
100	///
101	/// # Errors
102	///
103	/// Returns an error if the file cannot be read.
104	pub(crate) fn to_dataurl(&self) -> Result<String> {
105		let file_bytes = std::fs::read(&self.0)?;
106		let mime_type = tree_magic_mini::from_u8(&file_bytes);
107
108		Ok(format!(
109			"data:{mime_type};base64,{base64}",
110			base64 = base64_encode(&file_bytes)
111		))
112	}
113}
114
115impl AsRef<std::path::Path> for Path {
116	fn as_ref(&self) -> &std::path::Path {
117		self.0.as_ref()
118	}
119}
120
121impl JsonSchema for Path {
122	fn schema_name() -> String {
123		"Path".to_string()
124	}
125
126	fn json_schema(gen: &mut SchemaGenerator) -> Schema {
127		Url::json_schema(gen)
128	}
129}
130
131impl Drop for Path {
132	fn drop(&mut self) {
133		tracing::debug!("Removing temporary file at path {:?}", self.0);
134
135		std::fs::remove_file(&self.0).unwrap();
136	}
137}
138
139impl<'de> serde::Deserialize<'de> for Path {
140	fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
141	where
142		D: serde::Deserializer<'de>,
143	{
144		let url = String::deserialize(deserializer)?;
145
146		Self::new(&Url::parse(&url).map_err(serde::de::Error::custom)?)
147			.map_err(serde::de::Error::custom)
148	}
149}
150
151impl Serialize for Path {
152	fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
153	where
154		S: serde::Serializer,
155	{
156		let url = env::var("UPLOAD_URL")
157			.map(|url| url.parse().ok())
158			.ok()
159			.flatten()
160			.map_or_else(
161				|| self.to_dataurl(),
162				|upload_url| self.upload_put(&upload_url),
163			);
164
165		serializer.serialize_str(&url.map_err(serde::ser::Error::custom)?)
166	}
167}
168
169impl From<PathBuf> for Path {
170	fn from(path: PathBuf) -> Self {
171		Self(path)
172	}
173}
174
175#[cfg(test)]
176mod tests {
177	use super::*;
178	use serde_json::json;
179
180	#[derive(Debug, serde::Deserialize)]
181	struct StructWithPath {
182		file: Path,
183	}
184
185	#[test]
186	fn test_path_deserialize() {
187		let r#struct: StructWithPath = serde_json::from_value(json!({
188			"file": "https://raw.githubusercontent.com/m1guelpf/cog-rust/main/README.md"
189		}))
190		.unwrap();
191
192		let path = r#struct.file;
193		let underlying_path = path.0.clone();
194
195		assert!(
196			underlying_path.exists(),
197			"File does not exist at path {:?}",
198			path.0
199		);
200		assert!(
201			underlying_path.metadata().unwrap().len() > 0,
202			"File is empty"
203		);
204
205		drop(path);
206
207		assert!(
208			!underlying_path.exists(),
209			"File still exists at path {underlying_path:?}",
210		);
211	}
212
213	#[test]
214	fn test_dataurl_serialize() {
215		let r#struct: StructWithPath = serde_json::from_value(json!({
216			"file": "https://upload.wikimedia.org/wikipedia/commons/thumb/1/1b/Square_200x200.png/120px-Square_200x200.png"
217		}))
218		.unwrap();
219
220		let path = r#struct.file;
221		let dataurl = path.to_dataurl().unwrap();
222
223		assert!(dataurl.starts_with("data:image/png;base64,"));
224	}
225}