1use crate::tensor::DataType;
6
7#[derive(Debug, Clone, PartialEq)]
9pub enum ShapeConstraint {
10 Fixed(Vec<usize>),
12 Enumerated(Vec<Vec<usize>>),
14 Range(Vec<(usize, usize)>),
16 Unspecified,
18}
19
20#[derive(Debug, Clone)]
22pub struct FeatureDescription {
23 name: String,
24 feature_type: FeatureType,
25 shape: Option<Vec<usize>>,
26 data_type: Option<DataType>,
27 is_optional: bool,
28 shape_constraint: Option<ShapeConstraint>,
30}
31
32impl FeatureDescription {
33 pub fn name(&self) -> &str { &self.name }
34 pub fn feature_type(&self) -> &FeatureType { &self.feature_type }
35 pub fn shape(&self) -> Option<&[usize]> { self.shape.as_deref() }
36 pub fn data_type(&self) -> Option<DataType> { self.data_type }
37 pub fn is_optional(&self) -> bool { self.is_optional }
38 pub fn shape_constraint(&self) -> Option<&ShapeConstraint> { self.shape_constraint.as_ref() }
39}
40
41#[derive(Debug, Clone, Copy, PartialEq, Eq)]
43pub enum FeatureType {
44 MultiArray,
45 Image,
46 Dictionary,
47 Sequence,
48 String,
49 Int64,
50 Double,
51 Invalid,
52}
53
54impl std::fmt::Display for FeatureType {
55 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
56 match self {
57 Self::MultiArray => write!(f, "MultiArray"),
58 Self::Image => write!(f, "Image"),
59 Self::Dictionary => write!(f, "Dictionary"),
60 Self::Sequence => write!(f, "Sequence"),
61 Self::String => write!(f, "String"),
62 Self::Int64 => write!(f, "Int64"),
63 Self::Double => write!(f, "Double"),
64 Self::Invalid => write!(f, "Invalid"),
65 }
66 }
67}
68
69#[derive(Debug, Clone, Default)]
71pub struct ModelMetadata {
72 pub author: Option<String>,
73 pub description: Option<String>,
74 pub version: Option<String>,
75 pub license: Option<String>,
76 pub predicted_feature_name: Option<String>,
78 pub predicted_probabilities_name: Option<String>,
80 pub is_updatable: bool,
82}
83
84#[cfg(target_vendor = "apple")]
87pub(crate) fn extract_features(
88 descriptions: &objc2_foundation::NSDictionary<
89 objc2_foundation::NSString,
90 objc2_core_ml::MLFeatureDescription,
91 >,
92) -> Vec<FeatureDescription> {
93 use crate::ffi;
94 use objc2_core_ml::{MLFeatureType, MLMultiArrayShapeConstraintType};
95
96 let mut result = Vec::new();
97 let keys = descriptions.allKeys();
98
99 for key in keys.iter() {
100 let name = ffi::nsstring_to_string(&key);
101
102 if let Some(desc) = descriptions.objectForKey(&key) {
103 let ft = unsafe { desc.r#type() };
104 let is_optional = unsafe { desc.isOptional() };
105
106 let feature_type = match ft {
107 MLFeatureType::MultiArray => FeatureType::MultiArray,
108 MLFeatureType::Image => FeatureType::Image,
109 MLFeatureType::Dictionary => FeatureType::Dictionary,
110 MLFeatureType::Sequence => FeatureType::Sequence,
111 MLFeatureType::String => FeatureType::String,
112 MLFeatureType::Int64 => FeatureType::Int64,
113 MLFeatureType::Double => FeatureType::Double,
114 _ => FeatureType::Invalid,
115 };
116
117 let (shape, data_type, shape_constraint) =
118 if feature_type == FeatureType::MultiArray {
119 let constraint = unsafe { desc.multiArrayConstraint() };
120 match constraint {
121 Some(c) => {
122 let ns_shape = unsafe { c.shape() };
123 let shape = ffi::nsarray_to_shape(&ns_shape);
124 let dt_raw = unsafe { c.dataType() };
125 let dt = ffi::ml_to_datatype(dt_raw.0);
126
127 let sc = unsafe { c.shapeConstraint() };
128 let sc_type = unsafe { sc.r#type() };
129 let sc_val = match sc_type {
130 MLMultiArrayShapeConstraintType::Enumerated => {
131 let enum_shapes = unsafe { sc.enumeratedShapes() };
132 let mut shapes = Vec::new();
133 for i in 0..enum_shapes.len() {
134 let s = enum_shapes.objectAtIndex(i);
135 shapes.push(ffi::nsarray_to_shape(&s));
136 }
137 ShapeConstraint::Enumerated(shapes)
138 }
139 MLMultiArrayShapeConstraintType::Range => {
140 let range_vals = unsafe { sc.sizeRangeForDimension() };
141 let mut ranges = Vec::new();
142 for i in 0..range_vals.len() {
143 let val = range_vals.objectAtIndex(i);
144 let r = unsafe { val.rangeValue() };
145 let lower = r.location;
146 let upper = lower + r.length;
147 ranges.push((lower, upper));
148 }
149 ShapeConstraint::Range(ranges)
150 }
151 _ => ShapeConstraint::Unspecified,
152 };
153
154 (Some(shape), dt, Some(sc_val))
155 }
156 None => (None, None, None),
157 }
158 } else {
159 (None, None, None)
160 };
161
162 result.push(FeatureDescription {
163 name,
164 feature_type,
165 shape,
166 data_type,
167 is_optional,
168 shape_constraint,
169 });
170 }
171 }
172
173 result.sort_by(|a, b| a.name.cmp(&b.name));
174 result
175}
176
177#[cfg(target_vendor = "apple")]
178pub(crate) fn extract_metadata(
179 model_desc: &objc2_core_ml::MLModelDescription,
180) -> ModelMetadata {
181 use crate::ffi;
182
183 let meta = unsafe { model_desc.metadata() };
184 let mut result = ModelMetadata::default();
185
186 let author_key = ffi::str_to_nsstring("MLModelAuthorKey");
188 let desc_key = ffi::str_to_nsstring("MLModelDescriptionKey");
189 let version_key = ffi::str_to_nsstring("MLModelVersionStringKey");
190 let license_key = ffi::str_to_nsstring("MLModelLicenseKey");
191
192 if let Some(v) = meta.objectForKey(&author_key) {
193 if let Some(s) = v.downcast_ref::<objc2_foundation::NSString>() {
195 result.author = Some(ffi::nsstring_to_string(s));
196 }
197 }
198 if let Some(v) = meta.objectForKey(&desc_key) {
199 if let Some(s) = v.downcast_ref::<objc2_foundation::NSString>() {
200 result.description = Some(ffi::nsstring_to_string(s));
201 }
202 }
203 if let Some(v) = meta.objectForKey(&version_key) {
204 if let Some(s) = v.downcast_ref::<objc2_foundation::NSString>() {
205 result.version = Some(ffi::nsstring_to_string(s));
206 }
207 }
208 if let Some(v) = meta.objectForKey(&license_key) {
209 if let Some(s) = v.downcast_ref::<objc2_foundation::NSString>() {
210 result.license = Some(ffi::nsstring_to_string(s));
211 }
212 }
213
214 result.predicted_feature_name = unsafe { model_desc.predictedFeatureName() }
215 .map(|s| ffi::nsstring_to_string(&s));
216 result.predicted_probabilities_name = unsafe { model_desc.predictedProbabilitiesName() }
217 .map(|s| ffi::nsstring_to_string(&s));
218 result.is_updatable = unsafe { model_desc.isUpdatable() };
219
220 result
221}
222
223#[cfg(test)]
224mod tests {
225 use super::*;
226
227 #[test]
228 fn feature_type_display() {
229 assert_eq!(format!("{}", FeatureType::MultiArray), "MultiArray");
230 assert_eq!(format!("{}", FeatureType::Image), "Image");
231 }
232
233 #[test]
234 fn feature_type_equality() {
235 assert_eq!(FeatureType::MultiArray, FeatureType::MultiArray);
236 assert_ne!(FeatureType::MultiArray, FeatureType::Image);
237 }
238
239 #[test]
240 fn metadata_default() {
241 let m = ModelMetadata::default();
242 assert!(m.author.is_none());
243 assert!(m.description.is_none());
244 assert!(m.version.is_none());
245 assert!(m.license.is_none());
246 assert!(m.predicted_feature_name.is_none());
247 assert!(m.predicted_probabilities_name.is_none());
248 assert!(!m.is_updatable);
249 }
250
251 #[test]
252 fn feature_description_accessors() {
253 let fd = FeatureDescription {
254 name: "input".into(),
255 feature_type: FeatureType::MultiArray,
256 shape: Some(vec![1, 128, 500]),
257 data_type: Some(DataType::Float32),
258 is_optional: false,
259 shape_constraint: Some(ShapeConstraint::Fixed(vec![1, 128, 500])),
260 };
261 assert_eq!(fd.name(), "input");
262 assert_eq!(fd.feature_type(), &FeatureType::MultiArray);
263 assert_eq!(fd.shape(), Some(&[1, 128, 500][..]));
264 assert_eq!(fd.data_type(), Some(DataType::Float32));
265 assert!(!fd.is_optional());
266 assert_eq!(
267 fd.shape_constraint(),
268 Some(&ShapeConstraint::Fixed(vec![1, 128, 500])),
269 );
270 }
271
272 #[test]
273 fn shape_constraint_types() {
274 let fixed = ShapeConstraint::Fixed(vec![1, 128]);
275 let enum_c = ShapeConstraint::Enumerated(vec![vec![1, 128], vec![1, 256]]);
276 let range_c = ShapeConstraint::Range(vec![(1, 10), (64, 512)]);
277 let unspec = ShapeConstraint::Unspecified;
278
279 assert_ne!(fixed, enum_c);
280 assert_ne!(enum_c, range_c);
281 assert_ne!(range_c, unspec);
282 assert_eq!(fixed, ShapeConstraint::Fixed(vec![1, 128]));
283 }
284}