Skip to main content

diskann_benchmark_runner/
any.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use crate::dispatcher::{DispatchRule, FailureScore, MatchScore};
7
8/// An refinement of [`std::any::Any`] with an associated name (tag) and serialization.
9///
10/// This type represents deserialized inputs returned from [`crate::Input::try_deserialize`]
11/// and is passed to beckend benchmarks for matching and execution.
12#[derive(Debug)]
13pub struct Any {
14    any: Box<dyn SerializableAny>,
15    tag: &'static str,
16}
17
18/// The score given unsuccessful downcasts in [`Any::try_match`].
19pub const MATCH_FAIL: FailureScore = FailureScore(10_000);
20
21impl Any {
22    /// Construct a new [`Any`] around `any` and associate it with the name `tag`.
23    ///
24    /// The tag is included as merely a debugging and readability aid and usually should
25    /// belong to a [`crate::Input::tag`] that generated `any`.
26    pub fn new<T>(any: T, tag: &'static str) -> Self
27    where
28        T: serde::Serialize + std::fmt::Debug + 'static,
29    {
30        Self {
31            any: Box::new(any),
32            tag,
33        }
34    }
35
36    /// A lower level API for constructing an [`Any`] that decouples the serialized
37    /// representation from the inmemory representation.
38    ///
39    /// When serialized, the **exact** representation of `repr` will be used.
40    ///
41    /// This is useful in some contexts where as part of input resolution, a fully resolved
42    /// input struct contains elements that are not serializable.
43    ///
44    /// Like [`Any::new`], the tag is included for debugging and readability.
45    pub fn raw<T>(any: T, repr: serde_json::Value, tag: &'static str) -> Self
46    where
47        T: std::fmt::Debug + 'static,
48    {
49        Self {
50            any: Box::new(Raw::new(any, repr)),
51            tag,
52        }
53    }
54
55    /// Return the benchmark tag associated with this benchmarks.
56    pub fn tag(&self) -> &'static str {
57        self.tag
58    }
59
60    /// Return the Rust [`std::any::TypeId`] for the contained object.
61    pub fn type_id(&self) -> std::any::TypeId {
62        self.any.as_any().type_id()
63    }
64
65    /// Return `true` if the runtime value is `T`. Otherwise, return false.
66    ///
67    /// ```rust
68    /// use diskann_benchmark_runner::any::Any;
69    ///
70    /// let value = Any::new(42usize, "usize");
71    /// assert!(value.is::<usize>());
72    /// assert!(!value.is::<u32>());
73    /// ```
74    #[must_use = "this function has no side effects"]
75    pub fn is<T>(&self) -> bool
76    where
77        T: std::any::Any,
78    {
79        self.any.as_any().is::<T>()
80    }
81
82    /// Return a reference to the contained object if it's runtime type is `T`.
83    ///
84    /// Otherwise return `None`.
85    ///
86    /// ```rust
87    /// use diskann_benchmark_runner::any::Any;
88    ///
89    /// let value = Any::new(42usize, "usize");
90    /// assert_eq!(*value.downcast_ref::<usize>().unwrap(), 42);
91    /// assert!(value.downcast_ref::<u32>().is_none());
92    /// ```
93    pub fn downcast_ref<T>(&self) -> Option<&T>
94    where
95        T: std::any::Any,
96    {
97        self.any.as_any().downcast_ref::<T>()
98    }
99
100    /// Attempt to downcast self to `T` and if succssful, try matching `&T` with `U` using
101    /// [`crate::dispatcher::DispatchRule`].
102    ///
103    /// Otherwise, return `Err(diskann_benchmark_runner::any::MATCH_FAIL)`.
104    ///
105    /// ```rust
106    /// use diskann_benchmark_runner::{
107    ///     any::Any,
108    ///     dispatcher::{self, MatchScore, FailureScore},
109    ///     utils::datatype::{self, DataType, Type},
110    /// };
111    ///
112    /// let value = Any::new(DataType::Float32, "datatype");
113    ///
114    /// // A successful down cast and successful match.
115    /// assert_eq!(
116    ///     value.try_match::<DataType, Type<f32>>().unwrap(),
117    ///     MatchScore(0),
118    /// );
119    ///
120    /// // A successful down cast but unsuccessful match.
121    /// assert_eq!(
122    ///     value.try_match::<DataType, Type<f64>>().unwrap_err(),
123    ///     datatype::MATCH_FAIL,
124    /// );
125    ///
126    /// // An unsuccessful down cast.
127    /// let value = Any::new(0usize, "usize");
128    /// assert_eq!(
129    ///     value.try_match::<DataType, Type<f32>>().unwrap_err(),
130    ///     diskann_benchmark_runner::any::MATCH_FAIL,
131    /// );
132    /// ```
133    pub fn try_match<'a, T, U>(&'a self) -> Result<MatchScore, FailureScore>
134    where
135        U: DispatchRule<&'a T>,
136        T: 'static,
137    {
138        if let Some(cast) = self.downcast_ref::<T>() {
139            U::try_match(&cast)
140        } else {
141            Err(MATCH_FAIL)
142        }
143    }
144
145    /// Attempt to downcast self to `T` and if succssful, try converting `&T` with `U` using
146    /// [`crate::dispatcher::DispatchRule`].
147    ///
148    /// If unsuccessful, returns an error.
149    ///
150    /// ```rust
151    /// use diskann_benchmark_runner::{
152    ///     any::Any,
153    ///     dispatcher::{self, MatchScore, FailureScore},
154    ///     utils::datatype::{self, DataType, Type},
155    /// };
156    ///
157    /// let value = Any::new(DataType::Float32, "datatype");
158    ///
159    /// // A successful down cast and successful conversion.
160    /// let _: Type<f32> = value.convert::<DataType, _>().unwrap();
161    /// ```
162    pub fn convert<'a, T, U>(&'a self) -> anyhow::Result<U>
163    where
164        U: DispatchRule<&'a T>,
165        anyhow::Error: From<U::Error>,
166        T: 'static,
167    {
168        if let Some(cast) = self.downcast_ref::<T>() {
169            Ok(U::convert(cast)?)
170        } else {
171            Err(anyhow::Error::msg("invalid dispatch"))
172        }
173    }
174
175    /// A wrapper for [`DispatchRule::description`].
176    ///
177    /// If `from` is `None` - document the expected tag for the input and return
178    /// `<U as DispatchRule<&T>>::description(f, None)`.
179    ///
180    /// If `from` is `Some` - attempt to downcast to `T`. If successful, return the dispatch
181    /// rule description for `U` on the doncast reference. Otherwise, return the expected tag.
182    ///
183    /// ```rust
184    /// use diskann_benchmark_runner::{
185    ///     any::Any,
186    ///     utils::datatype::{self, DataType, Type},
187    /// };
188    ///
189    /// use std::io::Write;
190    ///
191    /// struct Display(Option<Any>);
192    ///
193    /// impl std::fmt::Display for Display {
194    ///     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
195    ///         match &self.0 {
196    ///             Some(v) => Any::description::<DataType, Type<f32>>(f, Some(&&v), "my-tag"),
197    ///             None => Any::description::<DataType, Type<f32>>(f, None, "my-tag"),
198    ///         }
199    ///     }
200    /// }
201    ///
202    /// // No contained value - document the expected conversion.
203    /// assert_eq!(
204    ///     Display(None).to_string(),
205    ///     "tag \"my-tag\"\nfloat32",
206    /// );
207    ///
208    /// // Matching contained value.
209    /// assert_eq!(
210    ///     Display(Some(Any::new(DataType::Float32, "datatype"))).to_string(),
211    ///     "successful match",
212    /// );
213    ///
214    /// // Successful down cast - unsuccessful match.
215    /// assert_eq!(
216    ///     Display(Some(Any::new(DataType::UInt64, "datatype"))).to_string(),
217    ///     "expected \"float32\" but found \"uint64\"",
218    /// );
219    ///
220    /// // Unsuccessful down cast.
221    /// assert_eq!(
222    ///     Display(Some(Any::new(0usize, "another-tag"))).to_string(),
223    ///     "expected tag \"my-tag\" - instead got \"another-tag\"",
224    /// );
225    /// ```
226    pub fn description<'a, T, U>(
227        f: &mut std::fmt::Formatter<'_>,
228        from: Option<&&'a Self>,
229        tag: impl std::fmt::Display,
230    ) -> std::fmt::Result
231    where
232        U: DispatchRule<&'a T>,
233        T: 'static,
234    {
235        match from {
236            Some(this) => match this.downcast_ref::<T>() {
237                Some(a) => U::description(f, Some(&a)),
238                None => write!(
239                    f,
240                    "expected tag \"{}\" - instead got \"{}\"",
241                    tag,
242                    this.tag(),
243                ),
244            },
245            None => {
246                writeln!(f, "tag \"{}\"", tag)?;
247                U::description(f, None::<&&T>)
248            }
249        }
250    }
251
252    /// Serialize the contained object to a [`serde_json::Value`].
253    pub fn serialize(&self) -> Result<serde_json::Value, serde_json::Error> {
254        self.any.dump()
255    }
256}
257
258trait SerializableAny: std::fmt::Debug {
259    fn as_any(&self) -> &dyn std::any::Any;
260    fn dump(&self) -> Result<serde_json::Value, serde_json::Error>;
261}
262
263impl<T> SerializableAny for T
264where
265    T: std::any::Any + serde::Serialize + std::fmt::Debug,
266{
267    fn as_any(&self) -> &dyn std::any::Any {
268        self
269    }
270
271    fn dump(&self) -> Result<serde_json::Value, serde_json::Error> {
272        serde_json::to_value(self)
273    }
274}
275
276// A backend type that allows users to decouple the serialized representation from the
277// actual type.
278#[derive(Debug)]
279struct Raw<T> {
280    value: T,
281    repr: serde_json::Value,
282}
283
284impl<T> Raw<T> {
285    fn new(value: T, repr: serde_json::Value) -> Self {
286        Self { value, repr }
287    }
288}
289
290impl<T> SerializableAny for Raw<T>
291where
292    T: std::any::Any + std::fmt::Debug,
293{
294    fn as_any(&self) -> &dyn std::any::Any {
295        &self.value
296    }
297
298    fn dump(&self) -> Result<serde_json::Value, serde_json::Error> {
299        Ok(self.repr.clone())
300    }
301}
302
303///////////
304// Tests //
305///////////
306
307#[cfg(test)]
308mod tests {
309    use super::*;
310
311    use crate::utils::datatype::{self, DataType, Type};
312
313    #[test]
314    fn test_new() {
315        let x = Any::new(42usize, "my-tag");
316        assert_eq!(x.tag(), "my-tag");
317        assert_eq!(x.type_id(), std::any::TypeId::of::<usize>());
318        assert!(x.is::<usize>());
319        assert!(!x.is::<u32>());
320        assert_eq!(*x.downcast_ref::<usize>().unwrap(), 42);
321        assert!(x.downcast_ref::<u32>().is_none());
322
323        assert!(!x.is::<Raw<usize>>());
324        assert!(!x.is::<Raw<u32>>());
325        assert!(x.downcast_ref::<Raw<usize>>().is_none());
326        assert!(x.downcast_ref::<Raw<u32>>().is_none());
327
328        assert_eq!(
329            x.serialize().unwrap(),
330            serde_json::Value::Number(serde_json::value::Number::from(42usize))
331        );
332    }
333
334    #[test]
335    fn test_raw() {
336        let repr = serde_json::json!(1.5);
337        let x = Any::raw(42usize, repr, "my-tag");
338        assert_eq!(x.tag(), "my-tag");
339        assert_eq!(x.type_id(), std::any::TypeId::of::<usize>());
340        assert!(x.is::<usize>());
341        assert!(!x.is::<u32>());
342        assert_eq!(*x.downcast_ref::<usize>().unwrap(), 42);
343        assert!(x.downcast_ref::<u32>().is_none());
344
345        assert!(!x.is::<Raw<usize>>());
346        assert!(!x.is::<Raw<u32>>());
347        assert!(x.downcast_ref::<Raw<usize>>().is_none());
348        assert!(x.downcast_ref::<Raw<u32>>().is_none());
349
350        assert_eq!(
351            x.serialize().unwrap(),
352            serde_json::Value::Number(serde_json::value::Number::from_f64(1.5).unwrap())
353        );
354    }
355
356    #[test]
357    fn test_try_match() {
358        let value = Any::new(DataType::Float32, "random-tag");
359
360        // A successful down cast and successful match.
361        assert_eq!(
362            value.try_match::<DataType, Type<f32>>().unwrap(),
363            MatchScore(0),
364        );
365
366        // A successful down cast but unsuccessful match.
367        assert_eq!(
368            value.try_match::<DataType, Type<f64>>().unwrap_err(),
369            datatype::MATCH_FAIL,
370        );
371
372        // An unsuccessful down cast.
373        let value = Any::new(0usize, "");
374        assert_eq!(
375            value.try_match::<DataType, Type<f32>>().unwrap_err(),
376            MATCH_FAIL,
377        );
378    }
379
380    #[test]
381    fn test_convert() {
382        let value = Any::new(DataType::Float32, "random-tag");
383
384        // A successful down cast and successful conversion.
385        let _: Type<f32> = value.convert::<DataType, _>().unwrap();
386
387        // An invalid match should return an error.
388        let value = Any::new(0usize, "random-rag");
389        let err = value.convert::<DataType, Type<f32>>().unwrap_err();
390        let msg = err.to_string();
391        assert!(msg.contains("invalid dispatch"), "{}", msg);
392    }
393
394    #[test]
395    #[should_panic(expected = "invalid dispatch")]
396    fn test_convert_inner_error() {
397        let value = Any::new(DataType::Float32, "random-tag");
398        let _ = value.convert::<DataType, Type<u64>>();
399    }
400
401    #[test]
402    fn test_description() {
403        struct Display(Option<Any>);
404
405        impl std::fmt::Display for Display {
406            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
407                match &self.0 {
408                    Some(v) => Any::description::<DataType, Type<f32>>(f, Some(&v), "my-tag"),
409                    None => Any::description::<DataType, Type<f32>>(f, None, "my-tag"),
410                }
411            }
412        }
413
414        // No contained value - document the expected conversion.
415        assert_eq!(Display(None).to_string(), "tag \"my-tag\"\nfloat32",);
416
417        // Matching contained value.
418        assert_eq!(
419            Display(Some(Any::new(DataType::Float32, ""))).to_string(),
420            "successful match",
421        );
422
423        // Successful down cast - unsuccessful match.
424        assert_eq!(
425            Display(Some(Any::new(DataType::UInt64, ""))).to_string(),
426            "expected \"float32\" but found \"uint64\"",
427        );
428
429        // Unsuccessful down cast.
430        assert_eq!(
431            Display(Some(Any::new(0usize, ""))).to_string(),
432            "expected tag \"my-tag\" - instead got \"\"",
433        );
434    }
435}