rsql_driver_https/
driver.rs1use async_trait::async_trait;
2use file_type::FileType;
3use futures_util::StreamExt;
4use reqwest::header::HeaderMap;
5use rsql_driver::Error::{ConversionError, IoError};
6use rsql_driver::{DriverManager, Result};
7use std::collections::HashMap;
8use std::fs::create_dir_all;
9use std::path::{Path, PathBuf};
10use tempfile::TempDir;
11use tokio::fs::File;
12use tokio::io::AsyncWriteExt;
13use tracing::debug;
14use url::Url;
15
16#[derive(Debug)]
17pub struct Driver;
18
19#[async_trait]
20impl rsql_driver::Driver for Driver {
21 fn identifier(&self) -> &'static str {
22 "https"
23 }
24
25 async fn connect(&self, url: &str) -> Result<Box<dyn rsql_driver::Connection>> {
26 let temp_dir = TempDir::new()?;
27 let (request_headers, file_path, file_type, response_headers) =
28 self.retrieve_file(url, temp_dir.path()).await?;
29 let file_path = file_path.to_string_lossy().to_string();
30 #[cfg(target_os = "windows")]
31 let file_path = file_path.replace(':', "%3A").replace('\\', "/");
32
33 debug!("temp_dir: {temp_dir:?}; file_path: {file_path}");
34 let driver = DriverManager::get_by_file_type(file_type)?;
35 match driver {
36 Some(driver) => {
37 let (_url, parameters) = url.split_once('?').unwrap_or((url, ""));
38 let url = format!("{}://{file_path}?{parameters}", driver.identifier());
39 let mut connection = driver.connect(url.as_str()).await?;
40 create_header_tables(&mut connection, &request_headers, &response_headers).await?;
41 Ok(connection)
42 }
43 None => Err(IoError(format!(
44 "{file_path:?}: {:?}",
45 file_type.media_types()
46 ))),
47 }
48 }
49
50 fn supports_file_type(&self, _file_type: &FileType) -> bool {
51 false
52 }
53}
54
55impl Driver {
56 async fn retrieve_file(
57 &self,
58 url: &str,
59 temp_dir: &Path,
60 ) -> Result<(
61 HashMap<String, String>,
62 PathBuf,
63 &FileType,
64 HashMap<String, String>,
65 )> {
66 let mut parsed_url = Url::parse(url)?;
67 let file_path = PathBuf::from(parsed_url.path());
68 let file_name = match file_path.file_name() {
70 Some(file_name) => file_name.to_string_lossy().to_string(),
71 None => "response".to_string(),
72 };
73
74 let mut request_headers: HashMap<String, String> =
75 parsed_url.query_pairs().into_owned().collect();
76 if let Some(headers) = request_headers.remove("_headers") {
77 let headers = headers
79 .split(';')
80 .map(|header| {
81 let mut parts = header.split('=');
82 let key = parts.next().unwrap_or_default().to_string();
83 let value = parts.next().unwrap_or_default().to_string();
84 (key, value)
85 })
86 .collect::<HashMap<String, String>>();
87 request_headers.extend(headers);
88 }
89
90 parsed_url.set_query(None);
91 let url = parsed_url.to_string();
92 let parameters: HashMap<&str, &str> = request_headers
93 .iter()
94 .map(|(k, v)| (k.as_str(), v.as_str()))
95 .collect();
96 let parsed_url = Url::parse_with_params(url.as_str(), parameters)?;
97
98 if !request_headers
99 .keys()
100 .any(|key| key.eq_ignore_ascii_case("user-agent"))
101 {
102 let package_name = env!("CARGO_PKG_NAME");
103 let version = env!("CARGO_PKG_VERSION");
104 let os = std::env::consts::OS;
105 let arch = std::env::consts::ARCH;
106 let user_agent = format!("{package_name}/{version} ({os}; {arch})");
107 request_headers.insert("User-Agent".to_string(), user_agent);
108 }
109
110 let header_map: HeaderMap = (&request_headers)
111 .try_into()
112 .map_err(|_| ConversionError("MalformedHeaders".into()))?;
113 let client = reqwest::ClientBuilder::new()
114 .default_headers(header_map)
115 .build()
116 .map_err(|error| IoError(error.to_string()))?;
117
118 let response = client
119 .get(parsed_url.as_str())
120 .send()
121 .await
122 .map_err(|error| IoError(error.to_string()))?;
123 let response_headers = response.headers();
124 let response_headers: HashMap<String, String> = response_headers
125 .iter()
126 .map(|(key, value)| {
127 (
128 key.as_str().to_string(),
129 value.to_str().unwrap_or_default().to_string(),
130 )
131 })
132 .collect();
133 let content_type = response_headers
134 .iter()
135 .find(|(key, _value)| key.eq_ignore_ascii_case("content-type"))
136 .map(|(_key, value)| value.split(';').next().unwrap_or_default())
137 .unwrap_or_default();
138 create_dir_all(temp_dir)?;
139 let file_path = temp_dir.join(file_name);
140 let mut file = File::create_new(&file_path)
141 .await
142 .map_err(|error| IoError(error.to_string()))?;
143 let mut stream = response.bytes_stream();
144 while let Some(item) = stream.next().await {
145 let item = item.map_err(|error| IoError(error.to_string()))?;
146 file.write_all(&item)
147 .await
148 .map_err(|error| IoError(error.to_string()))?;
149 }
150
151 let file_type = Self::file_type(content_type, &file_path)?;
152 Ok((request_headers, file_path, file_type, response_headers))
153 }
154
155 fn file_type(content_type: &str, file_path: &PathBuf) -> Result<&'static FileType> {
156 let content_type = content_type.trim().to_lowercase();
159 if !["text/plain", "application/octet-stream"].contains(&content_type.as_str()) {
160 let file_types = FileType::from_media_type(content_type.to_lowercase());
161 if !file_types.is_empty() {
162 if let Some(file_type) = file_types.first() {
163 return Ok(file_type);
164 }
165 }
166 }
167 let file_type =
168 FileType::try_from_file(file_path).map_err(|error| IoError(error.to_string()))?;
169 Ok(file_type)
170 }
171}
172
173async fn create_header_tables(
174 connection: &mut Box<dyn rsql_driver::Connection>,
175 request_headers: &HashMap<String, String>,
176 response_headers: &HashMap<String, String>,
177) -> Result<()> {
178 let request_header_sql = create_table_sql("request_headers", request_headers);
179 connection.execute(&request_header_sql).await?;
180 let response_header_sql = create_table_sql("response_headers", response_headers);
181 connection.execute(&response_header_sql).await?;
182 Ok(())
183}
184
185fn create_table_sql(table_name: &str, headers: &HashMap<String, String>) -> String {
186 let columns = headers
187 .iter()
188 .map(|(key, value)| {
189 let key = key.replace('\'', "''").to_lowercase();
190 let value = value.replace('\'', "''");
191 format!("SELECT '{key}' AS \"header\", '{value}' AS \"value\"")
192 })
193 .collect::<Vec<String>>()
194 .join(" UNION ");
195 format!("CREATE TABLE {table_name} AS {columns}")
196}