1use alloc::{boxed::Box, format, string::String, sync::Arc, vec, vec::Vec};
2use core::{
3 ffi::c_void,
4 fmt::{self, Debug},
5 hash::Hash,
6 marker::PhantomData,
7 mem,
8 ptr::{self, NonNull},
9 slice
10};
11#[cfg(feature = "std")]
12use std::collections::HashMap;
13
14use super::{
15 DowncastableTarget, DynValue, Value, ValueInner, ValueRef, ValueRefMut, ValueType, ValueTypeMarker,
16 impl_tensor::{DynTensor, IntoTensorElementType, PrimitiveTensorElementType, Tensor, TensorElementType}
17};
18use crate::{
19 AsPointer, ErrorCode,
20 error::{Error, Result},
21 memory::Allocator,
22 ortsys
23};
24
25pub trait MapValueTypeMarker: ValueTypeMarker {
26 private_trait!();
27}
28
29#[derive(Debug)]
30pub struct DynMapValueType;
31impl ValueTypeMarker for DynMapValueType {
32 fn fmt(f: &mut fmt::Formatter) -> fmt::Result {
33 f.write_str("DynMap")
34 }
35
36 private_impl!();
37}
38impl MapValueTypeMarker for DynMapValueType {
39 private_impl!();
40}
41
42impl DowncastableTarget for DynMapValueType {
43 fn can_downcast(dtype: &ValueType) -> bool {
44 matches!(dtype, ValueType::Map { .. })
45 }
46
47 private_impl!();
48}
49
50#[derive(Debug)]
51pub struct MapValueType<K: IntoTensorElementType + Clone + Hash + Eq, V: IntoTensorElementType + Debug>(PhantomData<(K, V)>);
52impl<K: IntoTensorElementType + Debug + Clone + Hash + Eq, V: IntoTensorElementType + Debug> ValueTypeMarker for MapValueType<K, V> {
53 fn fmt(f: &mut fmt::Formatter) -> fmt::Result {
54 f.write_str("Map<")?;
55 <TensorElementType as fmt::Display>::fmt(&K::into_tensor_element_type(), f)?;
56 f.write_str(", ")?;
57 <TensorElementType as fmt::Display>::fmt(&V::into_tensor_element_type(), f)?;
58 f.write_str(">")
59 }
60
61 private_impl!();
62}
63impl<K: IntoTensorElementType + Debug + Clone + Hash + Eq, V: IntoTensorElementType + Debug> MapValueTypeMarker for MapValueType<K, V> {
64 private_impl!();
65}
66
67impl<K: IntoTensorElementType + Debug + Clone + Hash + Eq, V: IntoTensorElementType + Debug> DowncastableTarget for MapValueType<K, V> {
68 fn can_downcast(dtype: &ValueType) -> bool {
69 match dtype {
70 ValueType::Map { key, value } => *key == K::into_tensor_element_type() && *value == V::into_tensor_element_type(),
71 _ => false
72 }
73 }
74
75 private_impl!();
76}
77
78pub type DynMap = Value<DynMapValueType>;
79pub type Map<K, V> = Value<MapValueType<K, V>>;
80
81pub type DynMapRef<'v> = ValueRef<'v, DynMapValueType>;
82pub type DynMapRefMut<'v> = ValueRefMut<'v, DynMapValueType>;
83pub type MapRef<'v, K, V> = ValueRef<'v, MapValueType<K, V>>;
84pub type MapRefMut<'v, K, V> = ValueRefMut<'v, MapValueType<K, V>>;
85
86impl<Type: MapValueTypeMarker + ?Sized> Value<Type> {
87 pub fn try_extract_key_values<K: IntoTensorElementType + Clone + Hash + Eq, V: PrimitiveTensorElementType + Clone>(&self) -> Result<Vec<(K, V)>> {
88 match self.dtype() {
89 ValueType::Map { key, value } => {
90 let k_type = K::into_tensor_element_type();
91 if k_type != *key {
92 return Err(Error::new_with_code(ErrorCode::InvalidArgument, format!("Cannot extract Map<{:?}, _> (value has K type {:?})", k_type, key)));
93 }
94 let v_type = V::into_tensor_element_type();
95 if v_type != *value {
96 return Err(Error::new_with_code(
97 ErrorCode::InvalidArgument,
98 format!("Cannot extract Map<{}, {}> from Map<{}, {}>", K::into_tensor_element_type(), V::into_tensor_element_type(), k_type, v_type)
99 ));
100 }
101
102 let allocator = Allocator::default();
103
104 let mut key_tensor_ptr = ptr::null_mut();
105 ortsys![unsafe GetValue(self.ptr(), 0, allocator.ptr().cast_mut(), &mut key_tensor_ptr)?; nonNull(key_tensor_ptr)];
106 let key_value: DynTensor = unsafe { Value::from_ptr(key_tensor_ptr, None) };
107 if K::into_tensor_element_type() != TensorElementType::String {
108 let dtype = key_value.dtype();
109 let (key_tensor_shape, key_tensor) = match dtype {
110 ValueType::Tensor { ty, shape, .. } => {
111 let mem = key_value.memory_info();
112 if !mem.is_cpu_accessible() {
113 return Err(Error::new(format!(
114 "Cannot extract from value on device `{}`, which is not CPU accessible",
115 mem.allocation_device().as_str()
116 )));
117 }
118
119 if *ty == K::into_tensor_element_type() {
120 let mut output_array_ptr: *mut K = ptr::null_mut();
121 let output_array_ptr_ptr: *mut *mut K = &mut output_array_ptr;
122 let output_array_ptr_ptr_void: *mut *mut c_void = output_array_ptr_ptr.cast();
123 ortsys![unsafe GetTensorMutableData(key_tensor_ptr.as_ptr(), output_array_ptr_ptr_void)?];
124 if output_array_ptr.is_null() {
125 output_array_ptr = NonNull::dangling().as_ptr();
126 }
127
128 (shape, unsafe { slice::from_raw_parts(output_array_ptr, shape.num_elements()) })
129 } else {
130 return Err(Error::new_with_code(
131 ErrorCode::InvalidArgument,
132 format!(
133 "Cannot extract Map<{}, {}> from Map<{}, {}>",
134 K::into_tensor_element_type(),
135 V::into_tensor_element_type(),
136 k_type,
137 v_type
138 )
139 ));
140 }
141 }
142 _ => unreachable!()
143 };
144
145 let mut value_tensor_ptr = ptr::null_mut();
146 ortsys![unsafe GetValue(self.ptr(), 1, allocator.ptr().cast_mut(), &mut value_tensor_ptr)?; nonNull(value_tensor_ptr)];
147 let value_value: DynTensor = unsafe { Value::from_ptr(value_tensor_ptr, None) };
148 let (value_tensor_shape, value_tensor) = value_value.try_extract_tensor::<V>()?;
149
150 assert_eq!(key_tensor_shape.len(), 1);
151 assert_eq!(value_tensor_shape.len(), 1);
152 assert_eq!(key_tensor_shape[0], value_tensor_shape[0]);
153
154 let mut vec = Vec::with_capacity(key_tensor_shape[0] as _);
155 for i in 0..key_tensor_shape[0] as usize {
156 vec.push((key_tensor[i].clone(), value_tensor[i].clone()));
157 }
158 Ok(vec)
159 } else {
160 let (key_tensor_shape, key_tensor) = key_value.try_extract_strings()?;
161 let key_tensor: Vec<K> = unsafe { mem::transmute(key_tensor) };
165
166 let mut value_tensor_ptr = ptr::null_mut();
167 ortsys![unsafe GetValue(self.ptr(), 1, allocator.ptr().cast_mut(), &mut value_tensor_ptr)?; nonNull(value_tensor_ptr)];
168 let value_value: DynTensor = unsafe { Value::from_ptr(value_tensor_ptr, None) };
169 let (value_tensor_shape, value_tensor) = value_value.try_extract_tensor::<V>()?;
170
171 assert_eq!(key_tensor_shape.len(), 1);
172 assert_eq!(value_tensor_shape.len(), 1);
173 assert_eq!(key_tensor_shape[0], value_tensor_shape[0]);
174
175 let mut vec = Vec::with_capacity(key_tensor_shape[0] as _);
176 for i in 0..key_tensor_shape[0] as usize {
177 vec.push((key_tensor[i].clone(), value_tensor[i].clone()));
178 }
179 Ok(vec.into_iter().collect())
180 }
181 }
182 t => Err(Error::new_with_code(
183 ErrorCode::InvalidArgument,
184 format!("Cannot extract Map<{}, {}> from {t}", K::into_tensor_element_type(), V::into_tensor_element_type())
185 ))
186 }
187 }
188
189 #[cfg(feature = "std")]
190 #[cfg_attr(docsrs, doc(cfg(feature = "std")))]
191 pub fn try_extract_map<K: IntoTensorElementType + Clone + Hash + Eq, V: PrimitiveTensorElementType + Clone>(&self) -> Result<HashMap<K, V>> {
192 self.try_extract_key_values().map(|c| c.into_iter().collect())
193 }
194}
195
196impl<K: PrimitiveTensorElementType + Debug + Clone + Hash + Eq + 'static, V: PrimitiveTensorElementType + Debug + Clone + 'static> Value<MapValueType<K, V>> {
197 pub fn new(data: impl IntoIterator<Item = (K, V)>) -> Result<Self> {
215 let (keys, values): (Vec<K>, Vec<V>) = data.into_iter().unzip();
216 Self::new_kv(Tensor::from_array((vec![keys.len()], keys))?, Tensor::from_array((vec![values.len()], values))?)
217 }
218}
219
220impl<V: PrimitiveTensorElementType + Debug + Clone + 'static> Value<MapValueType<String, V>> {
221 pub fn new(data: impl IntoIterator<Item = (String, V)>) -> Result<Self> {
239 let (keys, values): (Vec<String>, Vec<V>) = data.into_iter().unzip();
240 Self::new_kv(Tensor::from_string_array((vec![keys.len()], keys.as_slice()))?, Tensor::from_array((vec![values.len()], values))?)
241 }
242}
243
244impl<K: IntoTensorElementType + Debug + Clone + Hash + Eq + 'static, V: IntoTensorElementType + Debug + Clone + 'static> Value<MapValueType<K, V>> {
245 pub fn new_kv(keys: Tensor<K>, values: Tensor<V>) -> Result<Self> {
261 let mut value_ptr = ptr::null_mut();
262 let values: [DynValue; 2] = [keys.into_dyn(), values.into_dyn()];
263 let value_ptrs: Vec<*const ort_sys::OrtValue> = values.iter().map(|c| c.ptr()).collect();
264 ortsys![
265 unsafe CreateValue(value_ptrs.as_ptr(), 2, ort_sys::ONNXType::ONNX_TYPE_MAP, &mut value_ptr)?;
266 nonNull(value_ptr)
267 ];
268 Ok(Value {
269 inner: ValueInner::new_backed(
270 value_ptr,
271 ValueType::Map {
272 key: K::into_tensor_element_type(),
273 value: V::into_tensor_element_type()
274 },
275 None,
276 true,
277 Box::new(values)
278 ),
279 _markers: PhantomData
280 })
281 }
282}
283
284impl<K: IntoTensorElementType + Debug + Clone + Hash + Eq, V: PrimitiveTensorElementType + Debug + Clone> Value<MapValueType<K, V>> {
285 pub fn extract_key_values(&self) -> Vec<(K, V)> {
286 self.try_extract_key_values().expect("Failed to extract map")
287 }
288
289 #[cfg(feature = "std")]
290 #[cfg_attr(docsrs, doc(cfg(feature = "std")))]
291 pub fn extract_map(&self) -> HashMap<K, V> {
292 self.try_extract_map().expect("Failed to extract map")
293 }
294}
295
296impl<K: IntoTensorElementType + Debug + Clone + Hash + Eq, V: IntoTensorElementType + Debug + Clone> Value<MapValueType<K, V>> {
297 #[inline]
299 pub fn upcast(self) -> DynMap {
300 unsafe { self.transmute_type() }
301 }
302
303 #[inline]
305 pub fn upcast_ref(&self) -> DynMapRef<'_> {
306 DynMapRef::new(Value {
307 inner: Arc::clone(&self.inner),
308 _markers: PhantomData
309 })
310 }
311
312 #[inline]
314 pub fn upcast_mut(&mut self) -> DynMapRefMut<'_> {
315 DynMapRefMut::new(Value {
316 inner: Arc::clone(&self.inner),
317 _markers: PhantomData
318 })
319 }
320}