use async_trait::async_trait;
use file_type::FileType;
use futures_util::StreamExt;
use reqwest::header::HeaderMap;
use rsql_driver::Error::{ConversionError, IoError};
use rsql_driver::{DriverManager, Result};
use std::collections::HashMap;
use std::fs::create_dir_all;
use std::path::{Path, PathBuf};
use tempfile::TempDir;
use tokio::io::AsyncWriteExt;
use tracing::debug;
use url::Url;
#[derive(Debug)]
pub struct Driver;
#[async_trait]
impl rsql_driver::Driver for Driver {
fn identifier(&self) -> &'static str {
"https"
}
async fn connect(&self, url: &str) -> Result<Box<dyn rsql_driver::Connection>> {
let temp_dir = TempDir::new()?;
let (request_headers, file_path, file_type, response_headers) =
self.retrieve_file(url, temp_dir.path()).await?;
let file_path = file_path.to_string_lossy().to_string();
#[cfg(target_os = "windows")]
let file_path = file_path.replace(':', "%3A").replace('\\', "/");
debug!("temp_dir: {temp_dir:?}; file_path: {file_path}");
let driver = DriverManager::get_by_file_type(file_type)?;
match driver {
Some(driver) => {
let (_url, parameters) = url.split_once('?').unwrap_or((url, ""));
let url = format!("{}://{file_path}?{parameters}", driver.identifier());
let mut connection = driver.connect(url.as_str()).await?;
create_header_tables(&mut connection, &request_headers, &response_headers).await?;
Ok(connection)
}
None => Err(IoError(format!(
"{file_path:?}: {:?}",
file_type.media_types()
))),
}
}
fn supports_file_type(&self, _file_type: &FileType) -> bool {
false
}
}
impl Driver {
async fn retrieve_file(
&self,
url: &str,
temp_dir: &Path,
) -> Result<(
HashMap<String, String>,
PathBuf,
&FileType,
HashMap<String, String>,
)> {
let mut parsed_url = Url::parse(url)?;
let file_path = PathBuf::from(parsed_url.path());
let file_name = match file_path.file_name() {
Some(file_name) => file_name.to_string_lossy().to_string(),
None => "response".to_string(),
};
let mut request_headers: HashMap<String, String> =
parsed_url.query_pairs().into_owned().collect();
if let Some(headers) = request_headers.remove("_headers") {
let headers = headers
.split(';')
.map(|header| {
let mut parts = header.split('=');
let key = parts.next().unwrap_or_default().to_string();
let value = parts.next().unwrap_or_default().to_string();
(key, value)
})
.collect::<HashMap<String, String>>();
request_headers.extend(headers);
}
parsed_url.set_query(None);
let url = parsed_url.to_string();
let parameters: HashMap<&str, &str> = request_headers
.iter()
.map(|(k, v)| (k.as_str(), v.as_str()))
.collect();
let parsed_url = Url::parse_with_params(url.as_str(), parameters)?;
if !request_headers
.keys()
.any(|key| key.eq_ignore_ascii_case("user-agent"))
{
let package_name = env!("CARGO_PKG_NAME");
let version = env!("CARGO_PKG_VERSION");
let os = std::env::consts::OS;
let arch = std::env::consts::ARCH;
let user_agent = format!("{package_name}/{version} ({os}; {arch})");
request_headers.insert("User-Agent".to_string(), user_agent);
}
let header_map: HeaderMap = (&request_headers)
.try_into()
.map_err(|_| ConversionError("MalformedHeaders".into()))?;
let client = reqwest::ClientBuilder::new()
.default_headers(header_map)
.build()
.map_err(|error| IoError(error.to_string()))?;
let response = client
.get(parsed_url.as_str())
.send()
.await
.map_err(|error| IoError(error.to_string()))?;
let response_headers = response.headers();
let response_headers: HashMap<String, String> = response_headers
.iter()
.map(|(key, value)| {
(
key.as_str().to_string(),
value.to_str().unwrap_or_default().to_string(),
)
})
.collect();
let content_type = response_headers
.iter()
.find(|(key, _value)| key.eq_ignore_ascii_case("content-type"))
.map(|(_key, value)| value.split(';').next().unwrap_or_default())
.unwrap_or_default();
create_dir_all(temp_dir)?;
let file_path = temp_dir.join(file_name);
let mut file = tokio::fs::File::create_new(&file_path)
.await
.map_err(|error| IoError(error.to_string()))?;
let mut stream = response.bytes_stream();
while let Some(item) = stream.next().await {
let item = item.map_err(|error| IoError(error.to_string()))?;
file.write_all(&item)
.await
.map_err(|error| IoError(error.to_string()))?;
}
let file_type = Self::file_type(content_type, &file_path)?;
Ok((request_headers, file_path, file_type, response_headers))
}
fn file_type(content_type: &str, file_path: &PathBuf) -> Result<&'static FileType> {
let content_type = content_type.trim().to_lowercase();
if !["text/plain", "application/octet-stream"].contains(&content_type.as_str()) {
let file_types = FileType::from_media_type(content_type.to_lowercase());
if !file_types.is_empty()
&& let Some(file_type) = file_types.first()
{
return Ok(file_type);
}
}
let file_type =
FileType::try_from_file(file_path).map_err(|error| IoError(error.to_string()))?;
Ok(file_type)
}
}
async fn create_header_tables(
connection: &mut Box<dyn rsql_driver::Connection>,
request_headers: &HashMap<String, String>,
response_headers: &HashMap<String, String>,
) -> Result<()> {
let request_header_sql = create_table_sql("request_headers", request_headers);
connection.execute(&request_header_sql, &[]).await?;
let response_header_sql = create_table_sql("response_headers", response_headers);
connection.execute(&response_header_sql, &[]).await?;
Ok(())
}
fn create_table_sql(table_name: &str, headers: &HashMap<String, String>) -> String {
let columns = headers
.iter()
.map(|(key, value)| {
let key = key.replace('\'', "''").to_lowercase();
let value = value.replace('\'', "''");
format!("SELECT '{key}' AS \"header\", '{value}' AS \"value\"")
})
.collect::<Vec<String>>()
.join(" UNION ");
format!("CREATE TABLE {table_name} AS {columns}")
}