use block::{ConcreteBlock, IntoConcreteBlock};
use libc::{c_double, c_float};
use objc::{msg_send, sel, sel_impl};
use crate::{
foundation::{NSArray, NSDictionary, NSError, NSIndexSet, NSNumber, NSString, UInt, NSURL},
object,
objective_c_runtime::{
macros::interface_impl,
nil,
traits::{FromId, PNSObject},
},
utils::{to_bool, to_optional},
};
use super::NLLanguage;
pub type NLDistance = c_double;
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)]
#[repr(i64)]
pub enum NLDistanceType {
Cosine,
}
object! {
unsafe pub struct NLEmbedding;
}
#[interface_impl(NSObject)]
impl NLEmbedding {
#[method]
pub fn word_embedding_for_language(language: NLLanguage) -> Option<NLEmbedding> {
unsafe {
to_optional(msg_send![
Self::m_class(),
wordEmbeddingForLanguage: language
])
}
}
#[method]
pub fn word_embedding_for_language_revision(
language: NLLanguage,
revision: UInt,
) -> Option<NLEmbedding> {
unsafe {
to_optional(
msg_send![Self::m_class(), wordEmbeddingForLanguage:language revision: revision],
)
}
}
#[method]
pub fn embedding_with_contents_of_url(url: &NSURL) -> Result<Self, NSError>
where
Self: Sized + FromId,
{
let mut error = NSError::m_alloc();
let ptr = unsafe {
Self::from_id(
msg_send![Self::m_class(), embeddingWithContentsOfURL:url.m_self() error: &mut error],
)
};
if error.m_self() != nil {
Err(error)
} else {
Ok(ptr)
}
}
#[method]
pub fn sentence_embedding_for_language(language: NLLanguage) -> Option<NLEmbedding> {
unsafe {
to_optional(msg_send![
Self::m_class(),
sentenceEmbeddingForLanguage: language
])
}
}
#[method]
pub fn sentence_embedding_for_language_revision(
language: NLLanguage,
revision: UInt,
) -> Option<NLEmbedding> {
unsafe {
to_optional(msg_send![
Self::m_class(),
sentenceEmbeddingForLanguage: language revision: revision
])
}
}
#[method]
pub fn neighbors_for_string_maximum_count_distance_type(
&self,
string: &NSString,
max_count: UInt,
distance_type: NLDistanceType,
) -> NSArray<NSString> {
unsafe {
NSArray::from_id(
msg_send![self.m_self(), neighborsForString:string.m_self() maximumCount: max_count distanceType: distance_type],
)
}
}
#[method]
pub fn neighbors_for_string_maximum_count_maximum_distance_distance_type(
&self,
string: &NSString,
max_count: UInt,
max_distance: NLDistance,
distance_type: NLDistanceType,
) -> NSArray<NSString> {
unsafe {
NSArray::from_id(
msg_send![self.m_self(), neighborsForString: string.m_self() maximumCount: max_count maximumDistance: max_distance distanceType: distance_type],
)
}
}
#[method]
pub fn neighbors_for_vector_maximum_count_distance_type(
&self,
vector: &NSArray<NSNumber>,
max_count: UInt,
distance_type: NLDistanceType,
) -> NSArray<NSString> {
unsafe {
NSArray::from_id(
msg_send![self.m_self(), neighborsForVector:vector.m_self() maximumCount: max_count distanceType: distance_type],
)
}
}
#[method]
pub fn neighbors_for_vector_maximum_count_maximum_distance_distance_type(
&self,
vector: &NSArray<NSNumber>,
max_count: UInt,
max_distance: NLDistance,
distance_type: NLDistanceType,
) -> NSArray<NSString> {
unsafe {
NSArray::from_id(
msg_send![self.m_self(), neighborsForVector: vector.m_self() maximumCount: max_count maximumDistance: max_distance distanceType: distance_type],
)
}
}
#[method]
pub fn enumerate_neighbors_for_string_maximum_count_distance_type_using_block<F>(
&self,
string: &NSString,
max_count: UInt,
distance_type: NLDistanceType,
block: F,
) where
F: IntoConcreteBlock<(NSString, NLDistance, *mut bool), Ret = ()> + 'static,
{
let block = ConcreteBlock::new(block);
let block = block.copy();
unsafe {
msg_send![self.m_self(), enumerateNeighborsForString: string.m_self() maximumCount: max_count distanceType: distance_type usingBlock: block]
}
}
#[method]
pub fn enumerate_neighbors_for_string_maximum_count_maximum_distance_distance_type_using_block<
F,
>(
&self,
string: &NSString,
max_count: UInt,
max_distance: NLDistance,
distance_type: NLDistanceType,
block: F,
) where
F: IntoConcreteBlock<(NSString, NLDistance, *mut bool), Ret = ()> + 'static,
{
let block = ConcreteBlock::new(block);
let block = block.copy();
unsafe {
msg_send![self.m_self(), enumerateNeighborsForString: string.m_self() maximumCount: max_count maximumDistance: max_distance distanceType: distance_type usingBlock: block]
}
}
#[method]
pub fn enumerate_neighbors_for_vector_maximum_count_distance_type_using_block<F>(
&self,
vector: &NSArray<NSNumber>,
max_count: UInt,
distance_type: NLDistanceType,
block: F,
) where
F: IntoConcreteBlock<(NSString, NLDistance, *mut bool), Ret = ()> + 'static,
{
let block = ConcreteBlock::new(block);
let block = block.copy();
unsafe {
msg_send![self.m_self(), enumerateNeighborsForVector: vector.m_self() maximumCount: max_count distanceType: distance_type usingBlock: block]
}
}
#[method]
pub fn enumerate_neighbors_for_vector_maximum_count_maximum_distance_distance_type_using_block<
F,
>(
&self,
vector: &NSArray<NSNumber>,
max_count: UInt,
max_distance: NLDistance,
distance_type: NLDistanceType,
block: F,
) where
F: IntoConcreteBlock<(NSString, NLDistance, *mut bool), Ret = ()> + 'static,
{
let block = ConcreteBlock::new(block);
let block = block.copy();
unsafe {
msg_send![self.m_self(), enumerateNeighborsForVector: vector.m_self() maximumCount: max_count maximumDistance: max_distance distanceType: distance_type usingBlock: block]
}
}
#[method]
pub fn distance_between_string_and_string_distance_type(
&self,
first: &NSString,
second: &NSString,
distance_type: NLDistanceType,
) -> NLDistance {
unsafe {
msg_send![self.m_self(), distanceBetweenString: first.m_self() andString: second.m_self() distanceType: distance_type]
}
}
#[property]
pub fn dimension(&self) -> UInt {
unsafe { msg_send![self.m_self(), dimension] }
}
#[property]
pub fn vocabulary_size(&self) -> UInt {
unsafe { msg_send![self.m_self(), vocabularySize] }
}
#[property]
pub fn language(&self) -> Option<NLLanguage> {
unsafe { to_optional(msg_send![self.m_self(), language]) }
}
#[method]
pub fn contains_string(&self, string: &NSString) -> bool {
unsafe { to_bool(msg_send![self.m_self(), containsString: string.m_self()]) }
}
#[method]
pub fn vector_for_string(&self, string: &NSString) -> NSArray<NSNumber> {
unsafe { NSArray::from_id(msg_send![self.m_self(), vectorForString: string.m_self()]) }
}
#[method]
pub fn get_vector_for_string(&self, vector: &mut [c_float], string: &NSString) -> bool {
unsafe { to_bool(msg_send![self.m_self(), getVector: vector forString: string.m_self()]) }
}
#[property]
pub fn revision(&self) -> UInt {
unsafe { msg_send![self.m_self(), revision] }
}
#[method]
pub fn write_embedding_for_dictionary_language_revision_to_url(
dictionary: &NSDictionary<NSString, NSArray<NSNumber>>,
language: NLLanguage,
revision: UInt,
url: &NSURL,
) -> Result<bool, NSError> {
let mut error = NSError::m_alloc();
let ptr = unsafe {
to_bool(
msg_send![Self::m_class(), writeEmbeddingForDictionary: dictionary.m_self() language: language revision: revision toURL: url.m_self() error: &mut error ],
)
};
if error.m_self() != nil {
Err(error)
} else {
Ok(ptr)
}
}
#[method]
pub fn current_revision_for_language(language: NLLanguage) -> UInt {
unsafe { msg_send![Self::m_class(), currentRevisionForLanguage: language] }
}
#[method]
pub fn supported_revisions_for_language(language: NLLanguage) -> NSIndexSet {
unsafe {
NSIndexSet::from_id(msg_send![
Self::m_class(),
supportedRevisionsForLanguage: language
])
}
}
#[method]
pub fn current_sentence_embedding_revision_for_language(language: NLLanguage) -> UInt {
unsafe {
msg_send![
Self::m_class(),
currentSentenceEmbeddingRevisionForLanguage: language
]
}
}
#[method]
pub fn supported_sentence_embedding_revisions_for_language(language: NLLanguage) -> NSIndexSet {
unsafe {
NSIndexSet::from_id(msg_send![
Self::m_class(),
supportedSentenceEmbeddingRevisionsForLanguage: language
])
}
}
}