1use std::{
2 fmt,
3 hash::{Hash, Hasher},
4};
5
6use gc_arena::{Collect, Gc, Mutation};
7
8use crate::{
9 errors::Error,
10 meta_ops::MetaResult,
11 objects::{Function, Value},
12 Context,
13};
14
15#[derive(Debug, PartialEq, Eq, Collect)]
16#[collect(no_drop)]
17pub enum CallbackReturn<'gc> {
18 Return(Value<'gc>),
19 TailCall(Function<'gc>, Vec<Value<'gc>>),
20}
21
22impl<'gc, const N: usize> From<MetaResult<'gc, N>> for CallbackReturn<'gc> {
23 fn from(value: MetaResult<'gc, N>) -> Self {
24 match value {
25 MetaResult::Value(v) => CallbackReturn::Return(v),
26 MetaResult::Call(f, args) => CallbackReturn::TailCall(f, Vec::from(args)),
27 }
28 }
29}
30
31pub trait Callback<'gc>: Collect {
32 fn call(
33 &mut self,
34 ctx: Context<'gc>,
35 args: Vec<Value<'gc>>,
36 ) -> Result<CallbackReturn<'gc>, Error<'gc>>;
37}
38
39#[derive(Copy, Clone, Collect)]
41#[collect(no_drop)]
42pub struct AnyCallback<'gc>(Gc<'gc, Header<'gc>>);
43
44struct Header<'gc> {
45 call: unsafe fn(
46 *const (),
47 Context<'gc>,
48 Vec<Value<'gc>>,
49 ) -> Result<CallbackReturn<'gc>, Error<'gc>>,
50}
51
52impl<'gc> AnyCallback<'gc> {
53 pub fn new<C: Callback<'gc> + 'gc>(mc: &Mutation<'gc>, callback: C) -> Self {
54 #[repr(C)]
55 struct HeaderCallback<'gc, C> {
56 header: Header<'gc>,
57 callback: C,
58 }
59
60 unsafe impl<'gc, C: Collect> Collect for HeaderCallback<'gc, C> {
64 fn needs_trace() -> bool
65 where
66 Self: Sized,
67 {
68 C::needs_trace()
69 }
70
71 fn trace(&self, cc: &gc_arena::Collection) {
72 self.callback.trace(cc)
73 }
74 }
75
76 let hc = Gc::new(
77 mc,
78 HeaderCallback {
79 header: Header {
80 call: |ptr, ctx, args| unsafe {
81 let hc = ptr as *mut HeaderCallback<C>;
82 ((*hc).callback).call(ctx, args)
83 },
84 },
85 callback,
86 },
87 );
88
89 Self(unsafe { Gc::cast::<Header>(hc) })
90 }
91
92 pub fn from_fn<F>(mc: &Mutation<'gc>, call: F) -> AnyCallback<'gc>
93 where
94 F: 'static + Fn(Context<'gc>, Vec<Value<'gc>>) -> Result<CallbackReturn<'gc>, Error<'gc>>,
95 {
96 Self::from_fn_with(mc, (), move |_, ctx, args| call(ctx, args))
97 }
98
99 pub fn from_fn_with<R, F>(mc: &Mutation<'gc>, root: R, call: F) -> AnyCallback<'gc>
100 where
101 R: 'gc + Collect,
102 F: 'static
103 + Fn(&R, Context<'gc>, Vec<Value<'gc>>) -> Result<CallbackReturn<'gc>, Error<'gc>>,
104 {
105 #[derive(Collect)]
106 #[collect(no_drop)]
107 struct RootCallback<R, F> {
108 root: R,
109 #[collect(require_static)]
110 call: F,
111 }
112
113 impl<'gc, R, F> Callback<'gc> for RootCallback<R, F>
114 where
115 R: 'gc + Collect,
116 F: 'static
117 + Fn(&R, Context<'gc>, Vec<Value<'gc>>) -> Result<CallbackReturn<'gc>, Error<'gc>>,
118 {
119 fn call(
120 &mut self,
121 ctx: Context<'gc>,
122 args: Vec<Value<'gc>>,
123 ) -> Result<CallbackReturn<'gc>, Error<'gc>> {
124 (self.call)(&self.root, ctx, args)
125 }
126 }
127
128 AnyCallback::new(mc, RootCallback { root, call })
129 }
130
131 pub fn as_ptr(self) -> *const () {
132 Gc::as_ptr(self.0) as *const ()
133 }
134
135 pub fn call(
136 self,
137 ctx: Context<'gc>,
138 args: Vec<Value<'gc>>,
139 ) -> Result<CallbackReturn<'gc>, Error<'gc>> {
140 unsafe { (self.0.call)(Gc::as_ptr(self.0) as *const (), ctx, args) }
141 }
142}
143
144impl<'gc> fmt::Debug for AnyCallback<'gc> {
145 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
146 fmt.debug_tuple("Callback").field(&self.as_ptr()).finish()
147 }
148}
149
150impl<'gc> PartialEq for AnyCallback<'gc> {
151 fn eq(&self, other: &AnyCallback<'gc>) -> bool {
152 self.as_ptr() == other.as_ptr()
153 }
154}
155
156impl<'gc> Eq for AnyCallback<'gc> {}
157
158impl<'gc> Hash for AnyCallback<'gc> {
159 fn hash<H: Hasher>(&self, state: &mut H) {
160 self.as_ptr().hash(state)
161 }
162}
163
164#[cfg(test)]
165mod tests {
166 use gc_arena::{Arena, Rootable};
167
168 use crate::{context::State, objects::CallbackReturn};
169
170 use super::*;
171
172 #[test]
173 fn test_dyn_callback() {
174 #[derive(Collect)]
175 #[collect(require_static)]
176 struct CB(i64);
177
178 impl<'gc> Callback<'gc> for CB {
179 fn call(
180 &mut self,
181 _ctx: Context<'gc>,
182 _args: Vec<Value<'gc>>,
183 ) -> Result<CallbackReturn<'gc>, Error<'gc>> {
184 Ok(CallbackReturn::Return(Value::Int(42)))
185 }
186 }
187
188 let arena = Arena::<Rootable![State<'_>]>::new(Default::default(), |mc| State::new(mc));
189 arena.mutate(|mc, state| {
190 let ctx = state.ctx(mc);
191 let dyn_callback = AnyCallback::new(mc, CB(17));
192 assert_eq!(
193 dyn_callback.call(ctx, Vec::new()),
194 Ok(CallbackReturn::Return(Value::Int(42)))
195 );
196 });
197 }
198}