diskann_benchmark_runner/any.rs
1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use crate::dispatcher::{DispatchRule, FailureScore, MatchScore};
7
8/// An refinement of [`std::any::Any`] with an associated name (tag) and serialization.
9///
10/// This type represents deserialized inputs returned from [`crate::Input::try_deserialize`]
11/// and is passed to beckend benchmarks for matching and execution.
12#[derive(Debug)]
13pub struct Any {
14 any: Box<dyn SerializableAny>,
15 tag: &'static str,
16}
17
18/// The score given unsuccessful downcasts in [`Any::try_match`].
19pub const MATCH_FAIL: FailureScore = FailureScore(10_000);
20
21impl Any {
22 /// Construct a new [`Any`] around `any` and associate it with the name `tag`.
23 ///
24 /// The tag is included as merely a debugging and readability aid and usually should
25 /// belong to a [`crate::Input::tag`] that generated `any`.
26 pub fn new<T>(any: T, tag: &'static str) -> Self
27 where
28 T: serde::Serialize + std::fmt::Debug + 'static,
29 {
30 Self {
31 any: Box::new(any),
32 tag,
33 }
34 }
35
36 /// A lower level API for constructing an [`Any`] that decouples the serialized
37 /// representation from the inmemory representation.
38 ///
39 /// When serialized, the **exact** representation of `repr` will be used.
40 ///
41 /// This is useful in some contexts where as part of input resolution, a fully resolved
42 /// input struct contains elements that are not serializable.
43 ///
44 /// Like [`Any::new`], the tag is included for debugging and readability.
45 pub fn raw<T>(any: T, repr: serde_json::Value, tag: &'static str) -> Self
46 where
47 T: std::fmt::Debug + 'static,
48 {
49 Self {
50 any: Box::new(Raw::new(any, repr)),
51 tag,
52 }
53 }
54
55 /// Return the benchmark tag associated with this benchmarks.
56 pub fn tag(&self) -> &'static str {
57 self.tag
58 }
59
60 /// Return the Rust [`std::any::TypeId`] for the contained object.
61 pub fn type_id(&self) -> std::any::TypeId {
62 self.any.as_any().type_id()
63 }
64
65 /// Return `true` if the runtime value is `T`. Otherwise, return false.
66 ///
67 /// ```rust
68 /// use diskann_benchmark_runner::any::Any;
69 ///
70 /// let value = Any::new(42usize, "usize");
71 /// assert!(value.is::<usize>());
72 /// assert!(!value.is::<u32>());
73 /// ```
74 #[must_use = "this function has no side effects"]
75 pub fn is<T>(&self) -> bool
76 where
77 T: std::any::Any,
78 {
79 self.any.as_any().is::<T>()
80 }
81
82 /// Return a reference to the contained object if it's runtime type is `T`.
83 ///
84 /// Otherwise return `None`.
85 ///
86 /// ```rust
87 /// use diskann_benchmark_runner::any::Any;
88 ///
89 /// let value = Any::new(42usize, "usize");
90 /// assert_eq!(*value.downcast_ref::<usize>().unwrap(), 42);
91 /// assert!(value.downcast_ref::<u32>().is_none());
92 /// ```
93 pub fn downcast_ref<T>(&self) -> Option<&T>
94 where
95 T: std::any::Any,
96 {
97 self.any.as_any().downcast_ref::<T>()
98 }
99
100 /// Attempt to downcast self to `T` and if succssful, try matching `&T` with `U` using
101 /// [`crate::dispatcher::DispatchRule`].
102 ///
103 /// Otherwise, return `Err(diskann_benchmark_runner::any::MATCH_FAIL)`.
104 ///
105 /// ```rust
106 /// use diskann_benchmark_runner::{
107 /// any::Any,
108 /// dispatcher::{self, MatchScore, FailureScore},
109 /// utils::datatype::{self, DataType, Type},
110 /// };
111 ///
112 /// let value = Any::new(DataType::Float32, "datatype");
113 ///
114 /// // A successful down cast and successful match.
115 /// assert_eq!(
116 /// value.try_match::<DataType, Type<f32>>().unwrap(),
117 /// MatchScore(0),
118 /// );
119 ///
120 /// // A successful down cast but unsuccessful match.
121 /// assert_eq!(
122 /// value.try_match::<DataType, Type<f64>>().unwrap_err(),
123 /// datatype::MATCH_FAIL,
124 /// );
125 ///
126 /// // An unsuccessful down cast.
127 /// let value = Any::new(0usize, "usize");
128 /// assert_eq!(
129 /// value.try_match::<DataType, Type<f32>>().unwrap_err(),
130 /// diskann_benchmark_runner::any::MATCH_FAIL,
131 /// );
132 /// ```
133 pub fn try_match<'a, T, U>(&'a self) -> Result<MatchScore, FailureScore>
134 where
135 U: DispatchRule<&'a T>,
136 T: 'static,
137 {
138 if let Some(cast) = self.downcast_ref::<T>() {
139 U::try_match(&cast)
140 } else {
141 Err(MATCH_FAIL)
142 }
143 }
144
145 /// Attempt to downcast self to `T` and if succssful, try converting `&T` with `U` using
146 /// [`crate::dispatcher::DispatchRule`].
147 ///
148 /// If unsuccessful, returns an error.
149 ///
150 /// ```rust
151 /// use diskann_benchmark_runner::{
152 /// any::Any,
153 /// dispatcher::{self, MatchScore, FailureScore},
154 /// utils::datatype::{self, DataType, Type},
155 /// };
156 ///
157 /// let value = Any::new(DataType::Float32, "datatype");
158 ///
159 /// // A successful down cast and successful conversion.
160 /// let _: Type<f32> = value.convert::<DataType, _>().unwrap();
161 /// ```
162 pub fn convert<'a, T, U>(&'a self) -> anyhow::Result<U>
163 where
164 U: DispatchRule<&'a T>,
165 anyhow::Error: From<U::Error>,
166 T: 'static,
167 {
168 if let Some(cast) = self.downcast_ref::<T>() {
169 Ok(U::convert(cast)?)
170 } else {
171 Err(anyhow::Error::msg("invalid dispatch"))
172 }
173 }
174
175 /// A wrapper for [`DispatchRule::description`].
176 ///
177 /// If `from` is `None` - document the expected tag for the input and return
178 /// `<U as DispatchRule<&T>>::description(f, None)`.
179 ///
180 /// If `from` is `Some` - attempt to downcast to `T`. If successful, return the dispatch
181 /// rule description for `U` on the doncast reference. Otherwise, return the expected tag.
182 ///
183 /// ```rust
184 /// use diskann_benchmark_runner::{
185 /// any::Any,
186 /// utils::datatype::{self, DataType, Type},
187 /// };
188 ///
189 /// use std::io::Write;
190 ///
191 /// struct Display(Option<Any>);
192 ///
193 /// impl std::fmt::Display for Display {
194 /// fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
195 /// match &self.0 {
196 /// Some(v) => Any::description::<DataType, Type<f32>>(f, Some(&&v), "my-tag"),
197 /// None => Any::description::<DataType, Type<f32>>(f, None, "my-tag"),
198 /// }
199 /// }
200 /// }
201 ///
202 /// // No contained value - document the expected conversion.
203 /// assert_eq!(
204 /// Display(None).to_string(),
205 /// "tag \"my-tag\"\nfloat32",
206 /// );
207 ///
208 /// // Matching contained value.
209 /// assert_eq!(
210 /// Display(Some(Any::new(DataType::Float32, "datatype"))).to_string(),
211 /// "successful match",
212 /// );
213 ///
214 /// // Successful down cast - unsuccessful match.
215 /// assert_eq!(
216 /// Display(Some(Any::new(DataType::UInt64, "datatype"))).to_string(),
217 /// "expected \"float32\" but found \"uint64\"",
218 /// );
219 ///
220 /// // Unsuccessful down cast.
221 /// assert_eq!(
222 /// Display(Some(Any::new(0usize, "another-tag"))).to_string(),
223 /// "expected tag \"my-tag\" - instead got \"another-tag\"",
224 /// );
225 /// ```
226 pub fn description<'a, T, U>(
227 f: &mut std::fmt::Formatter<'_>,
228 from: Option<&&'a Self>,
229 tag: impl std::fmt::Display,
230 ) -> std::fmt::Result
231 where
232 U: DispatchRule<&'a T>,
233 T: 'static,
234 {
235 match from {
236 Some(this) => match this.downcast_ref::<T>() {
237 Some(a) => U::description(f, Some(&a)),
238 None => write!(
239 f,
240 "expected tag \"{}\" - instead got \"{}\"",
241 tag,
242 this.tag(),
243 ),
244 },
245 None => {
246 writeln!(f, "tag \"{}\"", tag)?;
247 U::description(f, None::<&&T>)
248 }
249 }
250 }
251
252 /// Serialize the contained object to a [`serde_json::Value`].
253 pub fn serialize(&self) -> Result<serde_json::Value, serde_json::Error> {
254 self.any.dump()
255 }
256}
257
258/// Used in `DispatchRule::description(f, _)` to ensure that additional description
259/// lines are properly aligned.
260#[macro_export]
261macro_rules! describeln {
262 ($writer:ident, $fmt:literal) => {
263 writeln!($writer, concat!(" ", $fmt))
264 };
265 ($writer:ident, $fmt:literal, $($args:expr),* $(,)?) => {
266 writeln!($writer, concat!(" ", $fmt), $($args,)*)
267 };
268}
269
270trait SerializableAny: std::fmt::Debug {
271 fn as_any(&self) -> &dyn std::any::Any;
272 fn dump(&self) -> Result<serde_json::Value, serde_json::Error>;
273}
274
275impl<T> SerializableAny for T
276where
277 T: std::any::Any + serde::Serialize + std::fmt::Debug,
278{
279 fn as_any(&self) -> &dyn std::any::Any {
280 self
281 }
282
283 fn dump(&self) -> Result<serde_json::Value, serde_json::Error> {
284 serde_json::to_value(self)
285 }
286}
287
288// A backend type that allows users to decouple the serialized representation from the
289// actual type.
290#[derive(Debug)]
291struct Raw<T> {
292 value: T,
293 repr: serde_json::Value,
294}
295
296impl<T> Raw<T> {
297 fn new(value: T, repr: serde_json::Value) -> Self {
298 Self { value, repr }
299 }
300}
301
302impl<T> SerializableAny for Raw<T>
303where
304 T: std::any::Any + std::fmt::Debug,
305{
306 fn as_any(&self) -> &dyn std::any::Any {
307 &self.value
308 }
309
310 fn dump(&self) -> Result<serde_json::Value, serde_json::Error> {
311 Ok(self.repr.clone())
312 }
313}
314
315///////////
316// Tests //
317///////////
318
319#[cfg(test)]
320mod tests {
321 use super::*;
322
323 use crate::utils::datatype::{self, DataType, Type};
324
325 #[test]
326 fn test_new() {
327 let x = Any::new(42usize, "my-tag");
328 assert_eq!(x.tag(), "my-tag");
329 assert_eq!(x.type_id(), std::any::TypeId::of::<usize>());
330 assert!(x.is::<usize>());
331 assert!(!x.is::<u32>());
332 assert_eq!(*x.downcast_ref::<usize>().unwrap(), 42);
333 assert!(x.downcast_ref::<u32>().is_none());
334
335 assert!(!x.is::<Raw<usize>>());
336 assert!(!x.is::<Raw<u32>>());
337 assert!(x.downcast_ref::<Raw<usize>>().is_none());
338 assert!(x.downcast_ref::<Raw<u32>>().is_none());
339
340 assert_eq!(
341 x.serialize().unwrap(),
342 serde_json::Value::Number(serde_json::value::Number::from(42usize))
343 );
344 }
345
346 #[test]
347 fn test_raw() {
348 let repr = serde_json::json!(1.5);
349 let x = Any::raw(42usize, repr, "my-tag");
350 assert_eq!(x.tag(), "my-tag");
351 assert_eq!(x.type_id(), std::any::TypeId::of::<usize>());
352 assert!(x.is::<usize>());
353 assert!(!x.is::<u32>());
354 assert_eq!(*x.downcast_ref::<usize>().unwrap(), 42);
355 assert!(x.downcast_ref::<u32>().is_none());
356
357 assert!(!x.is::<Raw<usize>>());
358 assert!(!x.is::<Raw<u32>>());
359 assert!(x.downcast_ref::<Raw<usize>>().is_none());
360 assert!(x.downcast_ref::<Raw<u32>>().is_none());
361
362 assert_eq!(
363 x.serialize().unwrap(),
364 serde_json::Value::Number(serde_json::value::Number::from_f64(1.5).unwrap())
365 );
366 }
367
368 #[test]
369 fn test_try_match() {
370 let value = Any::new(DataType::Float32, "random-tag");
371
372 // A successful down cast and successful match.
373 assert_eq!(
374 value.try_match::<DataType, Type<f32>>().unwrap(),
375 MatchScore(0),
376 );
377
378 // A successful down cast but unsuccessful match.
379 assert_eq!(
380 value.try_match::<DataType, Type<f64>>().unwrap_err(),
381 datatype::MATCH_FAIL,
382 );
383
384 // An unsuccessful down cast.
385 let value = Any::new(0usize, "");
386 assert_eq!(
387 value.try_match::<DataType, Type<f32>>().unwrap_err(),
388 MATCH_FAIL,
389 );
390 }
391
392 #[test]
393 fn test_convert() {
394 let value = Any::new(DataType::Float32, "random-tag");
395
396 // A successful down cast and successful conversion.
397 let _: Type<f32> = value.convert::<DataType, _>().unwrap();
398
399 // An invalid match should return an error.
400 let value = Any::new(0usize, "random-rag");
401 let err = value.convert::<DataType, Type<f32>>().unwrap_err();
402 let msg = err.to_string();
403 assert!(msg.contains("invalid dispatch"), "{}", msg);
404 }
405
406 #[test]
407 #[should_panic(expected = "invalid dispatch")]
408 fn test_convert_inner_error() {
409 let value = Any::new(DataType::Float32, "random-tag");
410 let _ = value.convert::<DataType, Type<u64>>();
411 }
412
413 #[test]
414 fn test_description() {
415 struct Display(Option<Any>);
416
417 impl std::fmt::Display for Display {
418 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
419 match &self.0 {
420 Some(v) => Any::description::<DataType, Type<f32>>(f, Some(&v), "my-tag"),
421 None => Any::description::<DataType, Type<f32>>(f, None, "my-tag"),
422 }
423 }
424 }
425
426 // No contained value - document the expected conversion.
427 assert_eq!(Display(None).to_string(), "tag \"my-tag\"\nfloat32",);
428
429 // Matching contained value.
430 assert_eq!(
431 Display(Some(Any::new(DataType::Float32, ""))).to_string(),
432 "successful match",
433 );
434
435 // Successful down cast - unsuccessful match.
436 assert_eq!(
437 Display(Some(Any::new(DataType::UInt64, ""))).to_string(),
438 "expected \"float32\" but found \"uint64\"",
439 );
440
441 // Unsuccessful down cast.
442 assert_eq!(
443 Display(Some(Any::new(0usize, ""))).to_string(),
444 "expected tag \"my-tag\" - instead got \"\"",
445 );
446 }
447}