rsql_driver_https/
driver.rs

1use async_trait::async_trait;
2use file_type::FileType;
3use futures_util::StreamExt;
4use reqwest::header::HeaderMap;
5use rsql_driver::Error::{ConversionError, IoError};
6use rsql_driver::{DriverManager, Result};
7use std::collections::HashMap;
8use std::fs::create_dir_all;
9use std::path::{Path, PathBuf};
10use tempfile::TempDir;
11use tokio::fs::File;
12use tokio::io::AsyncWriteExt;
13use tracing::debug;
14use url::Url;
15
16#[derive(Debug)]
17pub struct Driver;
18
19#[async_trait]
20impl rsql_driver::Driver for Driver {
21    fn identifier(&self) -> &'static str {
22        "https"
23    }
24
25    async fn connect(&self, url: &str) -> Result<Box<dyn rsql_driver::Connection>> {
26        let temp_dir = TempDir::new()?;
27        let (request_headers, file_path, file_type, response_headers) =
28            self.retrieve_file(url, temp_dir.path()).await?;
29        let file_path = file_path.to_string_lossy().to_string();
30        #[cfg(target_os = "windows")]
31        let file_path = file_path.replace(':', "%3A").replace('\\', "/");
32
33        debug!("temp_dir: {temp_dir:?}; file_path: {file_path}");
34        let driver = DriverManager::get_by_file_type(file_type)?;
35        match driver {
36            Some(driver) => {
37                let (_url, parameters) = url.split_once('?').unwrap_or((url, ""));
38                let url = format!("{}://{file_path}?{parameters}", driver.identifier());
39                let mut connection = driver.connect(url.as_str()).await?;
40                create_header_tables(&mut connection, &request_headers, &response_headers).await?;
41                Ok(connection)
42            }
43            None => Err(IoError(format!(
44                "{file_path:?}: {:?}",
45                file_type.media_types()
46            ))),
47        }
48    }
49
50    fn supports_file_type(&self, _file_type: &FileType) -> bool {
51        false
52    }
53}
54
55impl Driver {
56    async fn retrieve_file(
57        &self,
58        url: &str,
59        temp_dir: &Path,
60    ) -> Result<(
61        HashMap<String, String>,
62        PathBuf,
63        &FileType,
64        HashMap<String, String>,
65    )> {
66        let mut parsed_url = Url::parse(url)?;
67        let file_path = PathBuf::from(parsed_url.path());
68        // Extract the last segment of the path as a file name
69        let file_name = match file_path.file_name() {
70            Some(file_name) => file_name.to_string_lossy().to_string(),
71            None => "response".to_string(),
72        };
73
74        let mut request_headers: HashMap<String, String> =
75            parsed_url.query_pairs().into_owned().collect();
76        if let Some(headers) = request_headers.remove("_headers") {
77            // Split individual headers by ; with key=value pairs
78            let headers = headers
79                .split(';')
80                .map(|header| {
81                    let mut parts = header.split('=');
82                    let key = parts.next().unwrap_or_default().to_string();
83                    let value = parts.next().unwrap_or_default().to_string();
84                    (key, value)
85                })
86                .collect::<HashMap<String, String>>();
87            request_headers.extend(headers);
88        }
89
90        parsed_url.set_query(None);
91        let url = parsed_url.to_string();
92        let parameters: HashMap<&str, &str> = request_headers
93            .iter()
94            .map(|(k, v)| (k.as_str(), v.as_str()))
95            .collect();
96        let parsed_url = Url::parse_with_params(url.as_str(), parameters)?;
97
98        if !request_headers
99            .keys()
100            .any(|key| key.eq_ignore_ascii_case("user-agent"))
101        {
102            let package_name = env!("CARGO_PKG_NAME");
103            let version = env!("CARGO_PKG_VERSION");
104            let os = std::env::consts::OS;
105            let arch = std::env::consts::ARCH;
106            let user_agent = format!("{package_name}/{version} ({os}; {arch})");
107            request_headers.insert("User-Agent".to_string(), user_agent);
108        }
109
110        let header_map: HeaderMap = (&request_headers)
111            .try_into()
112            .map_err(|_| ConversionError("MalformedHeaders".into()))?;
113        let client = reqwest::ClientBuilder::new()
114            .default_headers(header_map)
115            .build()
116            .map_err(|error| IoError(error.to_string()))?;
117
118        let response = client
119            .get(parsed_url.as_str())
120            .send()
121            .await
122            .map_err(|error| IoError(error.to_string()))?;
123        let response_headers = response.headers();
124        let response_headers: HashMap<String, String> = response_headers
125            .iter()
126            .map(|(key, value)| {
127                (
128                    key.as_str().to_string(),
129                    value.to_str().unwrap_or_default().to_string(),
130                )
131            })
132            .collect();
133        let content_type = response_headers
134            .iter()
135            .find(|(key, _value)| key.eq_ignore_ascii_case("content-type"))
136            .map(|(_key, value)| value.split(';').next().unwrap_or_default())
137            .unwrap_or_default();
138        create_dir_all(temp_dir)?;
139        let file_path = temp_dir.join(file_name);
140        let mut file = File::create_new(&file_path)
141            .await
142            .map_err(|error| IoError(error.to_string()))?;
143        let mut stream = response.bytes_stream();
144        while let Some(item) = stream.next().await {
145            let item = item.map_err(|error| IoError(error.to_string()))?;
146            file.write_all(&item)
147                .await
148                .map_err(|error| IoError(error.to_string()))?;
149        }
150
151        let file_type = Self::file_type(content_type, &file_path)?;
152        Ok((request_headers, file_path, file_type, response_headers))
153    }
154
155    fn file_type(content_type: &str, file_path: &PathBuf) -> Result<&'static FileType> {
156        // Ignore generic content types and try to determine the file type from the extension
157        // or bytes
158        let content_type = content_type.trim().to_lowercase();
159        if !["text/plain", "application/octet-stream"].contains(&content_type.as_str()) {
160            let file_types = FileType::from_media_type(content_type.to_lowercase());
161            if !file_types.is_empty() {
162                if let Some(file_type) = file_types.first() {
163                    return Ok(file_type);
164                }
165            }
166        }
167        let file_type =
168            FileType::try_from_file(file_path).map_err(|error| IoError(error.to_string()))?;
169        Ok(file_type)
170    }
171}
172
173async fn create_header_tables(
174    connection: &mut Box<dyn rsql_driver::Connection>,
175    request_headers: &HashMap<String, String>,
176    response_headers: &HashMap<String, String>,
177) -> Result<()> {
178    let request_header_sql = create_table_sql("request_headers", request_headers);
179    connection.execute(&request_header_sql).await?;
180    let response_header_sql = create_table_sql("response_headers", response_headers);
181    connection.execute(&response_header_sql).await?;
182    Ok(())
183}
184
185fn create_table_sql(table_name: &str, headers: &HashMap<String, String>) -> String {
186    let columns = headers
187        .iter()
188        .map(|(key, value)| {
189            let key = key.replace('\'', "''").to_lowercase();
190            let value = value.replace('\'', "''");
191            format!("SELECT '{key}' AS \"header\", '{value}' AS \"value\"")
192        })
193        .collect::<Vec<String>>()
194        .join(" UNION ");
195    format!("CREATE TABLE {table_name} AS {columns}")
196}