1use std::{
14 future::Future,
15 marker::PhantomData,
16 pin::Pin,
17 task::{self, Poll},
18};
19
20use futures::future;
21use rlua::{
22 Chunk, Context, FromLuaMulti, Function, MultiValue, Result, Scope, Thread, ThreadStatus,
23 ToLuaMulti, UserData, UserDataMethods,
24};
25use scoped_tls::scoped_thread_local;
26
27pub mod prelude {
30 pub use super::{ChunkExt, ContextExt, FunctionExt, ScopeExt};
31}
32
33scoped_thread_local!(static FUTURE_CTX: *mut ());
41
42pub trait ContextExt<'lua> {
44 fn create_async_function<Arg, Ret, RetFut, F>(self, func: F) -> Result<Function<'lua>>
52 where
53 Arg: FromLuaMulti<'lua>,
54 Ret: ToLuaMulti<'lua>,
55 RetFut: 'static + Send + Future<Output = Result<Ret>>,
56 F: 'static + Send + Fn(Context<'lua>, Arg) -> RetFut;
57
58 fn create_async_function_mut<Arg, Ret, RetFut, F>(self, func: F) -> Result<Function<'lua>>
59 where
60 Arg: FromLuaMulti<'lua>,
61 Ret: ToLuaMulti<'lua>,
62 RetFut: 'static + Send + Future<Output = Result<Ret>>,
63 F: 'static + Send + FnMut(Context<'lua>, Arg) -> RetFut;
64}
65
66fn poller_fn<'lua, Ret, RetFut>(
67 ctx: Context<'lua>,
68 mut fut: Pin<Box<RetFut>>,
69) -> Result<Function<'lua>>
70where
71 Ret: ToLuaMulti<'lua>,
72 RetFut: 'static + Send + Future<Output = Result<Ret>>,
73{
74 ctx.create_function_mut(move |ctx, _: MultiValue<'lua>| {
75 FUTURE_CTX.with(|fut_ctx| {
76 let fut_ctx_ref = unsafe { &mut *(*fut_ctx as *mut task::Context) };
77 match Future::poll(fut.as_mut(), fut_ctx_ref) {
78 Poll::Pending => ToLuaMulti::to_lua_multi((rlua::Value::Nil, false), ctx),
79 Poll::Ready(v) => {
80 let v = ToLuaMulti::to_lua_multi(v?, ctx)?.into_vec();
81 ToLuaMulti::to_lua_multi((v, true), ctx)
82 }
83 }
84 })
85 })
86}
87
88static MAKE_POLLER: &[u8] = include_bytes!("make-poller.lua");
89
90impl<'lua> ContextExt<'lua> for Context<'lua> {
91 fn create_async_function<Arg, Ret, RetFut, F>(self, func: F) -> Result<Function<'lua>>
92 where
93 Arg: FromLuaMulti<'lua>,
94 Ret: ToLuaMulti<'lua>,
95 RetFut: 'static + Send + Future<Output = Result<Ret>>,
96 F: 'static + Send + Fn(Context<'lua>, Arg) -> RetFut,
97 {
98 let wrapped_fun = self.create_function(move |ctx, arg| {
99 let fut = Box::pin(func(ctx, arg));
100 poller_fn(ctx, fut)
101 })?;
102
103 self.load(MAKE_POLLER)
104 .set_name(b"coroutine yield helper")?
105 .eval::<Function<'lua>>()? .call(wrapped_fun)
107 }
108
109 fn create_async_function_mut<Arg, Ret, RetFut, F>(self, mut func: F) -> Result<Function<'lua>>
110 where
111 Arg: FromLuaMulti<'lua>,
112 Ret: ToLuaMulti<'lua>,
113 RetFut: 'static + Send + Future<Output = Result<Ret>>,
114 F: 'static + Send + FnMut(Context<'lua>, Arg) -> RetFut,
115 {
116 let wrapped_fun = self.create_function_mut(move |ctx, arg| {
117 let fut = Box::pin(func(ctx, arg));
118 poller_fn(ctx, fut)
119 })?;
120
121 self.load(MAKE_POLLER)
122 .set_name(b"coroutine yield helper")?
123 .eval::<Function<'lua>>()? .call(wrapped_fun)
125 }
126}
127
128struct FutGen<Arg, RetFut, F> {
129 gen: F,
130 cur_fut: Option<Pin<Box<RetFut>>>,
131 _phantom: PhantomData<fn(Arg)>,
132}
133
134impl<Arg, RetFut, F> FutGen<Arg, RetFut, F> {
135 fn new(gen: F) -> Self {
136 FutGen {
137 gen,
138 cur_fut: None,
139 _phantom: PhantomData,
140 }
141 }
142}
143
144impl<'scope, Arg, Ret, RetFut, F> UserData for FutGen<Arg, RetFut, F>
145where
146 Arg: for<'all> FromLuaMulti<'all>,
147 Ret: for<'all> ToLuaMulti<'all>,
148 RetFut: 'scope + Future<Output = Result<Ret>>,
149 F: 'scope + for<'all> FnMut(Context<'all>, Arg) -> RetFut,
150{
151 fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) {
152 methods.add_method_mut("set_arg", |ctx, this, arg: Arg| {
153 assert!(
154 this.cur_fut.is_none(),
155 "called set_arg without first polling previous future to completion"
156 );
157 this.cur_fut = Some(Box::pin((this.gen)(ctx, arg)));
158 Ok(())
159 });
160
161 methods.add_method_mut("poll", |ctx, this, _: ()| {
162 let mut fut = this
163 .cur_fut
164 .take()
165 .expect("called poll without first calling set_arg");
166 FUTURE_CTX.with(|fut_ctx| {
167 let fut_ctx_ref = unsafe { &mut *(*fut_ctx as *mut task::Context) };
169 match Future::poll(fut.as_mut(), fut_ctx_ref) {
170 Poll::Pending => {
171 this.cur_fut = Some(fut); ToLuaMulti::to_lua_multi((rlua::Value::Nil, false), ctx)
173 }
174 Poll::Ready(v) => {
175 let v = ToLuaMulti::to_lua_multi(v?, ctx)?.into_vec();
176 ToLuaMulti::to_lua_multi((v, true), ctx)
177 }
178 }
179 })
180 });
181 }
182}
183
184static MAKE_USERDATA_POLLER: &[u8] = include_bytes!("make-userdata-poller.lua");
185
186pub trait ScopeExt<'lua, 'scope> {
188 fn create_async_function<Arg, Ret, RetFut, F>(
196 &self,
197 ctx: Context<'lua>,
198 func: F,
199 ) -> Result<Function<'lua>>
200 where
201 Arg: 'scope + for<'all> FromLuaMulti<'all>,
202 Ret: 'scope + for<'all> ToLuaMulti<'all>,
203 RetFut: 'scope + Future<Output = Result<Ret>>,
204 F: 'scope + for<'all> Fn(Context<'all>, Arg) -> RetFut;
205
206 fn create_async_function_mut<Arg, Ret, RetFut, F>(
209 &self,
210 ctx: Context<'lua>,
211 func: F,
212 ) -> Result<Function<'lua>>
213 where
214 Arg: 'scope + for<'all> FromLuaMulti<'all>,
215 Ret: 'scope + for<'all> ToLuaMulti<'all>,
216 RetFut: 'scope + Future<Output = Result<Ret>>,
217 F: 'scope + for<'all> FnMut(Context<'all>, Arg) -> RetFut;
218}
219
220impl<'lua, 'scope> ScopeExt<'lua, 'scope> for Scope<'lua, 'scope> {
221 fn create_async_function<Arg, Ret, RetFut, F>(
222 &self,
223 ctx: Context<'lua>,
224 func: F,
225 ) -> Result<Function<'lua>>
226 where
227 Arg: 'scope + for<'all> FromLuaMulti<'all>,
228 Ret: 'scope + for<'all> ToLuaMulti<'all>,
229 RetFut: 'scope + Future<Output = Result<Ret>>,
230 F: 'scope + for<'all> Fn(Context<'all>, Arg) -> RetFut,
231 {
232 let ud = self.create_nonstatic_userdata(FutGen::new(func))?;
233 ctx.load(MAKE_USERDATA_POLLER)
234 .set_name(b"coroutine yield helper")?
235 .eval::<Function<'lua>>()? .call(ud)
237 }
238
239 fn create_async_function_mut<Arg, Ret, RetFut, F>(
240 &self,
241 ctx: Context<'lua>,
242 func: F,
243 ) -> Result<Function<'lua>>
244 where
245 Arg: 'scope + for<'all> FromLuaMulti<'all>,
246 Ret: 'scope + for<'all> ToLuaMulti<'all>,
247 RetFut: 'scope + Future<Output = Result<Ret>>,
248 F: 'scope + for<'all> FnMut(Context<'all>, Arg) -> RetFut,
249 {
250 let ud = self.create_nonstatic_userdata(FutGen::new(func))?;
251 ctx.load(MAKE_USERDATA_POLLER)
252 .set_name(b"coroutine yield helper")?
253 .eval::<Function<'lua>>()? .call(ud)
255 }
256}
257
258struct PollThreadFut<'lua, Arg, Ret> {
259 args: Option<Arg>,
262 ctx: Context<'lua>,
263 thread: Thread<'lua>,
264 _phantom: PhantomData<Ret>,
265}
266
267impl<'lua, Arg, Ret> Future for PollThreadFut<'lua, Arg, Ret>
268where
269 Arg: ToLuaMulti<'lua>,
270 Ret: FromLuaMulti<'lua>,
271{
272 type Output = Result<Ret>;
273
274 fn poll(mut self: Pin<&mut Self>, fut_ctx: &mut task::Context) -> Poll<Result<Ret>> {
275 FUTURE_CTX.set(&(fut_ctx as *mut _ as *mut ()), || {
276 let taken_args = unsafe { self.as_mut().get_unchecked_mut().args.take() };
277
278 let resume_ret = if let Some(a) = taken_args {
279 self.thread.resume::<_, rlua::MultiValue>(a)
280 } else {
281 self.thread.resume::<_, rlua::MultiValue>(())
282 };
283
284 match resume_ret {
285 Err(e) => Poll::Ready(Err(e)),
286 Ok(v) => {
287 match self.thread.status() {
288 ThreadStatus::Resumable => Poll::Pending,
289
290 ThreadStatus::Unresumable => {
291 Poll::Ready(FromLuaMulti::from_lua_multi(v, self.ctx))
292 }
293
294 ThreadStatus::Error => unreachable!(),
296 }
297 }
298 }
299 })
300 }
301}
302
303pub trait FunctionExt<'lua> {
305 fn call_async<'fut, Arg, Ret>(
312 &self,
313 ctx: Context<'lua>,
314 args: Arg,
315 ) -> Pin<Box<dyn 'fut + Future<Output = Result<Ret>>>>
316 where
317 'lua: 'fut,
318 Arg: 'fut + ToLuaMulti<'lua>,
319 Ret: 'fut + FromLuaMulti<'lua>;
320}
321
322impl<'lua> FunctionExt<'lua> for Function<'lua> {
323 fn call_async<'fut, Arg, Ret>(
324 &self,
325 ctx: Context<'lua>,
326 args: Arg,
327 ) -> Pin<Box<dyn 'fut + Future<Output = Result<Ret>>>>
328 where
329 'lua: 'fut,
330 Arg: 'fut + ToLuaMulti<'lua>,
331 Ret: 'fut + FromLuaMulti<'lua>,
332 {
333 let thread = match ctx.create_thread(self.clone()) {
334 Ok(thread) => thread,
335 Err(e) => return Box::pin(future::err(e)),
336 };
337
338 Box::pin(PollThreadFut {
339 args: Some(args),
340 ctx,
341 thread,
342 _phantom: PhantomData,
343 })
344 }
345}
346
347pub trait ChunkExt<'lua, 'a> {
353 fn exec_async<'fut>(
355 self,
356 ctx: Context<'lua>,
357 ) -> Pin<Box<dyn 'fut + Future<Output = Result<()>>>>
358 where
359 'lua: 'fut;
360
361 fn call_async<'fut, Arg, Ret>(
376 self,
377 ctx: Context<'lua>,
378 args: Arg,
379 ) -> Pin<Box<dyn 'fut + Future<Output = Result<Ret>>>>
380 where
381 'lua: 'fut,
382 Arg: 'fut + ToLuaMulti<'lua>,
383 Ret: 'fut + FromLuaMulti<'lua>;
384}
385
386impl<'lua, 'a> ChunkExt<'lua, 'a> for Chunk<'lua, 'a> {
387 fn exec_async<'fut>(
388 self,
389 ctx: Context<'lua>,
390 ) -> Pin<Box<dyn 'fut + Future<Output = Result<()>>>>
391 where
392 'lua: 'fut,
393 {
394 self.call_async(ctx, ())
395 }
396
397 fn call_async<'fut, Arg, Ret>(
414 self,
415 ctx: Context<'lua>,
416 args: Arg,
417 ) -> Pin<Box<dyn 'fut + Future<Output = Result<Ret>>>>
418 where
419 'lua: 'fut,
420 Arg: 'fut + ToLuaMulti<'lua>,
421 Ret: 'fut + FromLuaMulti<'lua>,
422 {
423 let fun = match self.into_function() {
424 Ok(fun) => fun,
425 Err(e) => return Box::pin(future::err(e)),
426 };
427
428 fun.call_async(ctx, args)
429 }
430}
431
432#[cfg(test)]
433mod tests {
434 use super::*;
435
436 use std::cell::Cell;
437 use std::rc::Rc;
438 use std::sync::{Arc, Mutex};
439 use std::time::Duration;
440
441 use futures::executor;
442 use rlua::{Error, Lua};
443
444 #[test]
445 fn async_fn() {
446 let lua = Lua::new();
447
448 lua.context(|lua_ctx| {
449 let globals = lua_ctx.globals();
450
451 let f = lua_ctx
452 .create_async_function(|_, a: usize| future::ok(a + 1))
453 .unwrap();
454 globals.set("f", f).unwrap();
455
456 assert_eq!(
457 executor::block_on(
458 lua_ctx
459 .load(r#"function(a) return f(a) - 1 end"#)
460 .eval::<Function>()
461 .unwrap()
462 .call_async::<_, usize>(lua_ctx, 2)
463 )
464 .unwrap(),
465 2
466 );
467 });
468 }
469
470 #[test]
471 fn actually_awaiting_fn() {
472 let lua = Lua::new();
473
474 lua.context(|lua_ctx| {
475 let globals = lua_ctx.globals();
476
477 let f = lua_ctx
478 .create_async_function(|_, a: usize| async move {
479 futures_timer::Delay::new(Duration::from_millis(50)).await;
480 Ok(a + 1)
481 })
482 .unwrap();
483 globals.set("f", f).unwrap();
484
485 assert_eq!(
486 executor::block_on(
487 lua_ctx
488 .load(r#"function(a) return f(a) - 1 end"#)
489 .set_name(b"example")
490 .expect("failed to set name")
491 .eval::<Function>()
492 .expect("failed to eval")
493 .call_async::<_, usize>(lua_ctx, 2)
494 )
495 .expect("failed to call"),
496 2
497 );
498 });
499 }
500
501 #[test]
502 fn async_fn_mut() {
503 let lua = Lua::new();
504
505 lua.context(|lua_ctx| {
506 let globals = lua_ctx.globals();
507
508 let v = Arc::new(Mutex::new(0));
509 let v_clone = v.clone();
510 let f = lua_ctx
511 .create_async_function_mut(move |_, a: usize| {
512 *v_clone.lock().unwrap() += 1;
513 future::ok(a + 1)
514 })
515 .unwrap();
516 globals.set("f", f).unwrap();
517
518 assert_eq!(*v.lock().unwrap(), 0);
519 assert_eq!(
520 executor::block_on(
521 lua_ctx
522 .load(r#"function(a) return f(a) - 1 end"#)
523 .set_name(b"example")
524 .expect("failed to set name")
525 .eval::<Function>()
526 .expect("failed to eval")
527 .call_async::<_, usize>(lua_ctx, 2)
528 )
529 .expect("failed to call"),
530 2
531 );
532 assert_eq!(*v.lock().unwrap(), 1);
533 });
534 }
535
536 #[test]
537 fn async_chunk() {
538 let lua = Lua::new();
539
540 lua.context(|lua_ctx| {
541 let globals = lua_ctx.globals();
542
543 let f = lua_ctx
544 .create_async_function(|_, a: usize| async move {
545 futures_timer::Delay::new(Duration::from_millis(50)).await;
546 Ok(a + 1)
547 })
548 .unwrap();
549 globals.set("f", f).unwrap();
550
551 executor::block_on(
552 lua_ctx
553 .load(
554 r#"
555 bar = f(1)
556 function foo(a)
557 return a + bar
558 end
559 "#,
560 )
561 .set_name(b"foo")
562 .expect("failed to set name")
563 .exec_async(lua_ctx),
564 )
565 .expect("failed to exec");
566
567 assert_eq!(
568 executor::block_on(
569 lua_ctx
570 .load(r#"return foo(1)"#)
571 .call_async::<_, usize>(lua_ctx, ()),
572 )
573 .expect("failed to call"),
574 3,
575 );
576
577 assert_eq!(
578 lua_ctx
579 .load(r#"foo(1)"#)
580 .eval::<usize>()
581 .expect("failed to eval"),
582 3,
583 );
584
585 });
600 }
601
602 #[test]
638 fn scopes_do_drop_things() {
639 Lua::new().context(|lua| {
640 let rc = Rc::new(Cell::new(0));
641 lua.scope(|scope| {
642 let rc_clone = rc.clone();
643 assert_eq!(Rc::strong_count(&rc), 2);
644 let f: Function = scope
645 .create_async_function(lua, move |_, ()| {
646 rc_clone.set(rc_clone.get() + 21);
647 future::ok(())
648 })
649 .unwrap();
650 assert_eq!(Rc::strong_count(&rc), 2);
651 lua.globals().set("bad", f.clone()).unwrap();
652 assert_eq!(Rc::strong_count(&rc), 2);
653 executor::block_on(f.call_async::<_, ()>(lua, ())).expect("call failed");
654 assert_eq!(Rc::strong_count(&rc), 2);
655 executor::block_on(f.call_async::<_, ()>(lua, ())).expect("call failed");
656 assert_eq!(Rc::strong_count(&rc), 2);
657 });
658 assert_eq!(rc.get(), 42);
659 assert_eq!(Rc::strong_count(&rc), 1);
660
661 let call_res = executor::block_on(
662 lua.globals()
663 .get::<_, Function>("bad")
664 .unwrap()
665 .call_async::<_, ()>(lua, ()),
666 );
667 match call_res {
668 Err(Error::CallbackError { .. }) => {}
669 r => panic!("improper return for destructed function: {:?}", r),
670 };
671 });
672 }
673
674 #[test]
675 fn scopes_async_fn_mut() {
676 Lua::new().context(|lua| {
677 let rc = Rc::new(Cell::new(0));
678 lua.scope(|scope| {
679 let rc_clone = rc.clone();
680 let mut v = 0;
681 assert_eq!(Rc::strong_count(&rc), 2);
682 let f: Function = scope
683 .create_async_function_mut(lua, move |_, ()| {
684 v += 21;
685 rc_clone.set(v);
686 future::ok(())
687 })
688 .unwrap();
689 assert_eq!(Rc::strong_count(&rc), 2);
690 lua.globals().set("bad", f.clone()).unwrap();
691 assert_eq!(Rc::strong_count(&rc), 2);
692 executor::block_on(f.call_async::<_, ()>(lua, ())).expect("call failed");
693 assert_eq!(Rc::strong_count(&rc), 2);
694 executor::block_on(f.call_async::<_, ()>(lua, ())).expect("call failed");
695 assert_eq!(Rc::strong_count(&rc), 2);
696 });
697 assert_eq!(rc.get(), 42);
698 assert_eq!(Rc::strong_count(&rc), 1);
699
700 let call_res = executor::block_on(
701 lua.globals()
702 .get::<_, Function>("bad")
703 .unwrap()
704 .call_async::<_, ()>(lua, ()),
705 );
706 match call_res {
707 Err(Error::CallbackError { .. }) => {}
708 r => panic!("improper return for destructed function: {:?}", r),
709 };
710 });
711 }
712}