use dragonfly_api::common::v2::{Hdfs, ObjectStorage, Range};
use dragonfly_client_config::dfdaemon::Config;
use dragonfly_client_core::{
error::{ErrorType, OrErr},
Error, Result,
};
use libloading::Library;
use reqwest::header::HeaderMap;
use rustls_pki_types::CertificateDer;
use std::path::Path;
use std::path::PathBuf;
use std::sync::Arc;
use std::{collections::HashMap, pin::Pin, time::Duration};
use std::{fmt::Debug, fs};
use tokio::io::{AsyncRead, AsyncReadExt};
use tracing::{error, info, warn};
use url::Url;
pub mod hdfs;
pub mod http;
pub mod object_storage;
const POOL_MAX_IDLE_PER_HOST: usize = 1024;
const KEEP_ALIVE_INTERVAL: Duration = Duration::from_secs(60);
const HTTP2_KEEP_ALIVE_INTERVAL: Duration = Duration::from_secs(300);
const HTTP2_KEEP_ALIVE_TIMEOUT: Duration = Duration::from_secs(20);
const HTTP2_STREAM_WINDOW_SIZE: u32 = 16 * 1024 * 1024;
const HTTP2_CONNECTION_WINDOW_SIZE: u32 = 16 * 1024 * 1024;
const MAX_RETRY_TIMES: u32 = 1;
pub const NAME: &str = "backend";
pub type Body = Box<dyn AsyncRead + Send + Unpin>;
pub struct StatRequest {
pub task_id: String,
pub url: String,
pub http_header: Option<HeaderMap>,
pub timeout: Duration,
pub client_cert: Option<Vec<CertificateDer<'static>>>,
pub object_storage: Option<ObjectStorage>,
pub hdfs: Option<Hdfs>,
}
#[derive(Debug)]
pub struct StatResponse {
pub success: bool,
pub content_length: Option<u64>,
pub http_header: Option<HeaderMap>,
pub http_status_code: Option<reqwest::StatusCode>,
pub entries: Vec<DirEntry>,
pub error_message: Option<String>,
}
#[derive(Debug, Clone)]
pub struct GetRequest {
pub task_id: String,
pub piece_id: String,
pub url: String,
pub range: Option<Range>,
pub http_header: Option<HeaderMap>,
pub timeout: Duration,
pub client_cert: Option<Vec<CertificateDer<'static>>>,
pub object_storage: Option<ObjectStorage>,
pub hdfs: Option<Hdfs>,
}
pub struct GetResponse<R>
where
R: AsyncRead + Unpin,
{
pub success: bool,
pub http_header: Option<HeaderMap>,
pub http_status_code: Option<reqwest::StatusCode>,
pub reader: R,
pub error_message: Option<String>,
}
impl<R> GetResponse<R>
where
R: AsyncRead + Unpin,
{
pub async fn text(&mut self) -> Result<String> {
let mut buffer = String::new();
Pin::new(&mut self.reader)
.read_to_string(&mut buffer)
.await?;
Ok(buffer)
}
}
#[derive(Debug, PartialEq, Eq, Hash, Clone)]
pub struct DirEntry {
pub url: String,
pub content_length: usize,
pub is_dir: bool,
}
pub struct ExistsRequest {
pub task_id: String,
pub url: String,
pub http_header: Option<HeaderMap>,
pub timeout: Duration,
pub client_cert: Option<Vec<CertificateDer<'static>>>,
pub object_storage: Option<ObjectStorage>,
pub hdfs: Option<Hdfs>,
}
pub struct PutRequest {
pub task_id: String,
pub url: String,
pub path: PathBuf,
pub http_header: Option<HeaderMap>,
pub timeout: Duration,
pub client_cert: Option<Vec<CertificateDer<'static>>>,
pub object_storage: Option<ObjectStorage>,
pub hdfs: Option<Hdfs>,
}
#[derive(Debug)]
pub struct PutResponse {
pub success: bool,
pub content_length: Option<u64>,
pub http_header: Option<HeaderMap>,
pub http_status_code: Option<reqwest::StatusCode>,
pub error_message: Option<String>,
}
#[tonic::async_trait]
pub trait Backend {
fn scheme(&self) -> String;
async fn stat(&self, request: StatRequest) -> Result<StatResponse>;
async fn get(&self, request: GetRequest) -> Result<GetResponse<Body>>;
async fn put(&self, request: PutRequest) -> Result<PutResponse>;
async fn exists(&self, request: ExistsRequest) -> Result<bool>;
}
#[derive(Default)]
pub struct BackendFactory {
config: Arc<Config>,
backends: HashMap<String, Box<dyn Backend + Send + Sync>>,
libraries: Vec<Library>,
}
impl BackendFactory {
pub fn new(config: Arc<Config>, plugin_dir: Option<&Path>) -> Result<Self> {
let mut backend_factory = Self {
config: config.clone(),
backends: HashMap::new(),
libraries: Vec::new(),
};
backend_factory.load_builtin_backends(
config.backend.enable_cache_temporary_redirect,
config.backend.cache_temporary_redirect_ttl,
)?;
if let Some(plugin_dir) = plugin_dir {
backend_factory
.load_plugin_backends(plugin_dir)
.inspect_err(|err| {
error!("failed to load plugin backends: {}", err);
})?;
}
Ok(backend_factory)
}
pub fn unsupported_download_directory(scheme: &str) -> bool {
scheme == http::HTTP_SCHEME || scheme == http::HTTPS_SCHEME
}
pub fn build(&self, url: &str) -> Result<&(dyn Backend + Send + Sync)> {
let url = Url::parse(url).or_err(ErrorType::ParseError)?;
let scheme = url.scheme();
self.backends
.get(scheme)
.map(|boxed_backend| &**boxed_backend)
.ok_or(Error::InvalidParameter)
.inspect_err(|_err| {
error!("unsupported backend scheme: {}", scheme);
})
}
fn load_builtin_backends(
&mut self,
enable_cache_temporary_redirect: bool,
cache_temporary_redirect_ttl: Duration,
) -> Result<()> {
self.backends.insert(
"http".to_string(),
Box::new(http::HTTP::new(
http::HTTP_SCHEME,
self.config.backend.clone().request_header,
enable_cache_temporary_redirect,
cache_temporary_redirect_ttl,
)?),
);
info!("load [http] builtin backend");
self.backends.insert(
"https".to_string(),
Box::new(http::HTTP::new(
http::HTTPS_SCHEME,
self.config.backend.clone().request_header,
enable_cache_temporary_redirect,
cache_temporary_redirect_ttl,
)?),
);
info!("load [https] builtin backend");
self.backends.insert(
"s3".to_string(),
Box::new(object_storage::ObjectStorage::new(
object_storage::Scheme::S3,
self.config.clone(),
)?),
);
info!("load [s3] builtin backend");
self.backends.insert(
"gs".to_string(),
Box::new(object_storage::ObjectStorage::new(
object_storage::Scheme::GCS,
self.config.clone(),
)?),
);
info!("load [gcs] builtin backend");
self.backends.insert(
"abs".to_string(),
Box::new(object_storage::ObjectStorage::new(
object_storage::Scheme::ABS,
self.config.clone(),
)?),
);
info!("load [abs] builtin backend");
self.backends.insert(
"oss".to_string(),
Box::new(object_storage::ObjectStorage::new(
object_storage::Scheme::OSS,
self.config.clone(),
)?),
);
info!("load [oss] builtin backend");
self.backends.insert(
"obs".to_string(),
Box::new(object_storage::ObjectStorage::new(
object_storage::Scheme::OBS,
self.config.clone(),
)?),
);
info!("load [obs] builtin backend");
self.backends.insert(
"cos".to_string(),
Box::new(object_storage::ObjectStorage::new(
object_storage::Scheme::COS,
self.config.clone(),
)?),
);
info!("load [cos] builtin backend");
self.backends
.insert("hdfs".to_string(), Box::new(hdfs::Hdfs::new()));
info!("load [hdfs] builtin backend");
Ok(())
}
fn load_plugin_backends(&mut self, plugin_dir: &Path) -> Result<()> {
let backend_plugin_dir = plugin_dir.join(NAME);
if !backend_plugin_dir.exists() {
warn!(
"skip loading plugin backends, because the plugin directory {} does not exist",
backend_plugin_dir.display()
);
return Ok(());
}
for entry in fs::read_dir(backend_plugin_dir)? {
let path = entry?.path();
unsafe {
self.libraries
.push(Library::new(path.as_os_str()).or_err(ErrorType::PluginError)?);
let lib = &self.libraries[self.libraries.len() - 1];
let register_plugin: libloading::Symbol<
unsafe extern "C" fn() -> Box<dyn Backend + Send + Sync>,
> = lib.get(b"register_plugin").or_err(ErrorType::PluginError)?;
if let Some(file_stem) = path.file_stem() {
if let Some(plugin_name) =
file_stem.to_string_lossy().to_string().strip_prefix("lib")
{
self.backends
.insert(plugin_name.to_string(), register_plugin());
info!("load [{}] plugin backend", plugin_name);
}
}
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::tempdir;
#[test]
fn should_create_backend_factory_without_plugin_dir() {
let result = BackendFactory::new(Arc::new(Config::default()), None);
assert!(result.is_ok());
}
#[test]
fn should_load_builtin_backends() {
let factory = BackendFactory::new(Arc::new(Config::default()), None).unwrap();
let expected_backends = vec![
"http", "https", "s3", "gs", "abs", "oss", "obs", "cos", "hdfs",
];
for backend in expected_backends {
assert!(factory.backends.contains_key(backend));
}
}
#[test]
fn should_load_plugin_backends() {
let dir = tempdir().unwrap();
let plugin_dir = dir.path().join("plugin");
std::fs::create_dir(&plugin_dir).unwrap();
let backend_dir = plugin_dir.join(NAME);
std::fs::create_dir(&backend_dir).unwrap();
build_example_plugin(&backend_dir);
let result = BackendFactory::new(Arc::new(Config::default()), Some(&plugin_dir));
assert!(result.is_ok());
let factory = result.unwrap();
assert!(factory.backends.contains_key("hdfs"));
}
#[test]
fn should_skip_loading_plugins_when_plugin_dir_is_invalid() {
let dir = tempdir().unwrap();
let plugin_dir = dir.path().join("non_existent_plugin_dir");
let factory = BackendFactory::new(Arc::new(Config::default()), Some(&plugin_dir)).unwrap();
assert_eq!(factory.backends.len(), 9);
}
#[test]
fn should_return_error_when_plugin_loading_fails() {
let dir = tempdir().unwrap();
let plugin_dir = dir.path().join("plugin");
std::fs::create_dir(&plugin_dir).unwrap();
let backend_dir = plugin_dir.join(NAME);
std::fs::create_dir(&backend_dir).unwrap();
let lib_path = backend_dir.join("libinvalid_plugin.so");
std::fs::write(&lib_path, b"invalid content").unwrap();
let result = BackendFactory::new(Arc::new(Config::default()), Some(&plugin_dir));
assert!(result.is_err());
let err_msg = format!("{}", result.err().unwrap());
assert!(
err_msg.starts_with("PluginError cause:"),
"error message should start with 'PluginError cause:'"
);
assert!(
err_msg.contains(&lib_path.display().to_string()),
"error message should contain library path"
);
}
#[test]
fn should_build_correct_backend() {
let dir = tempdir().unwrap();
let plugin_dir = dir.path().join("plugin");
std::fs::create_dir(&plugin_dir).unwrap();
let backend_dir = plugin_dir.join(NAME);
std::fs::create_dir(&backend_dir).unwrap();
build_example_plugin(&backend_dir);
let factory = BackendFactory::new(Arc::new(Config::default()), Some(&plugin_dir)).unwrap();
let schemes = vec![
"http", "https", "s3", "gs", "abs", "oss", "obs", "cos", "hdfs",
];
for scheme in schemes {
let result = factory.build(&format!("{}://example.com/key", scheme));
assert!(result.is_ok());
let backend = result.unwrap();
assert_eq!(backend.scheme(), scheme);
}
}
#[test]
fn should_return_error_when_backend_scheme_is_not_support() {
let factory = BackendFactory::new(Arc::new(Config::default()), None).unwrap();
let result = factory.build("github://example.com");
assert!(result.is_err());
assert_eq!(format!("{}", result.err().unwrap()), "invalid parameter");
}
#[test]
fn should_return_error_when_backend_scheme_is_invalid() {
let factory = BackendFactory::new(Arc::new(Config::default()), None).unwrap();
let result = factory.build("invalid_scheme://example.com");
assert!(result.is_err());
assert_eq!(
format!("{}", result.err().unwrap()),
"ParseError cause: relative URL without a base",
);
}
fn build_example_plugin(backend_dir: &Path) {
let status = std::process::Command::new("cargo")
.arg("build")
.current_dir("./examples/plugin")
.status()
.unwrap();
assert!(status.success());
let plugin_file = if cfg!(target_os = "macos") {
"libhdfs.dylib"
} else {
"libhdfs.so"
};
std::fs::rename(
format!("../target/debug/{}", plugin_file),
backend_dir.join(plugin_file),
)
.unwrap();
}
}