Skip to main content

diskann_benchmark_runner/
any.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use crate::dispatcher::{DispatchRule, FailureScore, MatchScore};
7
8/// An refinement of [`std::any::Any`] with an associated name (tag) and serialization.
9///
10/// This type represents deserialized inputs returned from [`crate::Input::try_deserialize`]
11/// and is passed to beckend benchmarks for matching and execution.
12#[derive(Debug)]
13pub struct Any {
14    any: Box<dyn SerializableAny>,
15    tag: &'static str,
16}
17
18/// The score given unsuccessful downcasts in [`Any::try_match`].
19pub const MATCH_FAIL: FailureScore = FailureScore(10_000);
20
21impl Any {
22    /// Construct a new [`Any`] around `any` and associate it with the name `tag`.
23    ///
24    /// The tag is included as merely a debugging and readability aid and usually should
25    /// belong to a [`crate::Input::tag`] that generated `any`.
26    pub fn new<T>(any: T, tag: &'static str) -> Self
27    where
28        T: serde::Serialize + std::fmt::Debug + 'static,
29    {
30        Self {
31            any: Box::new(any),
32            tag,
33        }
34    }
35
36    /// A lower level API for constructing an [`Any`] that decouples the serialized
37    /// representation from the inmemory representation.
38    ///
39    /// When serialized, the **exact** representation of `repr` will be used.
40    ///
41    /// This is useful in some contexts where as part of input resolution, a fully resolved
42    /// input struct contains elements that are not serializable.
43    ///
44    /// Like [`Any::new`], the tag is included for debugging and readability.
45    pub fn raw<T>(any: T, repr: serde_json::Value, tag: &'static str) -> Self
46    where
47        T: std::fmt::Debug + 'static,
48    {
49        Self {
50            any: Box::new(Raw::new(any, repr)),
51            tag,
52        }
53    }
54
55    /// Return the benchmark tag associated with this benchmarks.
56    pub fn tag(&self) -> &'static str {
57        self.tag
58    }
59
60    /// Return the Rust [`std::any::TypeId`] for the contained object.
61    pub fn type_id(&self) -> std::any::TypeId {
62        self.any.as_any().type_id()
63    }
64
65    /// Return `true` if the runtime value is `T`. Otherwise, return false.
66    ///
67    /// ```rust
68    /// use diskann_benchmark_runner::any::Any;
69    ///
70    /// let value = Any::new(42usize, "usize");
71    /// assert!(value.is::<usize>());
72    /// assert!(!value.is::<u32>());
73    /// ```
74    #[must_use = "this function has no side effects"]
75    pub fn is<T>(&self) -> bool
76    where
77        T: std::any::Any,
78    {
79        self.any.as_any().is::<T>()
80    }
81
82    /// Return a reference to the contained object if it's runtime type is `T`.
83    ///
84    /// Otherwise return `None`.
85    ///
86    /// ```rust
87    /// use diskann_benchmark_runner::any::Any;
88    ///
89    /// let value = Any::new(42usize, "usize");
90    /// assert_eq!(*value.downcast_ref::<usize>().unwrap(), 42);
91    /// assert!(value.downcast_ref::<u32>().is_none());
92    /// ```
93    pub fn downcast_ref<T>(&self) -> Option<&T>
94    where
95        T: std::any::Any,
96    {
97        self.any.as_any().downcast_ref::<T>()
98    }
99
100    /// Attempt to downcast self to `T` and if succssful, try matching `&T` with `U` using
101    /// [`crate::dispatcher::DispatchRule`].
102    ///
103    /// Otherwise, return `Err(diskann_benchmark_runner::any::MATCH_FAIL)`.
104    ///
105    /// ```rust
106    /// use diskann_benchmark_runner::{
107    ///     any::Any,
108    ///     dispatcher::{self, MatchScore, FailureScore},
109    ///     utils::datatype::{self, DataType, Type},
110    /// };
111    ///
112    /// let value = Any::new(DataType::Float32, "datatype");
113    ///
114    /// // A successful down cast and successful match.
115    /// assert_eq!(
116    ///     value.try_match::<DataType, Type<f32>>().unwrap(),
117    ///     MatchScore(0),
118    /// );
119    ///
120    /// // A successful down cast but unsuccessful match.
121    /// assert_eq!(
122    ///     value.try_match::<DataType, Type<f64>>().unwrap_err(),
123    ///     datatype::MATCH_FAIL,
124    /// );
125    ///
126    /// // An unsuccessful down cast.
127    /// let value = Any::new(0usize, "usize");
128    /// assert_eq!(
129    ///     value.try_match::<DataType, Type<f32>>().unwrap_err(),
130    ///     diskann_benchmark_runner::any::MATCH_FAIL,
131    /// );
132    /// ```
133    pub fn try_match<'a, T, U>(&'a self) -> Result<MatchScore, FailureScore>
134    where
135        U: DispatchRule<&'a T>,
136        T: 'static,
137    {
138        if let Some(cast) = self.downcast_ref::<T>() {
139            U::try_match(&cast)
140        } else {
141            Err(MATCH_FAIL)
142        }
143    }
144
145    /// Attempt to downcast self to `T` and if succssful, try converting `&T` with `U` using
146    /// [`crate::dispatcher::DispatchRule`].
147    ///
148    /// If unsuccessful, returns an error.
149    ///
150    /// ```rust
151    /// use diskann_benchmark_runner::{
152    ///     any::Any,
153    ///     dispatcher::{self, MatchScore, FailureScore},
154    ///     utils::datatype::{self, DataType, Type},
155    /// };
156    ///
157    /// let value = Any::new(DataType::Float32, "datatype");
158    ///
159    /// // A successful down cast and successful conversion.
160    /// let _: Type<f32> = value.convert::<DataType, _>().unwrap();
161    /// ```
162    pub fn convert<'a, T, U>(&'a self) -> anyhow::Result<U>
163    where
164        U: DispatchRule<&'a T>,
165        anyhow::Error: From<U::Error>,
166        T: 'static,
167    {
168        if let Some(cast) = self.downcast_ref::<T>() {
169            Ok(U::convert(cast)?)
170        } else {
171            Err(anyhow::Error::msg("invalid dispatch"))
172        }
173    }
174
175    /// A wrapper for [`DispatchRule::description`].
176    ///
177    /// If `from` is `None` - document the expected tag for the input and return
178    /// `<U as DispatchRule<&T>>::description(f, None)`.
179    ///
180    /// If `from` is `Some` - attempt to downcast to `T`. If successful, return the dispatch
181    /// rule description for `U` on the doncast reference. Otherwise, return the expected tag.
182    ///
183    /// ```rust
184    /// use diskann_benchmark_runner::{
185    ///     any::Any,
186    ///     utils::datatype::{self, DataType, Type},
187    /// };
188    ///
189    /// use std::io::Write;
190    ///
191    /// struct Display(Option<Any>);
192    ///
193    /// impl std::fmt::Display for Display {
194    ///     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
195    ///         match &self.0 {
196    ///             Some(v) => Any::description::<DataType, Type<f32>>(f, Some(&&v), "my-tag"),
197    ///             None => Any::description::<DataType, Type<f32>>(f, None, "my-tag"),
198    ///         }
199    ///     }
200    /// }
201    ///
202    /// // No contained value - document the expected conversion.
203    /// assert_eq!(
204    ///     Display(None).to_string(),
205    ///     "tag \"my-tag\"\nfloat32",
206    /// );
207    ///
208    /// // Matching contained value.
209    /// assert_eq!(
210    ///     Display(Some(Any::new(DataType::Float32, "datatype"))).to_string(),
211    ///     "successful match",
212    /// );
213    ///
214    /// // Successful down cast - unsuccessful match.
215    /// assert_eq!(
216    ///     Display(Some(Any::new(DataType::UInt64, "datatype"))).to_string(),
217    ///     "expected \"float32\" but found \"uint64\"",
218    /// );
219    ///
220    /// // Unsuccessful down cast.
221    /// assert_eq!(
222    ///     Display(Some(Any::new(0usize, "another-tag"))).to_string(),
223    ///     "expected tag \"my-tag\" - instead got \"another-tag\"",
224    /// );
225    /// ```
226    pub fn description<'a, T, U>(
227        f: &mut std::fmt::Formatter<'_>,
228        from: Option<&&'a Self>,
229        tag: impl std::fmt::Display,
230    ) -> std::fmt::Result
231    where
232        U: DispatchRule<&'a T>,
233        T: 'static,
234    {
235        match from {
236            Some(this) => match this.downcast_ref::<T>() {
237                Some(a) => U::description(f, Some(&a)),
238                None => write!(
239                    f,
240                    "expected tag \"{}\" - instead got \"{}\"",
241                    tag,
242                    this.tag(),
243                ),
244            },
245            None => {
246                writeln!(f, "tag \"{}\"", tag)?;
247                U::description(f, None::<&&T>)
248            }
249        }
250    }
251
252    /// Serialize the contained object to a [`serde_json::Value`].
253    pub fn serialize(&self) -> Result<serde_json::Value, serde_json::Error> {
254        self.any.dump()
255    }
256}
257
258/// Used in `DispatchRule::description(f, _)` to ensure that additional description
259/// lines are properly aligned.
260#[macro_export]
261macro_rules! describeln {
262    ($writer:ident, $fmt:literal) => {
263        writeln!($writer, concat!("        ", $fmt))
264    };
265    ($writer:ident, $fmt:literal, $($args:expr),* $(,)?) => {
266        writeln!($writer, concat!("        ", $fmt), $($args,)*)
267    };
268}
269
270trait SerializableAny: std::fmt::Debug {
271    fn as_any(&self) -> &dyn std::any::Any;
272    fn dump(&self) -> Result<serde_json::Value, serde_json::Error>;
273}
274
275impl<T> SerializableAny for T
276where
277    T: std::any::Any + serde::Serialize + std::fmt::Debug,
278{
279    fn as_any(&self) -> &dyn std::any::Any {
280        self
281    }
282
283    fn dump(&self) -> Result<serde_json::Value, serde_json::Error> {
284        serde_json::to_value(self)
285    }
286}
287
288// A backend type that allows users to decouple the serialized representation from the
289// actual type.
290#[derive(Debug)]
291struct Raw<T> {
292    value: T,
293    repr: serde_json::Value,
294}
295
296impl<T> Raw<T> {
297    fn new(value: T, repr: serde_json::Value) -> Self {
298        Self { value, repr }
299    }
300}
301
302impl<T> SerializableAny for Raw<T>
303where
304    T: std::any::Any + std::fmt::Debug,
305{
306    fn as_any(&self) -> &dyn std::any::Any {
307        &self.value
308    }
309
310    fn dump(&self) -> Result<serde_json::Value, serde_json::Error> {
311        Ok(self.repr.clone())
312    }
313}
314
315///////////
316// Tests //
317///////////
318
319#[cfg(test)]
320mod tests {
321    use super::*;
322
323    use crate::utils::datatype::{self, DataType, Type};
324
325    #[test]
326    fn test_new() {
327        let x = Any::new(42usize, "my-tag");
328        assert_eq!(x.tag(), "my-tag");
329        assert_eq!(x.type_id(), std::any::TypeId::of::<usize>());
330        assert!(x.is::<usize>());
331        assert!(!x.is::<u32>());
332        assert_eq!(*x.downcast_ref::<usize>().unwrap(), 42);
333        assert!(x.downcast_ref::<u32>().is_none());
334
335        assert!(!x.is::<Raw<usize>>());
336        assert!(!x.is::<Raw<u32>>());
337        assert!(x.downcast_ref::<Raw<usize>>().is_none());
338        assert!(x.downcast_ref::<Raw<u32>>().is_none());
339
340        assert_eq!(
341            x.serialize().unwrap(),
342            serde_json::Value::Number(serde_json::value::Number::from(42usize))
343        );
344    }
345
346    #[test]
347    fn test_raw() {
348        let repr = serde_json::json!(1.5);
349        let x = Any::raw(42usize, repr, "my-tag");
350        assert_eq!(x.tag(), "my-tag");
351        assert_eq!(x.type_id(), std::any::TypeId::of::<usize>());
352        assert!(x.is::<usize>());
353        assert!(!x.is::<u32>());
354        assert_eq!(*x.downcast_ref::<usize>().unwrap(), 42);
355        assert!(x.downcast_ref::<u32>().is_none());
356
357        assert!(!x.is::<Raw<usize>>());
358        assert!(!x.is::<Raw<u32>>());
359        assert!(x.downcast_ref::<Raw<usize>>().is_none());
360        assert!(x.downcast_ref::<Raw<u32>>().is_none());
361
362        assert_eq!(
363            x.serialize().unwrap(),
364            serde_json::Value::Number(serde_json::value::Number::from_f64(1.5).unwrap())
365        );
366    }
367
368    #[test]
369    fn test_try_match() {
370        let value = Any::new(DataType::Float32, "random-tag");
371
372        // A successful down cast and successful match.
373        assert_eq!(
374            value.try_match::<DataType, Type<f32>>().unwrap(),
375            MatchScore(0),
376        );
377
378        // A successful down cast but unsuccessful match.
379        assert_eq!(
380            value.try_match::<DataType, Type<f64>>().unwrap_err(),
381            datatype::MATCH_FAIL,
382        );
383
384        // An unsuccessful down cast.
385        let value = Any::new(0usize, "");
386        assert_eq!(
387            value.try_match::<DataType, Type<f32>>().unwrap_err(),
388            MATCH_FAIL,
389        );
390    }
391
392    #[test]
393    fn test_convert() {
394        let value = Any::new(DataType::Float32, "random-tag");
395
396        // A successful down cast and successful conversion.
397        let _: Type<f32> = value.convert::<DataType, _>().unwrap();
398
399        // An invalid match should return an error.
400        let value = Any::new(0usize, "random-rag");
401        let err = value.convert::<DataType, Type<f32>>().unwrap_err();
402        let msg = err.to_string();
403        assert!(msg.contains("invalid dispatch"), "{}", msg);
404    }
405
406    #[test]
407    #[should_panic(expected = "invalid dispatch")]
408    fn test_convert_inner_error() {
409        let value = Any::new(DataType::Float32, "random-tag");
410        let _ = value.convert::<DataType, Type<u64>>();
411    }
412
413    #[test]
414    fn test_description() {
415        struct Display(Option<Any>);
416
417        impl std::fmt::Display for Display {
418            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
419                match &self.0 {
420                    Some(v) => Any::description::<DataType, Type<f32>>(f, Some(&v), "my-tag"),
421                    None => Any::description::<DataType, Type<f32>>(f, None, "my-tag"),
422                }
423            }
424        }
425
426        // No contained value - document the expected conversion.
427        assert_eq!(Display(None).to_string(), "tag \"my-tag\"\nfloat32",);
428
429        // Matching contained value.
430        assert_eq!(
431            Display(Some(Any::new(DataType::Float32, ""))).to_string(),
432            "successful match",
433        );
434
435        // Successful down cast - unsuccessful match.
436        assert_eq!(
437            Display(Some(Any::new(DataType::UInt64, ""))).to_string(),
438            "expected \"float32\" but found \"uint64\"",
439        );
440
441        // Unsuccessful down cast.
442        assert_eq!(
443            Display(Some(Any::new(0usize, ""))).to_string(),
444            "expected tag \"my-tag\" - instead got \"\"",
445        );
446    }
447}