use objc::{msg_send, sel, sel_impl};
use crate::{
core_ml::MLModel,
foundation::{NSArray, NSDictionary, NSError, NSNumber, NSString, UInt, NSURL},
object,
objective_c_runtime::{
macros::interface_impl,
nil,
traits::{FromId, PNSObject},
},
utils::to_optional,
};
use super::NLModelConfiguration;
object! {
unsafe pub struct NLModel;
}
#[interface_impl(NSObject)]
impl NLModel {
#[method]
pub fn model_with_mlmodel(model: &MLModel) -> Result<Self, NSError>
where
Self: Sized + FromId,
{
unsafe {
let mut error = NSError::m_alloc();
let ptr = Self::from_id(
msg_send![Self::m_class(), modelWithMLModel: model.m_self() error: &mut error],
);
if error.m_self() != nil {
Err(error)
} else {
Ok(ptr)
}
}
}
#[method]
pub fn model_with_contents_of_url(url: &NSURL) -> Result<Self, NSError>
where
Self: Sized + FromId,
{
unsafe {
let mut error = NSError::m_alloc();
let ptr = Self::from_id(
msg_send![Self::m_class(), modelWithContentsOfURL: url.m_self() error: &mut error],
);
if error.m_self() != nil {
Err(error)
} else {
Ok(ptr)
}
}
}
#[method]
pub fn predicted_label_for_string(&self, string: &NSString) -> Option<NSString> {
unsafe { to_optional(msg_send![self.m_self(), predictedLabelForString: string.m_self()]) }
}
#[method]
pub fn predicted_labels_for_tokens(&self, tokens: &NSArray<NSString>) -> NSArray<NSString> {
unsafe {
NSArray::from_id(msg_send![self.m_self(), predictedLabelsForTokens: tokens.m_self()])
}
}
#[method]
pub fn predicted_label_hypotheses_for_string_maximum_count(
&self,
string: &NSString,
max_count: UInt,
) -> NSDictionary<NSString, NSNumber> {
unsafe {
NSDictionary::from_id(
msg_send![self.m_self(), predictedLabelHypothesesForString: string.m_self() maximumCount: max_count],
)
}
}
#[method]
pub fn predicted_label_hypotheses_for_tokens_maximum_count(
&self,
tokens: &NSArray<NSString>,
max_count: UInt,
) -> NSArray<NSDictionary<NSString, NSNumber>> {
unsafe {
NSArray::from_id(
msg_send![self.m_self(), predictedLabelHypothesesForTokens: tokens.m_self() maximumCount: max_count],
)
}
}
#[property]
pub fn configuration(&self) -> NLModelConfiguration {
unsafe { NLModelConfiguration::from_id(msg_send![self.m_self(), configuration]) }
}
}