use std::sync::Arc;
use block2::DynBlock;
use objc2::rc::Retained;
use objc2::{AllocAnyThread, DefinedClass, define_class, msg_send};
use objc2_foundation::{
NSCopying, NSData, NSError, NSObject, NSObjectProtocol, NSURLAuthenticationChallenge,
NSURLAuthenticationMethodServerTrust, NSURLCredential, NSURLResponse, NSURLSession,
NSURLSessionAuthChallengeDisposition, NSURLSessionDataDelegate, NSURLSessionDataTask,
NSURLSessionDelegate, NSURLSessionResponseDisposition, NSURLSessionTask,
NSURLSessionTaskDelegate,
};
use super::TaskSharedContext;
use std::collections::HashMap;
use std::sync::Mutex;
pub struct DataTaskDelegateIvars {
pub task_contexts: Mutex<HashMap<usize, Arc<TaskSharedContext>>>,
}
define_class!(
#[unsafe(super = NSObject)]
#[name = "fraktDataTaskDelegate"]
#[ivars = DataTaskDelegateIvars]
pub struct DataTaskDelegate;
unsafe impl NSObjectProtocol for DataTaskDelegate {}
unsafe impl NSURLSessionDelegate for DataTaskDelegate {}
unsafe impl NSURLSessionTaskDelegate for DataTaskDelegate {
#[unsafe(method(URLSession:task:didCompleteWithError:))]
fn URLSession_task_didCompleteWithError(
&self,
_session: &NSURLSession,
task: &NSURLSessionTask,
error: Option<&NSError>,
) {
let ivars = self.ivars();
let task_id = unsafe { task.taskIdentifier() } as usize;
if let Ok(mut contexts) = ivars.task_contexts.lock() {
if let Some(shared_context) = contexts.remove(&task_id) {
if let Some(error) = error {
shared_context.set_error(error.copy());
} else {
shared_context.mark_completed();
}
}
}
}
#[unsafe(method(URLSession:task:didReceiveChallenge:completionHandler:))]
fn URLSession_task_didReceiveChallenge_completionHandler(
&self,
session: &NSURLSession,
_task: &NSURLSessionTask,
challenge: &NSURLAuthenticationChallenge,
completion_handler: &DynBlock<
dyn Fn(NSURLSessionAuthChallengeDisposition, *mut NSURLCredential),
>,
) {
unsafe {
let protection_space = challenge.protectionSpace();
let auth_method = protection_space.authenticationMethod();
if auth_method.isEqualToString(&NSURLAuthenticationMethodServerTrust) {
let _config = session.configuration();
completion_handler.call((
NSURLSessionAuthChallengeDisposition::PerformDefaultHandling,
std::ptr::null_mut(),
));
} else {
completion_handler.call((
NSURLSessionAuthChallengeDisposition::PerformDefaultHandling,
std::ptr::null_mut(),
));
}
}
}
#[unsafe(method(URLSession:task:didSendBodyData:totalBytesSent:totalBytesExpectedToSend:))]
fn URLSession_task_didSendBodyData_totalBytesSent_totalBytesExpectedToSend(
&self,
_session: &NSURLSession,
task: &NSURLSessionTask,
_bytes_sent: i64,
total_bytes_sent: i64,
total_bytes_expected_to_send: i64,
) {
let ivars = self.ivars();
let task_id = unsafe { task.taskIdentifier() } as usize;
if let Ok(contexts) = ivars.task_contexts.lock() {
if let Some(shared_context) = contexts.get(&task_id) {
if total_bytes_expected_to_send > 0 {
shared_context
.set_total_bytes_expected(total_bytes_expected_to_send as u64);
}
let previous_bytes = shared_context
.bytes_downloaded
.load(std::sync::atomic::Ordering::Acquire);
let additional = (total_bytes_sent as u64).saturating_sub(previous_bytes);
if additional > 0 {
shared_context.update_progress(additional);
}
}
}
}
}
unsafe impl NSURLSessionDataDelegate for DataTaskDelegate {
#[unsafe(method(URLSession:dataTask:didReceiveResponse:completionHandler:))]
fn URLSession_dataTask_didReceiveResponse_completionHandler(
&self,
_session: &NSURLSession,
data_task: &NSURLSessionDataTask,
response: &NSURLResponse,
completion_handler: &DynBlock<dyn Fn(NSURLSessionResponseDisposition)>,
) {
let ivars = self.ivars();
let task_id = unsafe { data_task.taskIdentifier() } as usize;
if let Ok(contexts) = ivars.task_contexts.lock() {
if let Some(shared_context) = contexts.get(&task_id) {
shared_context
.response
.store(Some(Arc::new(response.copy())));
let expected_length = unsafe { response.expectedContentLength() };
if expected_length > 0 {
shared_context.set_total_bytes_expected(expected_length as u64);
}
}
}
completion_handler.call((NSURLSessionResponseDisposition::Allow,));
}
#[unsafe(method(URLSession:dataTask:didReceiveData:))]
fn URLSession_dataTask_didReceiveData(
&self,
_session: &NSURLSession,
data_task: &NSURLSessionDataTask,
data: &NSData,
) {
let ivars = self.ivars();
let task_id = unsafe { data_task.taskIdentifier() } as usize;
if let Ok(contexts) = ivars.task_contexts.lock() {
if let Some(shared_context) = contexts.get(&task_id) {
let bytes = data.to_vec();
if let Ok(mut buffer) = shared_context.response_buffer.try_lock() {
let max_size = shared_context
.max_response_buffer_size
.load(std::sync::atomic::Ordering::Acquire);
if buffer.len() as u64 + bytes.len() as u64 <= max_size {
buffer.extend_from_slice(&bytes);
shared_context.update_progress(bytes.len() as u64);
}
}
}
}
}
}
);
impl DataTaskDelegate {
pub fn new() -> Retained<Self> {
let delegate = Self::alloc().set_ivars(DataTaskDelegateIvars {
task_contexts: Mutex::new(HashMap::new()),
});
unsafe { msg_send![super(delegate), init] }
}
pub fn register_task(&self, task_id: usize, context: Arc<TaskSharedContext>) {
if let Ok(mut contexts) = self.ivars().task_contexts.lock() {
contexts.insert(task_id, context);
}
}
}