candle_coreml/
state.rs

1//! CoreML state management for autoregressive inference
2
3use candle_core::Error as CandleError;
4
5#[cfg(target_os = "macos")]
6use objc2::rc::{autoreleasepool, Retained};
7#[cfg(target_os = "macos")]
8use objc2_core_ml::{MLModel, MLState};
9
10/// Opaque wrapper around Core ML's `MLState` for stateful inference.
11///
12/// This provides persistent state management for autoregressive models,
13/// enabling efficient KV-cache reuse across token generation steps.
14///
15/// # Thread Safety
16///
17/// Each `CoreMLState` instance must be used by only one thread at a time.
18/// Concurrent predictions using the same state object result in undefined behavior.
19///
20/// # Example
21///
22/// ```rust,no_run
23/// use candle_core::{Device, Tensor};
24/// use candle_coreml::{CoreMLModel, Config};
25///
26/// # fn example() -> Result<(), Box<dyn std::error::Error>> {
27/// let model = CoreMLModel::load("model.mlmodelc")?;
28/// let device = Device::Cpu;
29///
30/// // Create state for efficient autoregressive generation
31/// let mut state = model.make_state()?;
32///
33/// // Generate tokens sequentially with persistent KV-cache
34/// for i in 0..10 {
35///     let input = Tensor::ones((1, 1), candle_core::DType::I64, &device)?;
36///     let output = model.predict_with_state(&[&input], &mut state)?;
37///     // Process output...
38/// }
39/// # Ok(())
40/// # }
41/// ```
42#[cfg(target_os = "macos")]
43pub struct CoreMLState {
44    inner: Retained<MLState>,
45}
46
47#[cfg(not(target_os = "macos"))]
48pub struct CoreMLState {
49    _phantom: std::marker::PhantomData<()>,
50}
51
52impl std::fmt::Debug for CoreMLState {
53    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54        f.debug_struct("CoreMLState").finish_non_exhaustive()
55    }
56}
57
58#[cfg(target_os = "macos")]
59impl CoreMLState {
60    /// Create a new state object for the given CoreML model.
61    ///
62    /// # Arguments
63    ///
64    /// * `model` - Reference to the MLModel to create state for
65    ///
66    /// # Returns
67    ///
68    /// A new `CoreMLState` instance, or an error if state creation fails.
69    /// For stateless models, this returns an empty state object that can
70    /// still be used with stateful prediction methods.
71    pub(crate) fn new(model: &Retained<MLModel>) -> Result<Self, CandleError> {
72        autoreleasepool(|_| {
73            // SAFETY: CoreML's MLModel::newState returns a valid retained MLState associated with the model.
74            // It does not borrow stack data and follows objc2 ownership rules.
75            let state = unsafe { model.newState() };
76            Ok(CoreMLState { inner: state })
77        })
78    }
79
80    /// Get a reference to the underlying MLState for CoreML operations.
81    pub(crate) fn inner(&self) -> &MLState {
82        &self.inner
83    }
84}
85
86#[cfg(not(target_os = "macos"))]
87impl CoreMLState {
88    pub(crate) fn new(_model: &()) -> Result<Self, CandleError> {
89        Err(CandleError::Msg(
90            "CoreML state is only available on macOS".to_string(),
91        ))
92    }
93}