Skip to main content

coreml_native/
description.rs

1//! Model introspection types.
2//!
3//! Covers FR-4.1, FR-4.2, FR-4.3.
4
5use crate::tensor::DataType;
6
7/// Constraint on the shape of a multi-array feature.
8#[derive(Debug, Clone, PartialEq)]
9pub enum ShapeConstraint {
10    /// Fixed shape -- only one shape is allowed.
11    Fixed(Vec<usize>),
12    /// One of several enumerated shapes.
13    Enumerated(Vec<Vec<usize>>),
14    /// Each dimension has an independent range (min, max).
15    Range(Vec<(usize, usize)>),
16    /// Unknown or unspecified constraint.
17    Unspecified,
18}
19
20/// Description of a model feature (input or output).
21#[derive(Debug, Clone)]
22pub struct FeatureDescription {
23    name: String,
24    feature_type: FeatureType,
25    shape: Option<Vec<usize>>,
26    data_type: Option<DataType>,
27    is_optional: bool,
28    /// For MultiArray features, the shape constraint type.
29    shape_constraint: Option<ShapeConstraint>,
30}
31
32impl FeatureDescription {
33    pub fn name(&self) -> &str { &self.name }
34    pub fn feature_type(&self) -> &FeatureType { &self.feature_type }
35    pub fn shape(&self) -> Option<&[usize]> { self.shape.as_deref() }
36    pub fn data_type(&self) -> Option<DataType> { self.data_type }
37    pub fn is_optional(&self) -> bool { self.is_optional }
38    pub fn shape_constraint(&self) -> Option<&ShapeConstraint> { self.shape_constraint.as_ref() }
39}
40
41/// The type of a model feature.
42#[derive(Debug, Clone, Copy, PartialEq, Eq)]
43pub enum FeatureType {
44    MultiArray,
45    Image,
46    Dictionary,
47    Sequence,
48    String,
49    Int64,
50    Double,
51    Invalid,
52}
53
54impl std::fmt::Display for FeatureType {
55    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
56        match self {
57            Self::MultiArray => write!(f, "MultiArray"),
58            Self::Image => write!(f, "Image"),
59            Self::Dictionary => write!(f, "Dictionary"),
60            Self::Sequence => write!(f, "Sequence"),
61            Self::String => write!(f, "String"),
62            Self::Int64 => write!(f, "Int64"),
63            Self::Double => write!(f, "Double"),
64            Self::Invalid => write!(f, "Invalid"),
65        }
66    }
67}
68
69/// Model metadata.
70#[derive(Debug, Clone, Default)]
71pub struct ModelMetadata {
72    pub author: Option<String>,
73    pub description: Option<String>,
74    pub version: Option<String>,
75    pub license: Option<String>,
76    /// The name of the predicted feature (for classifier models).
77    pub predicted_feature_name: Option<String>,
78    /// The name of the predicted probabilities feature (for classifier models).
79    pub predicted_probabilities_name: Option<String>,
80    /// Whether the model supports on-device updates.
81    pub is_updatable: bool,
82}
83
84// ─── Apple platform builders ────────────────────────────────────────────────
85
86#[cfg(target_vendor = "apple")]
87pub(crate) fn extract_features(
88    descriptions: &objc2_foundation::NSDictionary<
89        objc2_foundation::NSString,
90        objc2_core_ml::MLFeatureDescription,
91    >,
92) -> Vec<FeatureDescription> {
93    use crate::ffi;
94    use objc2_core_ml::{MLFeatureType, MLMultiArrayShapeConstraintType};
95
96    let mut result = Vec::new();
97    let keys = descriptions.allKeys();
98
99    for key in keys.iter() {
100        let name = ffi::nsstring_to_string(&key);
101
102        if let Some(desc) = descriptions.objectForKey(&key) {
103            let ft = unsafe { desc.r#type() };
104            let is_optional = unsafe { desc.isOptional() };
105
106            let feature_type = match ft {
107                MLFeatureType::MultiArray => FeatureType::MultiArray,
108                MLFeatureType::Image => FeatureType::Image,
109                MLFeatureType::Dictionary => FeatureType::Dictionary,
110                MLFeatureType::Sequence => FeatureType::Sequence,
111                MLFeatureType::String => FeatureType::String,
112                MLFeatureType::Int64 => FeatureType::Int64,
113                MLFeatureType::Double => FeatureType::Double,
114                _ => FeatureType::Invalid,
115            };
116
117            let (shape, data_type, shape_constraint) =
118                if feature_type == FeatureType::MultiArray {
119                    let constraint = unsafe { desc.multiArrayConstraint() };
120                    match constraint {
121                        Some(c) => {
122                            let ns_shape = unsafe { c.shape() };
123                            let shape = ffi::nsarray_to_shape(&ns_shape);
124                            let dt_raw = unsafe { c.dataType() };
125                            let dt = ffi::ml_to_datatype(dt_raw.0);
126
127                            let sc = unsafe { c.shapeConstraint() };
128                            let sc_type = unsafe { sc.r#type() };
129                            let sc_val = match sc_type {
130                                MLMultiArrayShapeConstraintType::Enumerated => {
131                                    let enum_shapes = unsafe { sc.enumeratedShapes() };
132                                    let mut shapes = Vec::new();
133                                    for i in 0..enum_shapes.len() {
134                                        let s = enum_shapes.objectAtIndex(i);
135                                        shapes.push(ffi::nsarray_to_shape(&s));
136                                    }
137                                    ShapeConstraint::Enumerated(shapes)
138                                }
139                                MLMultiArrayShapeConstraintType::Range => {
140                                    let range_vals = unsafe { sc.sizeRangeForDimension() };
141                                    let mut ranges = Vec::new();
142                                    for i in 0..range_vals.len() {
143                                        let val = range_vals.objectAtIndex(i);
144                                        let r = unsafe { val.rangeValue() };
145                                        let lower = r.location;
146                                        let upper = lower + r.length;
147                                        ranges.push((lower, upper));
148                                    }
149                                    ShapeConstraint::Range(ranges)
150                                }
151                                _ => ShapeConstraint::Unspecified,
152                            };
153
154                            (Some(shape), dt, Some(sc_val))
155                        }
156                        None => (None, None, None),
157                    }
158                } else {
159                    (None, None, None)
160                };
161
162            result.push(FeatureDescription {
163                name,
164                feature_type,
165                shape,
166                data_type,
167                is_optional,
168                shape_constraint,
169            });
170        }
171    }
172
173    result.sort_by(|a, b| a.name.cmp(&b.name));
174    result
175}
176
177#[cfg(target_vendor = "apple")]
178pub(crate) fn extract_metadata(
179    model_desc: &objc2_core_ml::MLModelDescription,
180) -> ModelMetadata {
181    use crate::ffi;
182
183    let meta = unsafe { model_desc.metadata() };
184    let mut result = ModelMetadata::default();
185
186    // Metadata keys are NSStrings. Try known keys.
187    let author_key = ffi::str_to_nsstring("MLModelAuthorKey");
188    let desc_key = ffi::str_to_nsstring("MLModelDescriptionKey");
189    let version_key = ffi::str_to_nsstring("MLModelVersionStringKey");
190    let license_key = ffi::str_to_nsstring("MLModelLicenseKey");
191
192    if let Some(v) = meta.objectForKey(&author_key) {
193        // Try to downcast to NSString
194        if let Some(s) = v.downcast_ref::<objc2_foundation::NSString>() {
195            result.author = Some(ffi::nsstring_to_string(s));
196        }
197    }
198    if let Some(v) = meta.objectForKey(&desc_key) {
199        if let Some(s) = v.downcast_ref::<objc2_foundation::NSString>() {
200            result.description = Some(ffi::nsstring_to_string(s));
201        }
202    }
203    if let Some(v) = meta.objectForKey(&version_key) {
204        if let Some(s) = v.downcast_ref::<objc2_foundation::NSString>() {
205            result.version = Some(ffi::nsstring_to_string(s));
206        }
207    }
208    if let Some(v) = meta.objectForKey(&license_key) {
209        if let Some(s) = v.downcast_ref::<objc2_foundation::NSString>() {
210            result.license = Some(ffi::nsstring_to_string(s));
211        }
212    }
213
214    result.predicted_feature_name = unsafe { model_desc.predictedFeatureName() }
215        .map(|s| ffi::nsstring_to_string(&s));
216    result.predicted_probabilities_name = unsafe { model_desc.predictedProbabilitiesName() }
217        .map(|s| ffi::nsstring_to_string(&s));
218    result.is_updatable = unsafe { model_desc.isUpdatable() };
219
220    result
221}
222
223#[cfg(test)]
224mod tests {
225    use super::*;
226
227    #[test]
228    fn feature_type_display() {
229        assert_eq!(format!("{}", FeatureType::MultiArray), "MultiArray");
230        assert_eq!(format!("{}", FeatureType::Image), "Image");
231    }
232
233    #[test]
234    fn feature_type_equality() {
235        assert_eq!(FeatureType::MultiArray, FeatureType::MultiArray);
236        assert_ne!(FeatureType::MultiArray, FeatureType::Image);
237    }
238
239    #[test]
240    fn metadata_default() {
241        let m = ModelMetadata::default();
242        assert!(m.author.is_none());
243        assert!(m.description.is_none());
244        assert!(m.version.is_none());
245        assert!(m.license.is_none());
246        assert!(m.predicted_feature_name.is_none());
247        assert!(m.predicted_probabilities_name.is_none());
248        assert!(!m.is_updatable);
249    }
250
251    #[test]
252    fn feature_description_accessors() {
253        let fd = FeatureDescription {
254            name: "input".into(),
255            feature_type: FeatureType::MultiArray,
256            shape: Some(vec![1, 128, 500]),
257            data_type: Some(DataType::Float32),
258            is_optional: false,
259            shape_constraint: Some(ShapeConstraint::Fixed(vec![1, 128, 500])),
260        };
261        assert_eq!(fd.name(), "input");
262        assert_eq!(fd.feature_type(), &FeatureType::MultiArray);
263        assert_eq!(fd.shape(), Some(&[1, 128, 500][..]));
264        assert_eq!(fd.data_type(), Some(DataType::Float32));
265        assert!(!fd.is_optional());
266        assert_eq!(
267            fd.shape_constraint(),
268            Some(&ShapeConstraint::Fixed(vec![1, 128, 500])),
269        );
270    }
271
272    #[test]
273    fn shape_constraint_types() {
274        let fixed = ShapeConstraint::Fixed(vec![1, 128]);
275        let enum_c = ShapeConstraint::Enumerated(vec![vec![1, 128], vec![1, 256]]);
276        let range_c = ShapeConstraint::Range(vec![(1, 10), (64, 512)]);
277        let unspec = ShapeConstraint::Unspecified;
278
279        assert_ne!(fixed, enum_c);
280        assert_ne!(enum_c, range_c);
281        assert_ne!(range_c, unspec);
282        assert_eq!(fixed, ShapeConstraint::Fixed(vec![1, 128]));
283    }
284}