diskann_benchmark_runner/any.rs
1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use crate::dispatcher::{DispatchRule, FailureScore, MatchScore};
7
8/// An refinement of [`std::any::Any`] with an associated name (tag) and serialization.
9///
10/// This type represents deserialized inputs returned from [`crate::Input::try_deserialize`]
11/// and is passed to beckend benchmarks for matching and execution.
12#[derive(Debug)]
13pub struct Any {
14 any: Box<dyn SerializableAny>,
15 tag: &'static str,
16}
17
18/// The score given unsuccessful downcasts in [`Any::try_match`].
19pub const MATCH_FAIL: FailureScore = FailureScore(10_000);
20
21impl Any {
22 /// A crate-local constructor.
23 ///
24 /// We don't expose this as a public method to ensure that users go through
25 /// [`crate::Checker::any`] and therefore get the correct tag and run
26 /// [`crate::CheckDeserialization`].
27 pub(crate) fn new<T>(any: T, tag: &'static str) -> Self
28 where
29 T: serde::Serialize + std::fmt::Debug + 'static,
30 {
31 Self {
32 any: Box::new(any),
33 tag,
34 }
35 }
36
37 /// Return a new [`Any`] without a tag.
38 ///
39 /// This method exists for examples and testing and generally should not be used.
40 ///
41 /// Users should use [`crate::Checker::any`] to construct a properly tagged [`Any`].
42 pub fn untagged<T>(any: T) -> Self
43 where
44 T: serde::Serialize + std::fmt::Debug + 'static,
45 {
46 Self::new(any, "")
47 }
48
49 /// Return the benchmark tag associated with this benchmarks.
50 pub fn tag(&self) -> &'static str {
51 self.tag
52 }
53
54 /// Return the Rust [`std::any::TypeId`] for the contained object.
55 pub fn type_id(&self) -> std::any::TypeId {
56 self.any.as_any().type_id()
57 }
58
59 /// Return `true` if the runtime value is `T`. Otherwise, return false.
60 ///
61 /// ```rust
62 /// use diskann_benchmark_runner::any::Any;
63 ///
64 /// let value = Any::untagged(42usize);
65 /// assert!(value.is::<usize>());
66 /// assert!(!value.is::<u32>());
67 /// ```
68 #[must_use = "this function has no side effects"]
69 pub fn is<T>(&self) -> bool
70 where
71 T: std::any::Any,
72 {
73 self.any.as_any().is::<T>()
74 }
75
76 /// Return a reference to the contained object if it's runtime type is `T`.
77 ///
78 /// Otherwise return `None`.
79 ///
80 /// ```rust
81 /// use diskann_benchmark_runner::any::Any;
82 ///
83 /// let value = Any::untagged(42usize);
84 /// assert_eq!(*value.downcast_ref::<usize>().unwrap(), 42);
85 /// assert!(value.downcast_ref::<u32>().is_none());
86 /// ```
87 pub fn downcast_ref<T>(&self) -> Option<&T>
88 where
89 T: std::any::Any,
90 {
91 self.any.as_any().downcast_ref::<T>()
92 }
93
94 /// Attempt to downcast self to `T` and if succssful, try matching `&T` with `U` using
95 /// [`crate::dispatcher::DispatchRule`].
96 ///
97 /// Otherwise, return `Err(diskann_benchmark_runner::any::MATCH_FAIL)`.
98 ///
99 /// ```rust
100 /// use diskann_benchmark_runner::{
101 /// any::Any,
102 /// dispatcher::{self, MatchScore, FailureScore},
103 /// utils::datatype::{self, DataType, Type},
104 /// };
105 ///
106 /// let value = Any::untagged(DataType::Float32);
107 ///
108 /// // A successful down cast and successful match.
109 /// assert_eq!(
110 /// value.try_match::<DataType, Type<f32>>().unwrap(),
111 /// MatchScore(0),
112 /// );
113 ///
114 /// // A successful down cast but unsuccessful match.
115 /// assert_eq!(
116 /// value.try_match::<DataType, Type<f64>>().unwrap_err(),
117 /// datatype::MATCH_FAIL,
118 /// );
119 ///
120 /// // An unsuccessful down cast.
121 /// let value = Any::untagged(0usize);
122 /// assert_eq!(
123 /// value.try_match::<DataType, Type<f32>>().unwrap_err(),
124 /// diskann_benchmark_runner::any::MATCH_FAIL,
125 /// );
126 /// ```
127 pub fn try_match<'a, T, U>(&'a self) -> Result<MatchScore, FailureScore>
128 where
129 U: DispatchRule<&'a T>,
130 T: 'static,
131 {
132 if let Some(cast) = self.downcast_ref::<T>() {
133 U::try_match(&cast)
134 } else {
135 Err(MATCH_FAIL)
136 }
137 }
138
139 /// Attempt to downcast self to `T` and if succssful, try converting `&T` with `U` using
140 /// [`crate::dispatcher::DispatchRule`].
141 ///
142 /// If unsuccessful, returns an error.
143 ///
144 /// ```rust
145 /// use diskann_benchmark_runner::{
146 /// any::Any,
147 /// dispatcher::{self, MatchScore, FailureScore},
148 /// utils::datatype::{self, DataType, Type},
149 /// };
150 ///
151 /// let value = Any::untagged(DataType::Float32);
152 ///
153 /// // A successful down cast and successful conversion.
154 /// let _: Type<f32> = value.convert::<DataType, _>().unwrap();
155 /// ```
156 pub fn convert<'a, T, U>(&'a self) -> anyhow::Result<U>
157 where
158 U: DispatchRule<&'a T>,
159 anyhow::Error: From<U::Error>,
160 T: 'static,
161 {
162 if let Some(cast) = self.downcast_ref::<T>() {
163 Ok(U::convert(cast)?)
164 } else {
165 Err(anyhow::Error::msg("invalid dispatch"))
166 }
167 }
168
169 /// A wrapper for [`DispatchRule::description`].
170 ///
171 /// If `from` is `None` - document the expected tag for the input and return
172 /// `<U as DispatchRule<&T>>::description(f, None)`.
173 ///
174 /// If `from` is `Some` - attempt to downcast to `T`. If successful, return the dispatch
175 /// rule description for `U` on the doncast reference. Otherwise, return the expected tag.
176 ///
177 /// ```rust
178 /// use diskann_benchmark_runner::{
179 /// any::Any,
180 /// utils::datatype::{self, DataType, Type},
181 /// };
182 ///
183 /// use std::io::Write;
184 ///
185 /// struct Display(Option<Any>);
186 ///
187 /// impl std::fmt::Display for Display {
188 /// fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
189 /// match &self.0 {
190 /// Some(v) => Any::description::<DataType, Type<f32>>(f, Some(&&v), "my-tag"),
191 /// None => Any::description::<DataType, Type<f32>>(f, None, "my-tag"),
192 /// }
193 /// }
194 /// }
195 ///
196 /// // No contained value - document the expected conversion.
197 /// assert_eq!(
198 /// Display(None).to_string(),
199 /// "tag \"my-tag\"\nfloat32",
200 /// );
201 ///
202 /// // Matching contained value.
203 /// assert_eq!(
204 /// Display(Some(Any::untagged(DataType::Float32))).to_string(),
205 /// "successful match",
206 /// );
207 ///
208 /// // Successful down cast - unsuccessful match.
209 /// assert_eq!(
210 /// Display(Some(Any::untagged(DataType::UInt64))).to_string(),
211 /// "expected \"float32\" but found \"uint64\"",
212 /// );
213 ///
214 /// // Unsuccessful down cast.
215 /// assert_eq!(
216 /// Display(Some(Any::untagged(0usize))).to_string(),
217 /// "expected tag \"my-tag\" - instead got \"\"",
218 /// );
219 /// ```
220 pub fn description<'a, T, U>(
221 f: &mut std::fmt::Formatter<'_>,
222 from: Option<&&'a Self>,
223 tag: impl std::fmt::Display,
224 ) -> std::fmt::Result
225 where
226 U: DispatchRule<&'a T>,
227 T: 'static,
228 {
229 match from {
230 Some(this) => match this.downcast_ref::<T>() {
231 Some(a) => U::description(f, Some(&a)),
232 None => write!(
233 f,
234 "expected tag \"{}\" - instead got \"{}\"",
235 tag,
236 this.tag(),
237 ),
238 },
239 None => {
240 writeln!(f, "tag \"{}\"", tag)?;
241 U::description(f, None::<&&T>)
242 }
243 }
244 }
245
246 /// Serialize the contained object to a [`serde_json::Value`].
247 pub fn serialize(&self) -> Result<serde_json::Value, serde_json::Error> {
248 self.any.dump()
249 }
250}
251
252/// Used in `DispatchRule::description(f, _)` to ensure that additional description
253/// lines are properly aligned.
254#[macro_export]
255macro_rules! describeln {
256 ($writer:ident, $fmt:literal) => {
257 writeln!($writer, concat!(" ", $fmt))
258 };
259 ($writer:ident, $fmt:literal, $($args:expr),* $(,)?) => {
260 writeln!($writer, concat!(" ", $fmt), $($args,)*)
261 };
262}
263
264trait SerializableAny: std::fmt::Debug {
265 fn as_any(&self) -> &dyn std::any::Any;
266 fn dump(&self) -> Result<serde_json::Value, serde_json::Error>;
267}
268
269impl<T> SerializableAny for T
270where
271 T: std::any::Any + serde::Serialize + std::fmt::Debug,
272{
273 fn as_any(&self) -> &dyn std::any::Any {
274 self
275 }
276
277 fn dump(&self) -> Result<serde_json::Value, serde_json::Error> {
278 serde_json::to_value(self)
279 }
280}
281
282///////////
283// Tests //
284///////////
285
286#[cfg(test)]
287mod tests {
288 use super::*;
289
290 use crate::utils::datatype::{self, DataType, Type};
291
292 #[test]
293 fn test_untagged() {
294 let x = Any::untagged(42usize);
295 assert_eq!(x.tag(), "");
296 assert_eq!(x.type_id(), std::any::TypeId::of::<usize>());
297 assert!(x.is::<usize>());
298 assert!(!x.is::<u32>());
299 assert_eq!(*x.downcast_ref::<usize>().unwrap(), 42);
300 assert!(x.downcast_ref::<u32>().is_none());
301 }
302
303 #[test]
304 fn test_new() {
305 let x = Any::new(42usize, "my-tag");
306 assert_eq!(x.tag(), "my-tag");
307 assert_eq!(x.type_id(), std::any::TypeId::of::<usize>());
308 assert!(x.is::<usize>());
309 assert!(!x.is::<u32>());
310 assert_eq!(*x.downcast_ref::<usize>().unwrap(), 42);
311 assert!(x.downcast_ref::<u32>().is_none());
312
313 assert_eq!(
314 x.serialize().unwrap(),
315 serde_json::Value::Number(serde_json::value::Number::from(42usize))
316 );
317 }
318
319 #[test]
320 fn test_try_match() {
321 let value = Any::new(DataType::Float32, "random-tag");
322
323 // A successful down cast and successful match.
324 assert_eq!(
325 value.try_match::<DataType, Type<f32>>().unwrap(),
326 MatchScore(0),
327 );
328
329 // A successful down cast but unsuccessful match.
330 assert_eq!(
331 value.try_match::<DataType, Type<f64>>().unwrap_err(),
332 datatype::MATCH_FAIL,
333 );
334
335 // An unsuccessful down cast.
336 let value = Any::untagged(0usize);
337 assert_eq!(
338 value.try_match::<DataType, Type<f32>>().unwrap_err(),
339 MATCH_FAIL,
340 );
341 }
342
343 #[test]
344 fn test_convert() {
345 let value = Any::new(DataType::Float32, "random-tag");
346
347 // A successful down cast and successful conversion.
348 let _: Type<f32> = value.convert::<DataType, _>().unwrap();
349
350 // An invalid match should return an error.
351 let value = Any::new(0usize, "random-rag");
352 let err = value.convert::<DataType, Type<f32>>().unwrap_err();
353 let msg = err.to_string();
354 assert!(msg.contains("invalid dispatch"), "{}", msg);
355 }
356
357 #[test]
358 #[should_panic(expected = "invalid dispatch")]
359 fn test_convert_inner_error() {
360 let value = Any::new(DataType::Float32, "random-tag");
361 let _ = value.convert::<DataType, Type<u64>>();
362 }
363
364 #[test]
365 fn test_description() {
366 struct Display(Option<Any>);
367
368 impl std::fmt::Display for Display {
369 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
370 match &self.0 {
371 Some(v) => Any::description::<DataType, Type<f32>>(f, Some(&v), "my-tag"),
372 None => Any::description::<DataType, Type<f32>>(f, None, "my-tag"),
373 }
374 }
375 }
376
377 // No contained value - document the expected conversion.
378 assert_eq!(Display(None).to_string(), "tag \"my-tag\"\nfloat32",);
379
380 // Matching contained value.
381 assert_eq!(
382 Display(Some(Any::untagged(DataType::Float32))).to_string(),
383 "successful match",
384 );
385
386 // Successful down cast - unsuccessful match.
387 assert_eq!(
388 Display(Some(Any::untagged(DataType::UInt64))).to_string(),
389 "expected \"float32\" but found \"uint64\"",
390 );
391
392 // Unsuccessful down cast.
393 assert_eq!(
394 Display(Some(Any::untagged(0usize))).to_string(),
395 "expected tag \"my-tag\" - instead got \"\"",
396 );
397 }
398}