use std::sync::{Arc, atomic::Ordering};
use kithara_abr::{AbrController, AbrPeerId};
use kithara_events::{
BandwidthSource, CancelReason, DownloaderEvent, EventBus, RequestId, RequestMethod,
};
use kithara_net::{HttpClient, NetError};
use kithara_platform::{
CancelGroup,
time::{Duration, Instant},
tokio,
tokio::task,
};
use kithara_test_utils::kithara;
use tokio_util::sync::CancellationToken;
use tracing::warn;
use super::{
cmd::FetchCmd,
downloader::DownloaderInner,
peer::{InternalCmd, ResponseTarget, SlotEntry},
response::{BodyStream, FetchResponse},
};
#[kithara::probe(request_id, wait_in_queue)]
fn start_request(
bus: Option<&EventBus>,
inflight: &std::sync::atomic::AtomicUsize,
request_id: RequestId,
wait_in_queue: Duration,
) {
inflight.fetch_add(1, Ordering::Relaxed);
if let Some(b) = bus {
b.publish(DownloaderEvent::RequestStarted {
request_id,
wait_in_queue,
});
}
}
#[kithara::probe(request_id, bytes_transferred, duration)]
fn finish_request(
bus: Option<&EventBus>,
abr: &AbrController,
peer_id: AbrPeerId,
request_id: RequestId,
bytes_transferred: u64,
duration: Duration,
) {
if bytes_transferred > 0 {
abr.record_bandwidth(
peer_id,
bytes_transferred,
duration,
BandwidthSource::Network,
);
}
if let Some(b) = bus {
b.publish(DownloaderEvent::RequestCompleted {
request_id,
bytes_transferred,
duration,
bandwidth_bps: bandwidth_bps(bytes_transferred, duration),
});
}
}
#[kithara::probe(request_id, reason, bytes_transferred, was_in_flight)]
fn abort_request(
bus: Option<&EventBus>,
request_id: RequestId,
reason: CancelReason,
bytes_transferred: u64,
was_in_flight: bool,
) {
if let Some(bus) = bus {
bus.publish(DownloaderEvent::RequestCancelled {
request_id,
reason,
bytes_transferred,
});
}
let _ = was_in_flight;
}
#[kithara::probe(request_id, retryable)]
fn fail_request(bus: Option<&EventBus>, request_id: RequestId, err: &NetError, retryable: bool) {
if let Some(bus) = bus {
bus.publish(DownloaderEvent::RequestFailed {
request_id,
retryable,
error: err.clone(),
});
}
}
struct EpochGroup {
cancel: CancelGroup,
entries: Vec<SlotEntry>,
}
pub(super) struct BatchGroup {
epochs: Vec<EpochGroup>,
}
impl FromIterator<SlotEntry> for BatchGroup {
fn from_iter<I: IntoIterator<Item = SlotEntry>>(entries: I) -> Self {
let mut epochs: Vec<EpochGroup> = Vec::new();
for entry in entries {
let found = epochs
.iter_mut()
.find(|g| g.cancel.equals_ptr(&entry.cmd.cancel));
match found {
Some(group) => group.entries.push(entry),
None => epochs.push(EpochGroup {
cancel: entry.cmd.cancel.clone(),
entries: vec![entry],
}),
}
}
Self { epochs }
}
}
impl BatchGroup {
pub(super) fn is_empty(&self) -> bool {
self.epochs.is_empty()
}
pub(super) async fn process(self, inner: &DownloaderInner) -> usize {
let mut dispatched: usize = 0;
for group in self.epochs {
if group.cancel.is_cancelled() {
for entry in group.entries {
deliver_cancelled_with_event(entry.cmd, &entry.peer_cancel);
}
continue;
}
for entry in group.entries {
if entry.cmd.cancel.is_cancelled() {
deliver_cancelled_with_event(entry.cmd, &entry.peer_cancel);
continue;
}
while inner.inflight.load(Ordering::Relaxed) >= inner.max_concurrent {
task::yield_now().await;
}
spawn_fetch(inner, entry.cmd, entry.peer_cancel);
dispatched += 1;
task::yield_now().await;
}
}
dispatched
}
}
fn spawn_fetch(inner: &DownloaderInner, internal: InternalCmd, peer_cancel: CancellationToken) {
let client = inner.client.clone();
let chunk_timeout = inner.chunk_timeout;
let soft_timeout = inner.soft_timeout;
let inflight = inner.inflight.clone();
let fetch_waker = inner.fetch_waker.clone();
let abr = Arc::clone(&inner.abr);
let downloader_cancel = inner.cancel.clone();
let peer_id = internal.peer_id;
let request_id = internal.request_id;
let started = Instant::now();
let wait_in_queue = started.saturating_duration_since(internal.enqueued_at);
let mut cmd = internal.cmd;
let writer = cmd.writer.take();
let on_complete_cb = cmd.on_complete.take();
let on_response_cb = cmd.on_response.take();
let bus = internal.bus;
let cancel = internal.cancel.clone();
let epoch_cancel = cmd.cancel.clone();
start_request(bus.as_ref(), &inflight, request_id, wait_in_queue);
task::spawn(async move {
let result = establish(
&client,
chunk_timeout,
soft_timeout,
&cancel,
bus.clone(),
cmd,
request_id,
)
.await;
deliver(
request_id,
DeliveryContext {
result,
writer,
on_complete_cb,
on_response_cb,
abr,
peer_id,
started,
bus,
target: internal.response,
peer_cancel: &peer_cancel,
epoch_cancel: epoch_cancel.as_ref(),
downloader_cancel: &downloader_cancel,
},
)
.await;
inflight.fetch_sub(1, Ordering::Relaxed);
fetch_waker.wake();
});
}
#[kithara::probe(request_id)]
async fn with_soft_timeout<F, T>(
fut: F,
soft: Duration,
bus: Option<&EventBus>,
request_id: RequestId,
) -> T
where
F: Future<Output = T>,
{
tokio::pin!(fut);
let started = Instant::now();
tokio::select! {
r = &mut fut => r,
() = tokio::time::sleep(soft) => {
if let Some(bus) = bus {
bus.publish(DownloaderEvent::LoadSlow {
request_id,
elapsed: started.elapsed(),
});
}
fut.await
}
}
}
#[kithara::probe(request_id)]
async fn establish(
client: &HttpClient,
chunk_timeout: Duration,
soft_timeout: Duration,
cancel: &CancelGroup,
bus: Option<EventBus>,
cmd: FetchCmd,
request_id: RequestId,
) -> Result<FetchResponse, NetError> {
let FetchCmd {
method,
url,
range,
headers,
validator,
..
} = cmd;
if tracing::enabled!(tracing::Level::TRACE) {
let names: Vec<&str> = headers
.as_ref()
.map(|h| h.iter().map(|(k, _)| k).collect())
.unwrap_or_default();
tracing::trace!(%url, ?method, ?range, header_names = ?names, "fetch: outgoing FetchCmd");
}
if method == RequestMethod::Head {
let resp_headers = tokio::select! {
() = cancel.cancelled() => return Err(NetError::Cancelled),
r = with_soft_timeout(client.head(url, headers), soft_timeout, bus.as_ref(), request_id) => r?,
};
return Ok(FetchResponse {
headers: resp_headers,
body: BodyStream::empty(),
});
}
let fetch_url = url.clone();
let fetch = async {
match range {
Some(range) => client.get_range(url, range, headers).await,
None => client.stream(url, headers).await,
}
};
let byte_stream = tokio::select! {
() = cancel.cancelled() => return Err(NetError::Cancelled),
r = with_soft_timeout(fetch, soft_timeout, bus.as_ref(), request_id) => r?,
};
if let Some(validate) = validator
&& let Err(e) = validate(&byte_stream.headers)
{
warn!(url = %fetch_url, error = %e, "fetch rejected by response validator");
return Err(e);
}
let resp_headers = byte_stream.headers.clone();
let body = BodyStream::wrap_http(byte_stream, cancel.clone(), chunk_timeout);
Ok(FetchResponse {
body,
headers: resp_headers,
})
}
fn bandwidth_bps(bytes: u64, duration: Duration) -> u64 {
const BITS_TIMES_MS_PER_SEC: u64 = 8_000;
let ms = u64::try_from(duration.as_millis())
.unwrap_or(u64::MAX)
.max(1);
bytes.saturating_mul(BITS_TIMES_MS_PER_SEC) / ms
}
fn classify_cancel(
peer_cancel: &CancellationToken,
epoch_cancel: Option<&CancellationToken>,
downloader_cancel: &CancellationToken,
) -> CancelReason {
if peer_cancel.is_cancelled() {
CancelReason::PeerCancel
} else if epoch_cancel.is_some_and(CancellationToken::is_cancelled) {
CancelReason::EpochCancel
} else if downloader_cancel.is_cancelled() {
CancelReason::DownloaderShutdown
} else {
CancelReason::BeforeStart
}
}
struct DeliveryContext<'a> {
downloader_cancel: &'a CancellationToken,
peer_cancel: &'a CancellationToken,
peer_id: AbrPeerId,
abr: Arc<AbrController>,
started: Instant,
bus: Option<EventBus>,
epoch_cancel: Option<&'a CancellationToken>,
on_complete_cb: Option<super::cmd::OnCompleteFn>,
on_response_cb: Option<super::cmd::OnResponseFn>,
writer: Option<super::cmd::WriterFn>,
target: ResponseTarget,
result: Result<FetchResponse, NetError>,
}
#[kithara::probe(request_id)]
async fn deliver(request_id: RequestId, ctx: DeliveryContext<'_>) {
let DeliveryContext {
target,
result,
mut writer,
on_complete_cb,
on_response_cb,
abr,
peer_id,
started,
bus,
peer_cancel,
epoch_cancel,
downloader_cancel,
} = ctx;
match target {
ResponseTarget::Channel(tx) => {
tx.send(result).ok();
}
ResponseTarget::Streaming => match result {
Ok(resp) => {
if let Some(ref mut w) = writer {
let headers = resp.headers.clone();
if let Some(cb) = on_response_cb {
cb(&headers);
}
let write_result = resp.body.write_all(|chunk| w(chunk)).await;
let elapsed = started.elapsed();
match write_result {
Ok(total) => {
finish_request(bus.as_ref(), &abr, peer_id, request_id, total, elapsed);
if let Some(cb) = on_complete_cb {
cb(total, Some(&headers), None);
}
}
Err(ref e) => {
publish_failure_or_cancel(
bus.as_ref(),
request_id,
e,
0,
peer_cancel,
epoch_cancel,
downloader_cancel,
);
if let Some(cb) = on_complete_cb {
cb(0, Some(&headers), Some(e));
}
}
}
}
}
Err(ref e) => {
publish_failure_or_cancel(
bus.as_ref(),
request_id,
e,
0,
peer_cancel,
epoch_cancel,
downloader_cancel,
);
if let Some(cb) = on_complete_cb {
cb(0, None, Some(e));
}
}
},
}
}
fn publish_failure_or_cancel(
bus: Option<&EventBus>,
request_id: RequestId,
err: &NetError,
bytes_transferred: u64,
peer_cancel: &CancellationToken,
epoch_cancel: Option<&CancellationToken>,
downloader_cancel: &CancellationToken,
) {
if matches!(err, NetError::Cancelled) {
let reason = classify_cancel(peer_cancel, epoch_cancel, downloader_cancel);
abort_request(bus, request_id, reason, bytes_transferred, true);
} else {
let retryable = err.is_retryable();
fail_request(bus, request_id, err, retryable);
}
}
pub(super) fn deliver_cancelled_with_event(internal: InternalCmd, peer_cancel: &CancellationToken) {
let request_id = internal.request_id;
let bus = internal.bus.clone();
let epoch_cancel = internal.cmd.cancel.clone();
let placeholder_inner = CancellationToken::new(); let reason = classify_cancel(peer_cancel, epoch_cancel.as_ref(), &placeholder_inner);
abort_request(bus.as_ref(), request_id, reason, 0, false);
deliver_cancelled(internal.response, internal.cmd);
}
pub(super) fn deliver_cancelled(target: ResponseTarget, mut cmd: FetchCmd) {
let err = NetError::Cancelled;
match target {
ResponseTarget::Channel(tx) => {
tx.send(Err(err)).ok();
}
ResponseTarget::Streaming => {
if let Some(cb) = cmd.on_complete.take() {
cb(0, None, Some(&err));
}
}
}
}