candle_coreml/state.rs
1//! CoreML state management for autoregressive inference
2
3use candle_core::Error as CandleError;
4
5#[cfg(target_os = "macos")]
6use objc2::rc::{autoreleasepool, Retained};
7#[cfg(target_os = "macos")]
8use objc2_core_ml::{MLModel, MLState};
9
10/// Opaque wrapper around Core ML's `MLState` for stateful inference.
11///
12/// This provides persistent state management for autoregressive models,
13/// enabling efficient KV-cache reuse across token generation steps.
14///
15/// # Thread Safety
16///
17/// Each `CoreMLState` instance must be used by only one thread at a time.
18/// Concurrent predictions using the same state object result in undefined behavior.
19///
20/// # Example
21///
22/// ```rust,no_run
23/// use candle_core::{Device, Tensor};
24/// use candle_coreml::{CoreMLModel, Config};
25///
26/// # fn example() -> Result<(), Box<dyn std::error::Error>> {
27/// let model = CoreMLModel::load("model.mlmodelc")?;
28/// let device = Device::Cpu;
29///
30/// // Create state for efficient autoregressive generation
31/// let mut state = model.make_state()?;
32///
33/// // Generate tokens sequentially with persistent KV-cache
34/// for i in 0..10 {
35/// let input = Tensor::ones((1, 1), candle_core::DType::I64, &device)?;
36/// let output = model.predict_with_state(&[&input], &mut state)?;
37/// // Process output...
38/// }
39/// # Ok(())
40/// # }
41/// ```
42#[cfg(target_os = "macos")]
43pub struct CoreMLState {
44 inner: Retained<MLState>,
45}
46
47#[cfg(not(target_os = "macos"))]
48pub struct CoreMLState {
49 _phantom: std::marker::PhantomData<()>,
50}
51
52impl std::fmt::Debug for CoreMLState {
53 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54 f.debug_struct("CoreMLState").finish_non_exhaustive()
55 }
56}
57
58#[cfg(target_os = "macos")]
59impl CoreMLState {
60 /// Create a new state object for the given CoreML model.
61 ///
62 /// # Arguments
63 ///
64 /// * `model` - Reference to the MLModel to create state for
65 ///
66 /// # Returns
67 ///
68 /// A new `CoreMLState` instance, or an error if state creation fails.
69 /// For stateless models, this returns an empty state object that can
70 /// still be used with stateful prediction methods.
71 pub(crate) fn new(model: &Retained<MLModel>) -> Result<Self, CandleError> {
72 autoreleasepool(|_| {
73 // SAFETY: CoreML's MLModel::newState returns a valid retained MLState associated with the model.
74 // It does not borrow stack data and follows objc2 ownership rules.
75 let state = unsafe { model.newState() };
76 Ok(CoreMLState { inner: state })
77 })
78 }
79
80 /// Get a reference to the underlying MLState for CoreML operations.
81 pub(crate) fn inner(&self) -> &MLState {
82 &self.inner
83 }
84}
85
86#[cfg(not(target_os = "macos"))]
87impl CoreMLState {
88 pub(crate) fn new(_model: &()) -> Result<Self, CandleError> {
89 Err(CandleError::Msg(
90 "CoreML state is only available on macOS".to_string(),
91 ))
92 }
93}