burn_central_runtime/inference/
job.rs

1use super::error::InferenceError;
2use super::streaming::CancelToken;
3use std::thread::JoinHandle;
4
5/// Handle to a running inference job thread.
6///
7/// Provides access to the output stream (`stream`) plus cancellation (`cancel`) and
8/// a `join` method to retrieve the final result (or error) once the handler terminates.
9pub struct JobHandle<S> {
10    pub id: String,
11    pub stream: crossbeam::channel::Receiver<S>,
12    cancel: CancelToken,
13    join: Option<JoinHandle<Result<(), InferenceError>>>,
14}
15
16impl<S> JobHandle<S> {
17    pub fn new(
18        id: String,
19        stream: crossbeam::channel::Receiver<S>,
20        cancel: CancelToken,
21        join: JoinHandle<Result<(), InferenceError>>,
22    ) -> Self {
23        Self {
24            id,
25            stream,
26            cancel,
27            join: Some(join),
28        }
29    }
30
31    /// Cancel the running job. This will signal the job to stop processing as soon as possible.
32    /// Note that this does not immediately kill the thread, but rather requests it to stop.
33    /// The inference function has to use the `CancelToken` to check for cancellation.
34    pub fn cancel(&self) {
35        self.cancel.cancel();
36    }
37
38    /// Wait for the job to finish and return the result.
39    pub fn join(mut self) -> Result<(), InferenceError> {
40        if let Some(join) = self.join.take() {
41            let res = join.join();
42            match res {
43                Ok(r) => r,
44                Err(e) => Err(InferenceError::ThreadPanicked(format!("{e:?}"))),
45            }
46        } else {
47            Ok(())
48        }
49    }
50}