diskann_benchmark_runner/
any.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use crate::dispatcher::{DispatchRule, FailureScore, MatchScore};
7
8/// An refinement of [`std::any::Any`] with an associated name (tag) and serialization.
9///
10/// This type represents deserialized inputs returned from [`crate::Input::try_deserialize`]
11/// and is passed to beckend benchmarks for matching and execution.
12#[derive(Debug)]
13pub struct Any {
14    any: Box<dyn SerializableAny>,
15    tag: &'static str,
16}
17
18/// The score given unsuccessful downcasts in [`Any::try_match`].
19pub const MATCH_FAIL: FailureScore = FailureScore(10_000);
20
21impl Any {
22    /// A crate-local constructor.
23    ///
24    /// We don't expose this as a public method to ensure that users go through
25    /// [`crate::Checker::any`] and therefore get the correct tag and run
26    /// [`crate::CheckDeserialization`].
27    pub(crate) fn new<T>(any: T, tag: &'static str) -> Self
28    where
29        T: serde::Serialize + std::fmt::Debug + 'static,
30    {
31        Self {
32            any: Box::new(any),
33            tag,
34        }
35    }
36
37    /// Return a new [`Any`] without a tag.
38    ///
39    /// This method exists for examples and testing and generally should not be used.
40    ///
41    /// Users should use [`crate::Checker::any`] to construct a properly tagged [`Any`].
42    pub fn untagged<T>(any: T) -> Self
43    where
44        T: serde::Serialize + std::fmt::Debug + 'static,
45    {
46        Self::new(any, "")
47    }
48
49    /// Return the benchmark tag associated with this benchmarks.
50    pub fn tag(&self) -> &'static str {
51        self.tag
52    }
53
54    /// Return the Rust [`std::any::TypeId`] for the contained object.
55    pub fn type_id(&self) -> std::any::TypeId {
56        self.any.as_any().type_id()
57    }
58
59    /// Return `true` if the runtime value is `T`. Otherwise, return false.
60    ///
61    /// ```rust
62    /// use diskann_benchmark_runner::any::Any;
63    ///
64    /// let value = Any::untagged(42usize);
65    /// assert!(value.is::<usize>());
66    /// assert!(!value.is::<u32>());
67    /// ```
68    #[must_use = "this function has no side effects"]
69    pub fn is<T>(&self) -> bool
70    where
71        T: std::any::Any,
72    {
73        self.any.as_any().is::<T>()
74    }
75
76    /// Return a reference to the contained object if it's runtime type is `T`.
77    ///
78    /// Otherwise return `None`.
79    ///
80    /// ```rust
81    /// use diskann_benchmark_runner::any::Any;
82    ///
83    /// let value = Any::untagged(42usize);
84    /// assert_eq!(*value.downcast_ref::<usize>().unwrap(), 42);
85    /// assert!(value.downcast_ref::<u32>().is_none());
86    /// ```
87    pub fn downcast_ref<T>(&self) -> Option<&T>
88    where
89        T: std::any::Any,
90    {
91        self.any.as_any().downcast_ref::<T>()
92    }
93
94    /// Attempt to downcast self to `T` and if succssful, try matching `&T` with `U` using
95    /// [`crate::dispatcher::DispatchRule`].
96    ///
97    /// Otherwise, return `Err(diskann_benchmark_runner::any::MATCH_FAIL)`.
98    ///
99    /// ```rust
100    /// use diskann_benchmark_runner::{
101    ///     any::Any,
102    ///     dispatcher::{self, MatchScore, FailureScore},
103    ///     utils::datatype::{self, DataType, Type},
104    /// };
105    ///
106    /// let value = Any::untagged(DataType::Float32);
107    ///
108    /// // A successful down cast and successful match.
109    /// assert_eq!(
110    ///     value.try_match::<DataType, Type<f32>>().unwrap(),
111    ///     MatchScore(0),
112    /// );
113    ///
114    /// // A successful down cast but unsuccessful match.
115    /// assert_eq!(
116    ///     value.try_match::<DataType, Type<f64>>().unwrap_err(),
117    ///     datatype::MATCH_FAIL,
118    /// );
119    ///
120    /// // An unsuccessful down cast.
121    /// let value = Any::untagged(0usize);
122    /// assert_eq!(
123    ///     value.try_match::<DataType, Type<f32>>().unwrap_err(),
124    ///     diskann_benchmark_runner::any::MATCH_FAIL,
125    /// );
126    /// ```
127    pub fn try_match<'a, T, U>(&'a self) -> Result<MatchScore, FailureScore>
128    where
129        U: DispatchRule<&'a T>,
130        T: 'static,
131    {
132        if let Some(cast) = self.downcast_ref::<T>() {
133            U::try_match(&cast)
134        } else {
135            Err(MATCH_FAIL)
136        }
137    }
138
139    /// Attempt to downcast self to `T` and if succssful, try converting `&T` with `U` using
140    /// [`crate::dispatcher::DispatchRule`].
141    ///
142    /// If unsuccessful, returns an error.
143    ///
144    /// ```rust
145    /// use diskann_benchmark_runner::{
146    ///     any::Any,
147    ///     dispatcher::{self, MatchScore, FailureScore},
148    ///     utils::datatype::{self, DataType, Type},
149    /// };
150    ///
151    /// let value = Any::untagged(DataType::Float32);
152    ///
153    /// // A successful down cast and successful conversion.
154    /// let _: Type<f32> = value.convert::<DataType, _>().unwrap();
155    /// ```
156    pub fn convert<'a, T, U>(&'a self) -> anyhow::Result<U>
157    where
158        U: DispatchRule<&'a T>,
159        anyhow::Error: From<U::Error>,
160        T: 'static,
161    {
162        if let Some(cast) = self.downcast_ref::<T>() {
163            Ok(U::convert(cast)?)
164        } else {
165            Err(anyhow::Error::msg("invalid dispatch"))
166        }
167    }
168
169    /// A wrapper for [`DispatchRule::description`].
170    ///
171    /// If `from` is `None` - document the expected tag for the input and return
172    /// `<U as DispatchRule<&T>>::description(f, None)`.
173    ///
174    /// If `from` is `Some` - attempt to downcast to `T`. If successful, return the dispatch
175    /// rule description for `U` on the doncast reference. Otherwise, return the expected tag.
176    ///
177    /// ```rust
178    /// use diskann_benchmark_runner::{
179    ///     any::Any,
180    ///     utils::datatype::{self, DataType, Type},
181    /// };
182    ///
183    /// use std::io::Write;
184    ///
185    /// struct Display(Option<Any>);
186    ///
187    /// impl std::fmt::Display for Display {
188    ///     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
189    ///         match &self.0 {
190    ///             Some(v) => Any::description::<DataType, Type<f32>>(f, Some(&&v), "my-tag"),
191    ///             None => Any::description::<DataType, Type<f32>>(f, None, "my-tag"),
192    ///         }
193    ///     }
194    /// }
195    ///
196    /// // No contained value - document the expected conversion.
197    /// assert_eq!(
198    ///     Display(None).to_string(),
199    ///     "tag \"my-tag\"\nfloat32",
200    /// );
201    ///
202    /// // Matching contained value.
203    /// assert_eq!(
204    ///     Display(Some(Any::untagged(DataType::Float32))).to_string(),
205    ///     "successful match",
206    /// );
207    ///
208    /// // Successful down cast - unsuccessful match.
209    /// assert_eq!(
210    ///     Display(Some(Any::untagged(DataType::UInt64))).to_string(),
211    ///     "expected \"float32\" but found \"uint64\"",
212    /// );
213    ///
214    /// // Unsuccessful down cast.
215    /// assert_eq!(
216    ///     Display(Some(Any::untagged(0usize))).to_string(),
217    ///     "expected tag \"my-tag\" - instead got \"\"",
218    /// );
219    /// ```
220    pub fn description<'a, T, U>(
221        f: &mut std::fmt::Formatter<'_>,
222        from: Option<&&'a Self>,
223        tag: impl std::fmt::Display,
224    ) -> std::fmt::Result
225    where
226        U: DispatchRule<&'a T>,
227        T: 'static,
228    {
229        match from {
230            Some(this) => match this.downcast_ref::<T>() {
231                Some(a) => U::description(f, Some(&a)),
232                None => write!(
233                    f,
234                    "expected tag \"{}\" - instead got \"{}\"",
235                    tag,
236                    this.tag(),
237                ),
238            },
239            None => {
240                writeln!(f, "tag \"{}\"", tag)?;
241                U::description(f, None::<&&T>)
242            }
243        }
244    }
245
246    /// Serialize the contained object to a [`serde_json::Value`].
247    pub fn serialize(&self) -> Result<serde_json::Value, serde_json::Error> {
248        self.any.dump()
249    }
250}
251
252/// Used in `DispatchRule::description(f, _)` to ensure that additional description
253/// lines are properly aligned.
254#[macro_export]
255macro_rules! describeln {
256    ($writer:ident, $fmt:literal) => {
257        writeln!($writer, concat!("        ", $fmt))
258    };
259    ($writer:ident, $fmt:literal, $($args:expr),* $(,)?) => {
260        writeln!($writer, concat!("        ", $fmt), $($args,)*)
261    };
262}
263
264trait SerializableAny: std::fmt::Debug {
265    fn as_any(&self) -> &dyn std::any::Any;
266    fn dump(&self) -> Result<serde_json::Value, serde_json::Error>;
267}
268
269impl<T> SerializableAny for T
270where
271    T: std::any::Any + serde::Serialize + std::fmt::Debug,
272{
273    fn as_any(&self) -> &dyn std::any::Any {
274        self
275    }
276
277    fn dump(&self) -> Result<serde_json::Value, serde_json::Error> {
278        serde_json::to_value(self)
279    }
280}
281
282///////////
283// Tests //
284///////////
285
286#[cfg(test)]
287mod tests {
288    use super::*;
289
290    use crate::utils::datatype::{self, DataType, Type};
291
292    #[test]
293    fn test_untagged() {
294        let x = Any::untagged(42usize);
295        assert_eq!(x.tag(), "");
296        assert_eq!(x.type_id(), std::any::TypeId::of::<usize>());
297        assert!(x.is::<usize>());
298        assert!(!x.is::<u32>());
299        assert_eq!(*x.downcast_ref::<usize>().unwrap(), 42);
300        assert!(x.downcast_ref::<u32>().is_none());
301    }
302
303    #[test]
304    fn test_new() {
305        let x = Any::new(42usize, "my-tag");
306        assert_eq!(x.tag(), "my-tag");
307        assert_eq!(x.type_id(), std::any::TypeId::of::<usize>());
308        assert!(x.is::<usize>());
309        assert!(!x.is::<u32>());
310        assert_eq!(*x.downcast_ref::<usize>().unwrap(), 42);
311        assert!(x.downcast_ref::<u32>().is_none());
312
313        assert_eq!(
314            x.serialize().unwrap(),
315            serde_json::Value::Number(serde_json::value::Number::from(42usize))
316        );
317    }
318
319    #[test]
320    fn test_try_match() {
321        let value = Any::new(DataType::Float32, "random-tag");
322
323        // A successful down cast and successful match.
324        assert_eq!(
325            value.try_match::<DataType, Type<f32>>().unwrap(),
326            MatchScore(0),
327        );
328
329        // A successful down cast but unsuccessful match.
330        assert_eq!(
331            value.try_match::<DataType, Type<f64>>().unwrap_err(),
332            datatype::MATCH_FAIL,
333        );
334
335        // An unsuccessful down cast.
336        let value = Any::untagged(0usize);
337        assert_eq!(
338            value.try_match::<DataType, Type<f32>>().unwrap_err(),
339            MATCH_FAIL,
340        );
341    }
342
343    #[test]
344    fn test_convert() {
345        let value = Any::new(DataType::Float32, "random-tag");
346
347        // A successful down cast and successful conversion.
348        let _: Type<f32> = value.convert::<DataType, _>().unwrap();
349
350        // An invalid match should return an error.
351        let value = Any::new(0usize, "random-rag");
352        let err = value.convert::<DataType, Type<f32>>().unwrap_err();
353        let msg = err.to_string();
354        assert!(msg.contains("invalid dispatch"), "{}", msg);
355    }
356
357    #[test]
358    #[should_panic(expected = "invalid dispatch")]
359    fn test_convert_inner_error() {
360        let value = Any::new(DataType::Float32, "random-tag");
361        let _ = value.convert::<DataType, Type<u64>>();
362    }
363
364    #[test]
365    fn test_description() {
366        struct Display(Option<Any>);
367
368        impl std::fmt::Display for Display {
369            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
370                match &self.0 {
371                    Some(v) => Any::description::<DataType, Type<f32>>(f, Some(&v), "my-tag"),
372                    None => Any::description::<DataType, Type<f32>>(f, None, "my-tag"),
373                }
374            }
375        }
376
377        // No contained value - document the expected conversion.
378        assert_eq!(Display(None).to_string(), "tag \"my-tag\"\nfloat32",);
379
380        // Matching contained value.
381        assert_eq!(
382            Display(Some(Any::untagged(DataType::Float32))).to_string(),
383            "successful match",
384        );
385
386        // Successful down cast - unsuccessful match.
387        assert_eq!(
388            Display(Some(Any::untagged(DataType::UInt64))).to_string(),
389            "expected \"float32\" but found \"uint64\"",
390        );
391
392        // Unsuccessful down cast.
393        assert_eq!(
394            Display(Some(Any::untagged(0usize))).to_string(),
395            "expected tag \"my-tag\" - instead got \"\"",
396        );
397    }
398}