burn_central_runtime/inference/
model.rs1use std::any::Any;
2use std::fmt::{Debug, Display, Formatter};
3use std::thread::JoinHandle;
4
5pub struct ModelHost<M> {
7 accessor: ModelAccessor<M>,
8 abort_tx: crossbeam::channel::Sender<()>,
9 join_handle: Option<JoinHandle<M>>,
10}
11
12type BoxAny = Box<dyn Any + Send>;
13
14enum Msg<M> {
16 Call {
17 f: Box<dyn FnOnce(&mut M) -> BoxAny + Send>,
18 ret: crossbeam::channel::Sender<BoxAny>,
19 },
20}
21
22impl<M: 'static + Send> ModelHost<M> {
23 pub fn spawn(model: M) -> Self {
25 let (abort_tx, abort_rx) = crossbeam::channel::unbounded::<()>();
26 let (tx, rx) = crossbeam::channel::unbounded::<Msg<M>>();
27 let join_handle = std::thread::spawn(move || {
28 let mut m = model;
29 loop {
30 crossbeam::channel::select! {
31 recv(rx) -> msg => {
32 match msg {
33 Ok(Msg::Call { f, ret }) => {
34 let r = f(&mut m);
35 let _ = ret.send(r);
36 }
37 Err(_) => break,
38 }
39 }
40 recv(abort_rx) -> _ => {
41 break;
42 }
43 }
44 }
45 m
46 });
47 Self {
48 accessor: ModelAccessor { tx },
49 abort_tx,
50 join_handle: Some(join_handle),
51 }
52 }
53
54 pub fn accessor(&self) -> ModelAccessor<M> {
56 self.accessor.clone()
57 }
58
59 pub fn into_model(mut self) -> M {
61 let _ = self.abort_tx.send(());
62
63 self.join_handle
64 .take()
65 .expect("Should have join handle")
66 .join()
67 .expect("Thread should not panic")
68 }
69}
70
71impl<M> std::ops::Deref for ModelHost<M> {
72 type Target = ModelAccessor<M>;
73
74 fn deref(&self) -> &Self::Target {
75 &self.accessor
76 }
77}
78
79impl<M> Drop for ModelHost<M> {
80 fn drop(&mut self) {
81 let _ = self.abort_tx.send(());
82 let _ = self.join_handle.take().unwrap().join();
83 }
84}
85
86pub struct ModelAccessor<M> {
88 tx: crossbeam::channel::Sender<Msg<M>>,
89}
90
91impl<M: Debug> Debug for ModelAccessor<M> {
92 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
93 let debug_str = self.submit(|m| format!("{m:?}"));
94 write!(f, "{debug_str}")
95 }
96}
97
98impl<M: Display> Display for ModelAccessor<M> {
99 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
100 let display_str = self.submit(|m| format!("{m}"));
101 write!(f, "{display_str}")
102 }
103}
104
105impl<M> Clone for ModelAccessor<M> {
106 fn clone(&self) -> Self {
107 Self {
108 tx: self.tx.clone(),
109 }
110 }
111}
112
113impl<M> ModelAccessor<M> {
114 pub fn submit<R: Send + 'static>(&self, f: impl FnOnce(&mut M) -> R + Send + 'static) -> R {
116 let (ret_tx, ret_rx) = crossbeam::channel::bounded(1);
117 let _ = self.tx.send(Msg::Call {
118 f: Box::new(move |m| Box::new(f(m)) as BoxAny),
119 ret: ret_tx,
120 });
121 let r = ret_rx.recv().unwrap();
122 *r.downcast::<R>().unwrap()
123 }
124}