lucia_lang/objects/
callback.rs

1use std::{
2    fmt,
3    hash::{Hash, Hasher},
4};
5
6use gc_arena::{Collect, Gc, Mutation};
7
8use crate::{
9    errors::Error,
10    meta_ops::MetaResult,
11    objects::{Function, Value},
12    Context,
13};
14
15#[derive(Debug, PartialEq, Eq, Collect)]
16#[collect(no_drop)]
17pub enum CallbackReturn<'gc> {
18    Return(Value<'gc>),
19    TailCall(Function<'gc>, Vec<Value<'gc>>),
20}
21
22impl<'gc, const N: usize> From<MetaResult<'gc, N>> for CallbackReturn<'gc> {
23    fn from(value: MetaResult<'gc, N>) -> Self {
24        match value {
25            MetaResult::Value(v) => CallbackReturn::Return(v),
26            MetaResult::Call(f, args) => CallbackReturn::TailCall(f, Vec::from(args)),
27        }
28    }
29}
30
31pub trait Callback<'gc>: Collect {
32    fn call(
33        &mut self,
34        ctx: Context<'gc>,
35        args: Vec<Value<'gc>>,
36    ) -> Result<CallbackReturn<'gc>, Error<'gc>>;
37}
38
39// Represents a callback as a single pointer with an inline VTable header.
40#[derive(Copy, Clone, Collect)]
41#[collect(no_drop)]
42pub struct AnyCallback<'gc>(Gc<'gc, Header<'gc>>);
43
44struct Header<'gc> {
45    call: unsafe fn(
46        *const (),
47        Context<'gc>,
48        Vec<Value<'gc>>,
49    ) -> Result<CallbackReturn<'gc>, Error<'gc>>,
50}
51
52impl<'gc> AnyCallback<'gc> {
53    pub fn new<C: Callback<'gc> + 'gc>(mc: &Mutation<'gc>, callback: C) -> Self {
54        #[repr(C)]
55        struct HeaderCallback<'gc, C> {
56            header: Header<'gc>,
57            callback: C,
58        }
59
60        // SAFETY: We can't auto-implement `Collect` due to the function pointer lifetimes, but
61        // function pointers can't hold any data. It would be nice if function pointers could have
62        // higher rank `for<'gc>` lifetimes.
63        unsafe impl<'gc, C: Collect> Collect for HeaderCallback<'gc, C> {
64            fn needs_trace() -> bool
65            where
66                Self: Sized,
67            {
68                C::needs_trace()
69            }
70
71            fn trace(&self, cc: &gc_arena::Collection) {
72                self.callback.trace(cc)
73            }
74        }
75
76        let hc = Gc::new(
77            mc,
78            HeaderCallback {
79                header: Header {
80                    call: |ptr, ctx, args| unsafe {
81                        let hc = ptr as *mut HeaderCallback<C>;
82                        ((*hc).callback).call(ctx, args)
83                    },
84                },
85                callback,
86            },
87        );
88
89        Self(unsafe { Gc::cast::<Header>(hc) })
90    }
91
92    pub fn from_fn<F>(mc: &Mutation<'gc>, call: F) -> AnyCallback<'gc>
93    where
94        F: 'static + Fn(Context<'gc>, Vec<Value<'gc>>) -> Result<CallbackReturn<'gc>, Error<'gc>>,
95    {
96        Self::from_fn_with(mc, (), move |_, ctx, args| call(ctx, args))
97    }
98
99    pub fn from_fn_with<R, F>(mc: &Mutation<'gc>, root: R, call: F) -> AnyCallback<'gc>
100    where
101        R: 'gc + Collect,
102        F: 'static
103            + Fn(&R, Context<'gc>, Vec<Value<'gc>>) -> Result<CallbackReturn<'gc>, Error<'gc>>,
104    {
105        #[derive(Collect)]
106        #[collect(no_drop)]
107        struct RootCallback<R, F> {
108            root: R,
109            #[collect(require_static)]
110            call: F,
111        }
112
113        impl<'gc, R, F> Callback<'gc> for RootCallback<R, F>
114        where
115            R: 'gc + Collect,
116            F: 'static
117                + Fn(&R, Context<'gc>, Vec<Value<'gc>>) -> Result<CallbackReturn<'gc>, Error<'gc>>,
118        {
119            fn call(
120                &mut self,
121                ctx: Context<'gc>,
122                args: Vec<Value<'gc>>,
123            ) -> Result<CallbackReturn<'gc>, Error<'gc>> {
124                (self.call)(&self.root, ctx, args)
125            }
126        }
127
128        AnyCallback::new(mc, RootCallback { root, call })
129    }
130
131    pub fn as_ptr(self) -> *const () {
132        Gc::as_ptr(self.0) as *const ()
133    }
134
135    pub fn call(
136        self,
137        ctx: Context<'gc>,
138        args: Vec<Value<'gc>>,
139    ) -> Result<CallbackReturn<'gc>, Error<'gc>> {
140        unsafe { (self.0.call)(Gc::as_ptr(self.0) as *const (), ctx, args) }
141    }
142}
143
144impl<'gc> fmt::Debug for AnyCallback<'gc> {
145    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
146        fmt.debug_tuple("Callback").field(&self.as_ptr()).finish()
147    }
148}
149
150impl<'gc> PartialEq for AnyCallback<'gc> {
151    fn eq(&self, other: &AnyCallback<'gc>) -> bool {
152        self.as_ptr() == other.as_ptr()
153    }
154}
155
156impl<'gc> Eq for AnyCallback<'gc> {}
157
158impl<'gc> Hash for AnyCallback<'gc> {
159    fn hash<H: Hasher>(&self, state: &mut H) {
160        self.as_ptr().hash(state)
161    }
162}
163
164#[cfg(test)]
165mod tests {
166    use gc_arena::{Arena, Rootable};
167
168    use crate::{context::State, objects::CallbackReturn};
169
170    use super::*;
171
172    #[test]
173    fn test_dyn_callback() {
174        #[derive(Collect)]
175        #[collect(require_static)]
176        struct CB(i64);
177
178        impl<'gc> Callback<'gc> for CB {
179            fn call(
180                &mut self,
181                _ctx: Context<'gc>,
182                _args: Vec<Value<'gc>>,
183            ) -> Result<CallbackReturn<'gc>, Error<'gc>> {
184                Ok(CallbackReturn::Return(Value::Int(42)))
185            }
186        }
187
188        let arena = Arena::<Rootable![State<'_>]>::new(Default::default(), |mc| State::new(mc));
189        arena.mutate(|mc, state| {
190            let ctx = state.ctx(mc);
191            let dyn_callback = AnyCallback::new(mc, CB(17));
192            assert_eq!(
193                dyn_callback.call(ctx, Vec::new()),
194                Ok(CallbackReturn::Return(Value::Int(42)))
195            );
196        });
197    }
198}