use std::fmt;
use std::marker::PhantomData;
use std::sync::{Arc, Mutex};
use ipc_channel::ipc::IpcSender;
use ipc_channel::router::ROUTER;
use malloc_size_of::{MallocSizeOf, MallocSizeOfOps};
use serde::de::VariantAccess;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use servo_config::opts;
use crate::generic_channel::{
GenericReceiver, GenericReceiverVariants, SendError, SendResult, use_ipc,
};
pub type MsgCallback<T> = dyn FnMut(Result<T, ipc_channel::IpcError>) + Send;
pub struct GenericCallback<T>(GenericCallbackVariants<T>)
where
T: Serialize + Send + 'static;
enum GenericCallbackVariants<T>
where
T: Serialize + Send + 'static,
{
CrossProcess(IpcSender<T>),
InProcess(Arc<Mutex<MsgCallback<T>>>),
}
impl<T> Clone for GenericCallback<T>
where
T: Serialize + Send + 'static,
{
fn clone(&self) -> Self {
let variant = match &self.0 {
GenericCallbackVariants::CrossProcess(sender) => {
GenericCallbackVariants::CrossProcess((*sender).clone())
},
GenericCallbackVariants::InProcess(callback) => {
GenericCallbackVariants::InProcess(callback.clone())
},
};
GenericCallback(variant)
}
}
impl<T> MallocSizeOf for GenericCallback<T>
where
T: Serialize + Send + 'static,
{
fn size_of(&self, _ops: &mut MallocSizeOfOps) -> usize {
0
}
}
impl<T> GenericCallback<T>
where
T: for<'de> Deserialize<'de> + Serialize + Send + 'static,
{
pub fn new<F: FnMut(Result<T, ipc_channel::IpcError>) + Send + 'static>(
mut callback: F,
) -> Result<Self, ipc_channel::IpcError> {
let generic_callback = if use_ipc() {
let (ipc_sender, ipc_receiver) = ipc_channel::ipc::channel()?;
let new_callback = move |msg: Result<T, ipc_channel::SerDeError>| {
callback(msg.map_err(|error| error.into()))
};
ROUTER.add_typed_route(ipc_receiver, Box::new(new_callback));
GenericCallback(GenericCallbackVariants::CrossProcess(ipc_sender))
} else {
let callback = Arc::new(Mutex::new(callback));
GenericCallback(GenericCallbackVariants::InProcess(callback))
};
Ok(generic_callback)
}
pub fn new_blocking() -> Result<(Self, GenericReceiver<T>), ipc_channel::IpcError> {
if use_ipc() {
let (sender, receiver) = ipc_channel::ipc::channel()?;
let generic_callback = GenericCallback(GenericCallbackVariants::CrossProcess(sender));
let receiver = GenericReceiver(GenericReceiverVariants::Ipc(receiver));
Ok((generic_callback, receiver))
} else {
let (sender, receiver) = crossbeam_channel::bounded(1);
let callback = Arc::new(Mutex::new(move |msg| {
if sender.send(msg).is_err() {
log::error!("Error in callback");
}
}));
let generic_callback = GenericCallback(GenericCallbackVariants::InProcess(callback));
let receiver = GenericReceiver(GenericReceiverVariants::Crossbeam(receiver));
Ok((generic_callback, receiver))
}
}
pub fn send(&self, value: T) -> SendResult {
match &self.0 {
GenericCallbackVariants::CrossProcess(sender) => {
sender.send(value).map_err(|error| match error {
ipc_channel::IpcError::SerializationError(ser_de_error) => {
SendError::SerializationError(ser_de_error.to_string())
},
ipc_channel::IpcError::Io(_) | ipc_channel::IpcError::Disconnected => {
SendError::Disconnected
},
})
},
GenericCallbackVariants::InProcess(callback) => {
let mut cb = callback.lock().expect("poisoned");
(*cb)(Ok(value));
Ok(())
},
}
}
}
impl<T> Serialize for GenericCallback<T>
where
T: Serialize + Send + 'static,
{
fn serialize<S: Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
match &self.0 {
GenericCallbackVariants::CrossProcess(sender) => {
s.serialize_newtype_variant("GenericCallback", 0, "CrossProcess", sender)
},
GenericCallbackVariants::InProcess(wrapped_callback) => {
if opts::get().multiprocess {
return Err(serde::ser::Error::custom(
"InProcess callback can't be serialized in multiprocess mode",
));
}
let cloned_callback = Box::new(wrapped_callback.clone());
let sender_clone_addr = Box::leak(cloned_callback) as *mut Arc<_> as usize;
s.serialize_newtype_variant("GenericCallback", 1, "InProcess", &sender_clone_addr)
},
}
}
}
struct GenericCallbackVisitor<T> {
marker: PhantomData<T>,
}
impl<'de, T> serde::de::Visitor<'de> for GenericCallbackVisitor<T>
where
T: Serialize + Deserialize<'de> + Send + 'static,
{
type Value = GenericCallback<T>;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a GenericCallback variant")
}
fn visit_enum<A>(self, data: A) -> Result<Self::Value, A::Error>
where
A: serde::de::EnumAccess<'de>,
{
#[derive(Deserialize)]
enum GenericCallbackVariantNames {
CrossProcess,
InProcess,
}
let (variant_name, variant_data): (GenericCallbackVariantNames, _) = data.variant()?;
match variant_name {
GenericCallbackVariantNames::CrossProcess => variant_data
.newtype_variant::<IpcSender<T>>()
.map(|sender| GenericCallback(GenericCallbackVariants::CrossProcess(sender))),
GenericCallbackVariantNames::InProcess => {
if use_ipc() {
return Err(serde::de::Error::custom(
"InProcess callback found in multiprocess mode",
));
}
let addr = variant_data.newtype_variant::<usize>()?;
let ptr = addr as *mut Arc<Mutex<_>>;
#[expect(unsafe_code)]
let callback = unsafe { Box::from_raw(ptr) };
Ok(GenericCallback(GenericCallbackVariants::InProcess(
*callback,
)))
},
}
}
}
impl<'a, T> Deserialize<'a> for GenericCallback<T>
where
T: Serialize + Deserialize<'a> + Send + 'static,
{
fn deserialize<D>(d: D) -> Result<GenericCallback<T>, D::Error>
where
D: Deserializer<'a>,
{
d.deserialize_enum(
"GenericCallback",
&["CrossProcess", "InProcess"],
GenericCallbackVisitor {
marker: PhantomData,
},
)
}
}
impl<T> fmt::Debug for GenericCallback<T>
where
T: Serialize + Send + 'static,
{
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "GenericCallback(..)")
}
}
#[cfg(test)]
mod single_process_callback_test {
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use crate::generic_channel::GenericCallback;
#[test]
fn generic_callback() {
let number = Arc::new(AtomicUsize::new(0));
let number_clone = number.clone();
let callback = move |msg: Result<usize, ipc_channel::IpcError>| {
number_clone.store(msg.unwrap(), Ordering::SeqCst)
};
let generic_callback = GenericCallback::new(callback).unwrap();
std::thread::scope(|s| {
s.spawn(move || generic_callback.send(42));
});
assert_eq!(number.load(Ordering::SeqCst), 42);
}
#[test]
fn generic_callback_via_generic_sender() {
let number = Arc::new(AtomicUsize::new(0));
let number_clone = number.clone();
let callback = move |msg: Result<usize, ipc_channel::IpcError>| {
number_clone.store(msg.unwrap(), Ordering::SeqCst)
};
let generic_callback = GenericCallback::new(callback).unwrap();
let (tx, rx) = crate::generic_channel::channel().unwrap();
tx.send(generic_callback).unwrap();
std::thread::scope(|s| {
s.spawn(move || {
let callback = rx.recv().unwrap();
callback.send(42).unwrap();
});
});
assert_eq!(number.load(Ordering::SeqCst), 42);
}
#[test]
fn generic_callback_via_ipc_sender() {
let number = Arc::new(AtomicUsize::new(0));
let number_clone = number.clone();
let callback = move |msg: Result<usize, ipc_channel::IpcError>| {
number_clone.store(msg.unwrap(), Ordering::SeqCst)
};
let generic_callback = GenericCallback::new(callback).unwrap();
let (tx, rx) = ipc_channel::ipc::channel().unwrap();
tx.send(generic_callback).unwrap();
std::thread::scope(|s| {
s.spawn(move || {
let callback = rx.recv().unwrap();
callback.send(42).unwrap();
});
});
assert_eq!(number.load(Ordering::SeqCst), 42);
}
#[test]
fn generic_callback_blocking() {
let (callback, receiver) = GenericCallback::new_blocking().unwrap();
std::thread::spawn(move || {
std::thread::sleep(std::time::Duration::from_secs(1));
assert!(callback.send(42).is_ok());
});
assert_eq!(receiver.recv().unwrap(), 42);
}
}