diskann_benchmark_runner/any.rs
1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use crate::dispatcher::{DispatchRule, FailureScore, MatchScore};
7
8/// An refinement of [`std::any::Any`] with an associated name (tag) and serialization.
9///
10/// This type represents deserialized inputs returned from [`crate::Input::try_deserialize`]
11/// and is passed to beckend benchmarks for matching and execution.
12#[derive(Debug)]
13pub struct Any {
14 any: Box<dyn SerializableAny>,
15 tag: &'static str,
16}
17
18/// The score given unsuccessful downcasts in [`Any::try_match`].
19pub const MATCH_FAIL: FailureScore = FailureScore(10_000);
20
21impl Any {
22 /// Construct a new [`Any`] around `any` and associate it with the name `tag`.
23 ///
24 /// The tag is included as merely a debugging and readability aid and usually should
25 /// belong to a [`crate::Input::tag`] that generated `any`.
26 pub fn new<T>(any: T, tag: &'static str) -> Self
27 where
28 T: serde::Serialize + std::fmt::Debug + 'static,
29 {
30 Self {
31 any: Box::new(any),
32 tag,
33 }
34 }
35
36 /// A lower level API for constructing an [`Any`] that decouples the serialized
37 /// representation from the inmemory representation.
38 ///
39 /// When serialized, the **exact** representation of `repr` will be used.
40 ///
41 /// This is useful in some contexts where as part of input resolution, a fully resolved
42 /// input struct contains elements that are not serializable.
43 ///
44 /// Like [`Any::new`], the tag is included for debugging and readability.
45 pub fn raw<T>(any: T, repr: serde_json::Value, tag: &'static str) -> Self
46 where
47 T: std::fmt::Debug + 'static,
48 {
49 Self {
50 any: Box::new(Raw::new(any, repr)),
51 tag,
52 }
53 }
54
55 /// Return the benchmark tag associated with this benchmarks.
56 pub fn tag(&self) -> &'static str {
57 self.tag
58 }
59
60 /// Return the Rust [`std::any::TypeId`] for the contained object.
61 pub fn type_id(&self) -> std::any::TypeId {
62 self.any.as_any().type_id()
63 }
64
65 /// Return `true` if the runtime value is `T`. Otherwise, return false.
66 ///
67 /// ```rust
68 /// use diskann_benchmark_runner::any::Any;
69 ///
70 /// let value = Any::new(42usize, "usize");
71 /// assert!(value.is::<usize>());
72 /// assert!(!value.is::<u32>());
73 /// ```
74 #[must_use = "this function has no side effects"]
75 pub fn is<T>(&self) -> bool
76 where
77 T: std::any::Any,
78 {
79 self.any.as_any().is::<T>()
80 }
81
82 /// Return a reference to the contained object if it's runtime type is `T`.
83 ///
84 /// Otherwise return `None`.
85 ///
86 /// ```rust
87 /// use diskann_benchmark_runner::any::Any;
88 ///
89 /// let value = Any::new(42usize, "usize");
90 /// assert_eq!(*value.downcast_ref::<usize>().unwrap(), 42);
91 /// assert!(value.downcast_ref::<u32>().is_none());
92 /// ```
93 pub fn downcast_ref<T>(&self) -> Option<&T>
94 where
95 T: std::any::Any,
96 {
97 self.any.as_any().downcast_ref::<T>()
98 }
99
100 /// Attempt to downcast self to `T` and if succssful, try matching `&T` with `U` using
101 /// [`crate::dispatcher::DispatchRule`].
102 ///
103 /// Otherwise, return `Err(diskann_benchmark_runner::any::MATCH_FAIL)`.
104 ///
105 /// ```rust
106 /// use diskann_benchmark_runner::{
107 /// any::Any,
108 /// dispatcher::{self, MatchScore, FailureScore},
109 /// utils::datatype::{self, DataType, Type},
110 /// };
111 ///
112 /// let value = Any::new(DataType::Float32, "datatype");
113 ///
114 /// // A successful down cast and successful match.
115 /// assert_eq!(
116 /// value.try_match::<DataType, Type<f32>>().unwrap(),
117 /// MatchScore(0),
118 /// );
119 ///
120 /// // A successful down cast but unsuccessful match.
121 /// assert_eq!(
122 /// value.try_match::<DataType, Type<f64>>().unwrap_err(),
123 /// datatype::MATCH_FAIL,
124 /// );
125 ///
126 /// // An unsuccessful down cast.
127 /// let value = Any::new(0usize, "usize");
128 /// assert_eq!(
129 /// value.try_match::<DataType, Type<f32>>().unwrap_err(),
130 /// diskann_benchmark_runner::any::MATCH_FAIL,
131 /// );
132 /// ```
133 pub fn try_match<'a, T, U>(&'a self) -> Result<MatchScore, FailureScore>
134 where
135 U: DispatchRule<&'a T>,
136 T: 'static,
137 {
138 if let Some(cast) = self.downcast_ref::<T>() {
139 U::try_match(&cast)
140 } else {
141 Err(MATCH_FAIL)
142 }
143 }
144
145 /// Attempt to downcast self to `T` and if succssful, try converting `&T` with `U` using
146 /// [`crate::dispatcher::DispatchRule`].
147 ///
148 /// If unsuccessful, returns an error.
149 ///
150 /// ```rust
151 /// use diskann_benchmark_runner::{
152 /// any::Any,
153 /// dispatcher::{self, MatchScore, FailureScore},
154 /// utils::datatype::{self, DataType, Type},
155 /// };
156 ///
157 /// let value = Any::new(DataType::Float32, "datatype");
158 ///
159 /// // A successful down cast and successful conversion.
160 /// let _: Type<f32> = value.convert::<DataType, _>().unwrap();
161 /// ```
162 pub fn convert<'a, T, U>(&'a self) -> anyhow::Result<U>
163 where
164 U: DispatchRule<&'a T>,
165 anyhow::Error: From<U::Error>,
166 T: 'static,
167 {
168 if let Some(cast) = self.downcast_ref::<T>() {
169 Ok(U::convert(cast)?)
170 } else {
171 Err(anyhow::Error::msg("invalid dispatch"))
172 }
173 }
174
175 /// A wrapper for [`DispatchRule::description`].
176 ///
177 /// If `from` is `None` - document the expected tag for the input and return
178 /// `<U as DispatchRule<&T>>::description(f, None)`.
179 ///
180 /// If `from` is `Some` - attempt to downcast to `T`. If successful, return the dispatch
181 /// rule description for `U` on the doncast reference. Otherwise, return the expected tag.
182 ///
183 /// ```rust
184 /// use diskann_benchmark_runner::{
185 /// any::Any,
186 /// utils::datatype::{self, DataType, Type},
187 /// };
188 ///
189 /// use std::io::Write;
190 ///
191 /// struct Display(Option<Any>);
192 ///
193 /// impl std::fmt::Display for Display {
194 /// fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
195 /// match &self.0 {
196 /// Some(v) => Any::description::<DataType, Type<f32>>(f, Some(&&v), "my-tag"),
197 /// None => Any::description::<DataType, Type<f32>>(f, None, "my-tag"),
198 /// }
199 /// }
200 /// }
201 ///
202 /// // No contained value - document the expected conversion.
203 /// assert_eq!(
204 /// Display(None).to_string(),
205 /// "tag \"my-tag\"\nfloat32",
206 /// );
207 ///
208 /// // Matching contained value.
209 /// assert_eq!(
210 /// Display(Some(Any::new(DataType::Float32, "datatype"))).to_string(),
211 /// "successful match",
212 /// );
213 ///
214 /// // Successful down cast - unsuccessful match.
215 /// assert_eq!(
216 /// Display(Some(Any::new(DataType::UInt64, "datatype"))).to_string(),
217 /// "expected \"float32\" but found \"uint64\"",
218 /// );
219 ///
220 /// // Unsuccessful down cast.
221 /// assert_eq!(
222 /// Display(Some(Any::new(0usize, "another-tag"))).to_string(),
223 /// "expected tag \"my-tag\" - instead got \"another-tag\"",
224 /// );
225 /// ```
226 pub fn description<'a, T, U>(
227 f: &mut std::fmt::Formatter<'_>,
228 from: Option<&&'a Self>,
229 tag: impl std::fmt::Display,
230 ) -> std::fmt::Result
231 where
232 U: DispatchRule<&'a T>,
233 T: 'static,
234 {
235 match from {
236 Some(this) => match this.downcast_ref::<T>() {
237 Some(a) => U::description(f, Some(&a)),
238 None => write!(
239 f,
240 "expected tag \"{}\" - instead got \"{}\"",
241 tag,
242 this.tag(),
243 ),
244 },
245 None => {
246 writeln!(f, "tag \"{}\"", tag)?;
247 U::description(f, None::<&&T>)
248 }
249 }
250 }
251
252 /// Serialize the contained object to a [`serde_json::Value`].
253 pub fn serialize(&self) -> Result<serde_json::Value, serde_json::Error> {
254 self.any.dump()
255 }
256}
257
258trait SerializableAny: std::fmt::Debug {
259 fn as_any(&self) -> &dyn std::any::Any;
260 fn dump(&self) -> Result<serde_json::Value, serde_json::Error>;
261}
262
263impl<T> SerializableAny for T
264where
265 T: std::any::Any + serde::Serialize + std::fmt::Debug,
266{
267 fn as_any(&self) -> &dyn std::any::Any {
268 self
269 }
270
271 fn dump(&self) -> Result<serde_json::Value, serde_json::Error> {
272 serde_json::to_value(self)
273 }
274}
275
276// A backend type that allows users to decouple the serialized representation from the
277// actual type.
278#[derive(Debug)]
279struct Raw<T> {
280 value: T,
281 repr: serde_json::Value,
282}
283
284impl<T> Raw<T> {
285 fn new(value: T, repr: serde_json::Value) -> Self {
286 Self { value, repr }
287 }
288}
289
290impl<T> SerializableAny for Raw<T>
291where
292 T: std::any::Any + std::fmt::Debug,
293{
294 fn as_any(&self) -> &dyn std::any::Any {
295 &self.value
296 }
297
298 fn dump(&self) -> Result<serde_json::Value, serde_json::Error> {
299 Ok(self.repr.clone())
300 }
301}
302
303///////////
304// Tests //
305///////////
306
307#[cfg(test)]
308mod tests {
309 use super::*;
310
311 use crate::utils::datatype::{self, DataType, Type};
312
313 #[test]
314 fn test_new() {
315 let x = Any::new(42usize, "my-tag");
316 assert_eq!(x.tag(), "my-tag");
317 assert_eq!(x.type_id(), std::any::TypeId::of::<usize>());
318 assert!(x.is::<usize>());
319 assert!(!x.is::<u32>());
320 assert_eq!(*x.downcast_ref::<usize>().unwrap(), 42);
321 assert!(x.downcast_ref::<u32>().is_none());
322
323 assert!(!x.is::<Raw<usize>>());
324 assert!(!x.is::<Raw<u32>>());
325 assert!(x.downcast_ref::<Raw<usize>>().is_none());
326 assert!(x.downcast_ref::<Raw<u32>>().is_none());
327
328 assert_eq!(
329 x.serialize().unwrap(),
330 serde_json::Value::Number(serde_json::value::Number::from(42usize))
331 );
332 }
333
334 #[test]
335 fn test_raw() {
336 let repr = serde_json::json!(1.5);
337 let x = Any::raw(42usize, repr, "my-tag");
338 assert_eq!(x.tag(), "my-tag");
339 assert_eq!(x.type_id(), std::any::TypeId::of::<usize>());
340 assert!(x.is::<usize>());
341 assert!(!x.is::<u32>());
342 assert_eq!(*x.downcast_ref::<usize>().unwrap(), 42);
343 assert!(x.downcast_ref::<u32>().is_none());
344
345 assert!(!x.is::<Raw<usize>>());
346 assert!(!x.is::<Raw<u32>>());
347 assert!(x.downcast_ref::<Raw<usize>>().is_none());
348 assert!(x.downcast_ref::<Raw<u32>>().is_none());
349
350 assert_eq!(
351 x.serialize().unwrap(),
352 serde_json::Value::Number(serde_json::value::Number::from_f64(1.5).unwrap())
353 );
354 }
355
356 #[test]
357 fn test_try_match() {
358 let value = Any::new(DataType::Float32, "random-tag");
359
360 // A successful down cast and successful match.
361 assert_eq!(
362 value.try_match::<DataType, Type<f32>>().unwrap(),
363 MatchScore(0),
364 );
365
366 // A successful down cast but unsuccessful match.
367 assert_eq!(
368 value.try_match::<DataType, Type<f64>>().unwrap_err(),
369 datatype::MATCH_FAIL,
370 );
371
372 // An unsuccessful down cast.
373 let value = Any::new(0usize, "");
374 assert_eq!(
375 value.try_match::<DataType, Type<f32>>().unwrap_err(),
376 MATCH_FAIL,
377 );
378 }
379
380 #[test]
381 fn test_convert() {
382 let value = Any::new(DataType::Float32, "random-tag");
383
384 // A successful down cast and successful conversion.
385 let _: Type<f32> = value.convert::<DataType, _>().unwrap();
386
387 // An invalid match should return an error.
388 let value = Any::new(0usize, "random-rag");
389 let err = value.convert::<DataType, Type<f32>>().unwrap_err();
390 let msg = err.to_string();
391 assert!(msg.contains("invalid dispatch"), "{}", msg);
392 }
393
394 #[test]
395 #[should_panic(expected = "invalid dispatch")]
396 fn test_convert_inner_error() {
397 let value = Any::new(DataType::Float32, "random-tag");
398 let _ = value.convert::<DataType, Type<u64>>();
399 }
400
401 #[test]
402 fn test_description() {
403 struct Display(Option<Any>);
404
405 impl std::fmt::Display for Display {
406 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
407 match &self.0 {
408 Some(v) => Any::description::<DataType, Type<f32>>(f, Some(&v), "my-tag"),
409 None => Any::description::<DataType, Type<f32>>(f, None, "my-tag"),
410 }
411 }
412 }
413
414 // No contained value - document the expected conversion.
415 assert_eq!(Display(None).to_string(), "tag \"my-tag\"\nfloat32",);
416
417 // Matching contained value.
418 assert_eq!(
419 Display(Some(Any::new(DataType::Float32, ""))).to_string(),
420 "successful match",
421 );
422
423 // Successful down cast - unsuccessful match.
424 assert_eq!(
425 Display(Some(Any::new(DataType::UInt64, ""))).to_string(),
426 "expected \"float32\" but found \"uint64\"",
427 );
428
429 // Unsuccessful down cast.
430 assert_eq!(
431 Display(Some(Any::new(0usize, ""))).to_string(),
432 "expected tag \"my-tag\" - instead got \"\"",
433 );
434 }
435}