burn_central_runtime/inference/
model.rs

1use std::any::Any;
2use std::fmt::{Debug, Display, Formatter};
3use std::thread::JoinHandle;
4
5/// Dedicated host thread owning the model instance to serialize mutable access and allow cheap cloning of access handles.
6pub struct ModelHost<M> {
7    accessor: ModelAccessor<M>,
8    abort_tx: crossbeam::channel::Sender<()>,
9    join_handle: Option<JoinHandle<M>>,
10}
11
12type BoxAny = Box<dyn Any + Send>;
13
14/// Internal message variants for model host thread operations.
15enum Msg<M> {
16    Call {
17        f: Box<dyn FnOnce(&mut M) -> BoxAny + Send>,
18        ret: crossbeam::channel::Sender<BoxAny>,
19    },
20}
21
22impl<M: 'static + Send> ModelHost<M> {
23    /// Spawn a background thread hosting the provided model.
24    pub fn spawn(model: M) -> Self {
25        let (abort_tx, abort_rx) = crossbeam::channel::unbounded::<()>();
26        let (tx, rx) = crossbeam::channel::unbounded::<Msg<M>>();
27        let join_handle = std::thread::spawn(move || {
28            let mut m = model;
29            loop {
30                crossbeam::channel::select! {
31                    recv(rx) -> msg => {
32                        match msg {
33                            Ok(Msg::Call { f, ret }) => {
34                                let r = f(&mut m);
35                                let _ = ret.send(r);
36                            }
37                            Err(_) => break,
38                        }
39                    }
40                    recv(abort_rx) -> _ => {
41                        break;
42                    }
43                }
44            }
45            m
46        });
47        Self {
48            accessor: ModelAccessor { tx },
49            abort_tx,
50            join_handle: Some(join_handle),
51        }
52    }
53
54    /// Get a cloneable accessor to interact with the hosted model.
55    pub fn accessor(&self) -> ModelAccessor<M> {
56        self.accessor.clone()
57    }
58
59    /// Stop the host thread and return ownership of the inner model.
60    pub fn into_model(mut self) -> M {
61        let _ = self.abort_tx.send(());
62
63        self.join_handle
64            .take()
65            .expect("Should have join handle")
66            .join()
67            .expect("Thread should not panic")
68    }
69}
70
71impl<M> std::ops::Deref for ModelHost<M> {
72    type Target = ModelAccessor<M>;
73
74    fn deref(&self) -> &Self::Target {
75        &self.accessor
76    }
77}
78
79impl<M> Drop for ModelHost<M> {
80    fn drop(&mut self) {
81        let _ = self.abort_tx.send(());
82        let _ = self.join_handle.take().unwrap().join();
83    }
84}
85
86/// Cloneable handle used to execute closures against the model on its host thread.
87pub struct ModelAccessor<M> {
88    tx: crossbeam::channel::Sender<Msg<M>>,
89}
90
91impl<M: Debug> Debug for ModelAccessor<M> {
92    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
93        let debug_str = self.submit(|m| format!("{m:?}"));
94        write!(f, "{debug_str}")
95    }
96}
97
98impl<M: Display> Display for ModelAccessor<M> {
99    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
100        let display_str = self.submit(|m| format!("{m}"));
101        write!(f, "{display_str}")
102    }
103}
104
105impl<M> Clone for ModelAccessor<M> {
106    fn clone(&self) -> Self {
107        Self {
108            tx: self.tx.clone(),
109        }
110    }
111}
112
113impl<M> ModelAccessor<M> {
114    /// Run a closure that returns a value on the model thread, waiting for completion.
115    pub fn submit<R: Send + 'static>(&self, f: impl FnOnce(&mut M) -> R + Send + 'static) -> R {
116        let (ret_tx, ret_rx) = crossbeam::channel::bounded(1);
117        let _ = self.tx.send(Msg::Call {
118            f: Box::new(move |m| Box::new(f(m)) as BoxAny),
119            ret: ret_tx,
120        });
121        let r = ret_rx.recv().unwrap();
122        *r.downcast::<R>().unwrap()
123    }
124}