use std::sync::{Arc, Mutex as StdMutex};
use reqwest::header::{HeaderMap, HeaderValue, ACCEPT, RANGE};
use crate::download_trait::{BreakpointDownload, DownloadHeadCtx, DownloadRangeGetCtx};
use crate::error::{InnerErrorCode, MeowError};
use crate::TransferTask;
use super::time::now_unix_secs;
use super::{PresignedDownloadUrlRefresher, PresignedRangeDownloadPlan};
#[derive(Clone)]
pub struct PresignedRangeDownload {
plan: Arc<StdMutex<PresignedRangeDownloadPlan>>,
url_refresher: Option<Arc<dyn PresignedDownloadUrlRefresher>>,
}
impl PresignedRangeDownload {
pub fn new(plan: PresignedRangeDownloadPlan) -> Self {
Self {
plan: Arc::new(StdMutex::new(plan)),
url_refresher: None,
}
}
pub fn with_url_refresher(mut self, refresher: Arc<dyn PresignedDownloadUrlRefresher>) -> Self {
self.url_refresher = Some(refresher);
self
}
pub fn plan(&self) -> Result<PresignedRangeDownloadPlan, MeowError> {
self.plan.lock().map(|g| g.clone()).map_err(|_| {
MeowError::from_code_str(
InnerErrorCode::InvalidTaskState,
"presigned range download plan lock poisoned",
)
})
}
fn merge_headers(target: &mut HeaderMap, extra: &HeaderMap) {
for (k, v) in extra {
target.insert(k.clone(), v.clone());
}
}
fn should_refresh_plan(plan: &PresignedRangeDownloadPlan) -> Result<bool, MeowError> {
let Some(expires_at) = plan.range_expires_at_unix_secs else {
return Ok(false);
};
Ok(now_unix_secs()?.saturating_add(plan.refresh_before_secs) >= expires_at)
}
fn is_plan_expired(plan: &PresignedRangeDownloadPlan) -> Result<bool, MeowError> {
let Some(expires_at) = plan.range_expires_at_unix_secs else {
return Ok(false);
};
Ok(now_unix_secs()? >= expires_at)
}
pub(crate) fn ensure_fresh_plan(&self) -> Result<PresignedRangeDownloadPlan, MeowError> {
let plan = self.plan()?;
if !Self::should_refresh_plan(&plan)? {
return Ok(plan);
}
let Some(refresher) = &self.url_refresher else {
if Self::is_plan_expired(&plan)? {
return Err(MeowError::from_code_str(
InnerErrorCode::InvalidTaskState,
"presigned range URL expired and no refresher is configured",
));
}
return Ok(plan);
};
let mut refreshed = refresher.refresh_range_download(&plan)?;
if let (Some(old), Some(new)) = (plan.total_size, refreshed.total_size) {
if old != new {
return Err(MeowError::from_code(
InnerErrorCode::InvalidTaskState,
format!("refreshed range total_size mismatch: old={old} new={new}"),
));
}
}
if refreshed.total_size.is_none() {
refreshed.total_size = plan.total_size;
}
let mut guard = self.plan.lock().map_err(|_| {
MeowError::from_code_str(
InnerErrorCode::InvalidTaskState,
"presigned range download plan lock poisoned",
)
})?;
*guard = refreshed.clone();
Ok(refreshed)
}
}
impl BreakpointDownload for PresignedRangeDownload {
fn total_size_hint(&self, _task: &TransferTask) -> Option<u64> {
self.plan.lock().ok().and_then(|g| g.total_size)
}
fn head_url(&self, task: &TransferTask) -> String {
self.plan
.lock()
.ok()
.and_then(|g| g.head_url.clone())
.unwrap_or_else(|| task.url().to_string())
}
fn range_url(&self, _task: &TransferTask) -> String {
self.plan
.lock()
.map(|g| g.range_url.clone())
.unwrap_or_default()
}
fn merge_head_headers(&self, ctx: DownloadHeadCtx<'_>) -> Result<(), MeowError> {
let plan = self.plan()?;
Self::merge_headers(ctx.base, &plan.head_headers);
Ok(())
}
fn merge_range_get_headers(&self, ctx: DownloadRangeGetCtx<'_>) -> Result<(), MeowError> {
let plan = self.ensure_fresh_plan()?;
ctx.base.insert(
RANGE,
HeaderValue::from_str(ctx.range_value).map_err(|e| {
MeowError::from_code(
InnerErrorCode::ParameterEmpty,
format!("invalid range header value '{}': {e}", ctx.range_value),
)
})?,
);
if !ctx.base.contains_key(ACCEPT) {
ctx.base.insert(
ACCEPT,
HeaderValue::from_static(crate::http_breakpoint::DEFAULT_RANGE_ACCEPT),
);
}
Self::merge_headers(ctx.base, &plan.range_headers);
Ok(())
}
}