Skip to main content

ort/value/
impl_map.rs

1use alloc::{boxed::Box, format, string::String, sync::Arc, vec, vec::Vec};
2use core::{
3	ffi::c_void,
4	fmt::{self, Debug},
5	hash::Hash,
6	marker::PhantomData,
7	mem,
8	ptr::{self, NonNull},
9	slice
10};
11#[cfg(feature = "std")]
12use std::collections::HashMap;
13
14use super::{
15	DowncastableTarget, DynValue, Value, ValueInner, ValueRef, ValueRefMut, ValueType, ValueTypeMarker,
16	impl_tensor::{DynTensor, IntoTensorElementType, PrimitiveTensorElementType, Tensor, TensorElementType}
17};
18use crate::{
19	AsPointer, ErrorCode,
20	error::{Error, Result},
21	memory::Allocator,
22	ortsys
23};
24
25pub trait MapValueTypeMarker: ValueTypeMarker {
26	private_trait!();
27}
28
29#[derive(Debug)]
30pub struct DynMapValueType;
31impl ValueTypeMarker for DynMapValueType {
32	fn fmt(f: &mut fmt::Formatter) -> fmt::Result {
33		f.write_str("DynMap")
34	}
35
36	private_impl!();
37}
38impl MapValueTypeMarker for DynMapValueType {
39	private_impl!();
40}
41
42impl DowncastableTarget for DynMapValueType {
43	fn can_downcast(dtype: &ValueType) -> bool {
44		matches!(dtype, ValueType::Map { .. })
45	}
46
47	private_impl!();
48}
49
50#[derive(Debug)]
51pub struct MapValueType<K: IntoTensorElementType + Clone + Hash + Eq, V: IntoTensorElementType + Debug>(PhantomData<(K, V)>);
52impl<K: IntoTensorElementType + Debug + Clone + Hash + Eq, V: IntoTensorElementType + Debug> ValueTypeMarker for MapValueType<K, V> {
53	fn fmt(f: &mut fmt::Formatter) -> fmt::Result {
54		f.write_str("Map<")?;
55		<TensorElementType as fmt::Display>::fmt(&K::into_tensor_element_type(), f)?;
56		f.write_str(", ")?;
57		<TensorElementType as fmt::Display>::fmt(&V::into_tensor_element_type(), f)?;
58		f.write_str(">")
59	}
60
61	private_impl!();
62}
63impl<K: IntoTensorElementType + Debug + Clone + Hash + Eq, V: IntoTensorElementType + Debug> MapValueTypeMarker for MapValueType<K, V> {
64	private_impl!();
65}
66
67impl<K: IntoTensorElementType + Debug + Clone + Hash + Eq, V: IntoTensorElementType + Debug> DowncastableTarget for MapValueType<K, V> {
68	fn can_downcast(dtype: &ValueType) -> bool {
69		match dtype {
70			ValueType::Map { key, value } => *key == K::into_tensor_element_type() && *value == V::into_tensor_element_type(),
71			_ => false
72		}
73	}
74
75	private_impl!();
76}
77
78pub type DynMap = Value<DynMapValueType>;
79pub type Map<K, V> = Value<MapValueType<K, V>>;
80
81pub type DynMapRef<'v> = ValueRef<'v, DynMapValueType>;
82pub type DynMapRefMut<'v> = ValueRefMut<'v, DynMapValueType>;
83pub type MapRef<'v, K, V> = ValueRef<'v, MapValueType<K, V>>;
84pub type MapRefMut<'v, K, V> = ValueRefMut<'v, MapValueType<K, V>>;
85
86impl<Type: MapValueTypeMarker + ?Sized> Value<Type> {
87	pub fn try_extract_key_values<K: IntoTensorElementType + Clone + Hash + Eq, V: PrimitiveTensorElementType + Clone>(&self) -> Result<Vec<(K, V)>> {
88		match self.dtype() {
89			ValueType::Map { key, value } => {
90				let k_type = K::into_tensor_element_type();
91				if k_type != *key {
92					return Err(Error::new_with_code(ErrorCode::InvalidArgument, format!("Cannot extract Map<{:?}, _> (value has K type {:?})", k_type, key)));
93				}
94				let v_type = V::into_tensor_element_type();
95				if v_type != *value {
96					return Err(Error::new_with_code(
97						ErrorCode::InvalidArgument,
98						format!("Cannot extract Map<{}, {}> from Map<{}, {}>", K::into_tensor_element_type(), V::into_tensor_element_type(), k_type, v_type)
99					));
100				}
101
102				let allocator = Allocator::default();
103
104				let mut key_tensor_ptr = ptr::null_mut();
105				ortsys![unsafe GetValue(self.ptr(), 0, allocator.ptr().cast_mut(), &mut key_tensor_ptr)?; nonNull(key_tensor_ptr)];
106				let key_value: DynTensor = unsafe { Value::from_ptr(key_tensor_ptr, None) };
107				if K::into_tensor_element_type() != TensorElementType::String {
108					let dtype = key_value.dtype();
109					let (key_tensor_shape, key_tensor) = match dtype {
110						ValueType::Tensor { ty, shape, .. } => {
111							let mem = key_value.memory_info();
112							if !mem.is_cpu_accessible() {
113								return Err(Error::new(format!(
114									"Cannot extract from value on device `{}`, which is not CPU accessible",
115									mem.allocation_device().as_str()
116								)));
117							}
118
119							if *ty == K::into_tensor_element_type() {
120								let mut output_array_ptr: *mut K = ptr::null_mut();
121								let output_array_ptr_ptr: *mut *mut K = &mut output_array_ptr;
122								let output_array_ptr_ptr_void: *mut *mut c_void = output_array_ptr_ptr.cast();
123								ortsys![unsafe GetTensorMutableData(key_tensor_ptr.as_ptr(), output_array_ptr_ptr_void)?];
124								if output_array_ptr.is_null() {
125									output_array_ptr = NonNull::dangling().as_ptr();
126								}
127
128								(shape, unsafe { slice::from_raw_parts(output_array_ptr, shape.num_elements()) })
129							} else {
130								return Err(Error::new_with_code(
131									ErrorCode::InvalidArgument,
132									format!(
133										"Cannot extract Map<{}, {}> from Map<{}, {}>",
134										K::into_tensor_element_type(),
135										V::into_tensor_element_type(),
136										k_type,
137										v_type
138									)
139								));
140							}
141						}
142						_ => unreachable!()
143					};
144
145					let mut value_tensor_ptr = ptr::null_mut();
146					ortsys![unsafe GetValue(self.ptr(), 1, allocator.ptr().cast_mut(), &mut value_tensor_ptr)?; nonNull(value_tensor_ptr)];
147					let value_value: DynTensor = unsafe { Value::from_ptr(value_tensor_ptr, None) };
148					let (value_tensor_shape, value_tensor) = value_value.try_extract_tensor::<V>()?;
149
150					assert_eq!(key_tensor_shape.len(), 1);
151					assert_eq!(value_tensor_shape.len(), 1);
152					assert_eq!(key_tensor_shape[0], value_tensor_shape[0]);
153
154					let mut vec = Vec::with_capacity(key_tensor_shape[0] as _);
155					for i in 0..key_tensor_shape[0] as usize {
156						vec.push((key_tensor[i].clone(), value_tensor[i].clone()));
157					}
158					Ok(vec)
159				} else {
160					let (key_tensor_shape, key_tensor) = key_value.try_extract_strings()?;
161					// SAFETY: `IntoTensorElementType` is a private trait, and we only map the `String` type to `TensorElementType::String`,
162					// so at this point, `K` is **always** the `String` type, and this transmute really does nothing but please the type
163					// checker.
164					let key_tensor: Vec<K> = unsafe { mem::transmute(key_tensor) };
165
166					let mut value_tensor_ptr = ptr::null_mut();
167					ortsys![unsafe GetValue(self.ptr(), 1, allocator.ptr().cast_mut(), &mut value_tensor_ptr)?; nonNull(value_tensor_ptr)];
168					let value_value: DynTensor = unsafe { Value::from_ptr(value_tensor_ptr, None) };
169					let (value_tensor_shape, value_tensor) = value_value.try_extract_tensor::<V>()?;
170
171					assert_eq!(key_tensor_shape.len(), 1);
172					assert_eq!(value_tensor_shape.len(), 1);
173					assert_eq!(key_tensor_shape[0], value_tensor_shape[0]);
174
175					let mut vec = Vec::with_capacity(key_tensor_shape[0] as _);
176					for i in 0..key_tensor_shape[0] as usize {
177						vec.push((key_tensor[i].clone(), value_tensor[i].clone()));
178					}
179					Ok(vec.into_iter().collect())
180				}
181			}
182			t => Err(Error::new_with_code(
183				ErrorCode::InvalidArgument,
184				format!("Cannot extract Map<{}, {}> from {t}", K::into_tensor_element_type(), V::into_tensor_element_type())
185			))
186		}
187	}
188
189	#[cfg(feature = "std")]
190	#[cfg_attr(docsrs, doc(cfg(feature = "std")))]
191	pub fn try_extract_map<K: IntoTensorElementType + Clone + Hash + Eq, V: PrimitiveTensorElementType + Clone>(&self) -> Result<HashMap<K, V>> {
192		self.try_extract_key_values().map(|c| c.into_iter().collect())
193	}
194}
195
196impl<K: PrimitiveTensorElementType + Debug + Clone + Hash + Eq + 'static, V: PrimitiveTensorElementType + Debug + Clone + 'static> Value<MapValueType<K, V>> {
197	/// Creates a [`Map`] from an iterable emitting `K` and `V`.
198	///
199	/// ```
200	/// # use std::collections::HashMap;
201	/// # use ort::value::Map;
202	/// # fn main() -> ort::Result<()> {
203	/// let mut map = HashMap::<i64, f32>::new();
204	/// map.insert(0, 1.0);
205	/// map.insert(1, 2.0);
206	/// map.insert(2, 3.0);
207	///
208	/// let value = Map::<i64, f32>::new(map)?;
209	///
210	/// assert_eq!(*value.extract_map().get(&0).unwrap(), 1.0);
211	/// # 	Ok(())
212	/// # }
213	/// ```
214	pub fn new(data: impl IntoIterator<Item = (K, V)>) -> Result<Self> {
215		let (keys, values): (Vec<K>, Vec<V>) = data.into_iter().unzip();
216		Self::new_kv(Tensor::from_array((vec![keys.len()], keys))?, Tensor::from_array((vec![values.len()], values))?)
217	}
218}
219
220impl<V: PrimitiveTensorElementType + Debug + Clone + 'static> Value<MapValueType<String, V>> {
221	/// Creates a [`Map`] from an iterable emitting `K` and `V`.
222	///
223	/// ```
224	/// # use std::collections::HashMap;
225	/// # use ort::value::Map;
226	/// # fn main() -> ort::Result<()> {
227	/// let mut map = HashMap::<String, f32>::new();
228	/// map.insert("one".to_string(), 1.0);
229	/// map.insert("two".to_string(), 2.0);
230	/// map.insert("three".to_string(), 3.0);
231	///
232	/// let value = Map::<String, f32>::new(map)?;
233	///
234	/// assert_eq!(*value.extract_map().get("one").unwrap(), 1.0);
235	/// # 	Ok(())
236	/// # }
237	/// ```
238	pub fn new(data: impl IntoIterator<Item = (String, V)>) -> Result<Self> {
239		let (keys, values): (Vec<String>, Vec<V>) = data.into_iter().unzip();
240		Self::new_kv(Tensor::from_string_array((vec![keys.len()], keys.as_slice()))?, Tensor::from_array((vec![values.len()], values))?)
241	}
242}
243
244impl<K: IntoTensorElementType + Debug + Clone + Hash + Eq + 'static, V: IntoTensorElementType + Debug + Clone + 'static> Value<MapValueType<K, V>> {
245	/// Creates a [`Map`] from two tensors of keys & values respectively.
246	///
247	/// ```
248	/// # use std::collections::HashMap;
249	/// # use ort::value::{Map, Tensor};
250	/// # fn main() -> ort::Result<()> {
251	/// let keys = Tensor::<i64>::from_array(([4], vec![0, 1, 2, 3]))?;
252	/// let values = Tensor::<f32>::from_array(([4], vec![1., 2., 3., 4.]))?;
253	///
254	/// let value = Map::new_kv(keys, values)?;
255	///
256	/// assert_eq!(*value.extract_map().get(&0).unwrap(), 1.0);
257	/// # 	Ok(())
258	/// # }
259	/// ```
260	pub fn new_kv(keys: Tensor<K>, values: Tensor<V>) -> Result<Self> {
261		let mut value_ptr = ptr::null_mut();
262		let values: [DynValue; 2] = [keys.into_dyn(), values.into_dyn()];
263		let value_ptrs: Vec<*const ort_sys::OrtValue> = values.iter().map(|c| c.ptr()).collect();
264		ortsys![
265			unsafe CreateValue(value_ptrs.as_ptr(), 2, ort_sys::ONNXType::ONNX_TYPE_MAP, &mut value_ptr)?;
266			nonNull(value_ptr)
267		];
268		Ok(Value {
269			inner: ValueInner::new_backed(
270				value_ptr,
271				ValueType::Map {
272					key: K::into_tensor_element_type(),
273					value: V::into_tensor_element_type()
274				},
275				None,
276				true,
277				Box::new(values)
278			),
279			_markers: PhantomData
280		})
281	}
282}
283
284impl<K: IntoTensorElementType + Debug + Clone + Hash + Eq, V: PrimitiveTensorElementType + Debug + Clone> Value<MapValueType<K, V>> {
285	pub fn extract_key_values(&self) -> Vec<(K, V)> {
286		self.try_extract_key_values().expect("Failed to extract map")
287	}
288
289	#[cfg(feature = "std")]
290	#[cfg_attr(docsrs, doc(cfg(feature = "std")))]
291	pub fn extract_map(&self) -> HashMap<K, V> {
292		self.try_extract_map().expect("Failed to extract map")
293	}
294}
295
296impl<K: IntoTensorElementType + Debug + Clone + Hash + Eq, V: IntoTensorElementType + Debug + Clone> Value<MapValueType<K, V>> {
297	/// Converts from a strongly-typed [`Map<K, V>`] to a type-erased [`DynMap`].
298	#[inline]
299	pub fn upcast(self) -> DynMap {
300		unsafe { self.transmute_type() }
301	}
302
303	/// Converts from a strongly-typed [`Map<K, V>`] to a reference to a type-erased [`DynMap`].
304	#[inline]
305	pub fn upcast_ref(&self) -> DynMapRef<'_> {
306		DynMapRef::new(Value {
307			inner: Arc::clone(&self.inner),
308			_markers: PhantomData
309		})
310	}
311
312	/// Converts from a strongly-typed [`Map<K, V>`] to a mutable reference to a type-erased [`DynMap`].
313	#[inline]
314	pub fn upcast_mut(&mut self) -> DynMapRefMut<'_> {
315		DynMapRefMut::new(Value {
316			inner: Arc::clone(&self.inner),
317			_markers: PhantomData
318		})
319	}
320}