diskann-benchmark-runner 0.51.0

DiskANN is a fast approximate nearest neighbor search library for high dimensional data
Documentation
/*
 * Copyright (c) Microsoft Corporation.
 * Licensed under the MIT license.
 */

use serde::{Deserialize, Serialize};

use crate::{
    dispatcher::{FailureScore, MatchScore},
    Any, Checkpoint, Input, Output,
};

/// A registered benchmark.
///
/// Benchmarks consist of an [`Input`] and a corresponding serialized `Output`. Inputs will
/// first be validated with the benchmark using [`try_match`](Self::try_match). Only
/// successful matches will be passed to [`run`](Self::run).
pub trait Benchmark: 'static {
    /// The [`Input`] type this benchmark matches against.
    type Input: Input + 'static;

    /// The concrete type of the results generated by this benchmark.
    type Output: Serialize;

    /// Return whether or not this benchmark is compatible with `input`.
    ///
    /// On success, returns `Ok(MatchScore)`. [`MatchScore`]s of all benchmarks will be
    /// collected and the benchmark with the lowest final score will be selected.
    ///
    /// In the case of ties, the winner is chosen using an unspecified tie-breaking procedure.
    ///
    /// On failure, returns `Err(FailureScore)`. In the [`crate::registry::Benchmarks`]
    /// registry, [`FailureScore`]s will be used to rank the "nearest misses". Implementations
    /// are encouraged to generate ranked [`FailureScore`]s to assist in user level debugging.
    fn try_match(&self, input: &Self::Input) -> Result<MatchScore, FailureScore>;

    /// Return descriptive information about the benchmark.
    ///
    /// If `input` is `None`, then high level information about the benchmark should be relayed.
    /// If `input` is `Some`, and is an unsuccessful match, diagnostic information about what
    /// was expected should be generated to help users.
    fn description(
        &self,
        f: &mut std::fmt::Formatter<'_>,
        input: Option<&Self::Input>,
    ) -> std::fmt::Result;

    /// Run the benchmark with `input`.
    ///
    /// All prints should be directed to `output`. The `checkpoint` is provided so
    /// long-running benchmarks can periodically save output to prevent data loss due to
    /// an early error.
    ///
    /// Implementors may assume that [`Self::try_match`] returned `Ok` on `input`.
    fn run(
        &self,
        input: &Self::Input,
        checkpoint: Checkpoint<'_>,
        output: &mut dyn Output,
    ) -> anyhow::Result<Self::Output>;
}

/// A refinement of [`Benchmark`], that supports before/after comparison of generated results.
///
/// Benchmarks are associated with a "tolerance" input, which may contain runtime values
/// controlling the amount of slack a benchmark is allowed to have between runs before failing.
///
/// The semantics of pass or failure are left solely to the discretion of the [`Regression`]
/// implementation.
///
/// See: [`register_regression`](crate::registry::Benchmarks::register_regression).
pub trait Regression: Benchmark<Output: for<'a> Deserialize<'a>> {
    /// The tolerance [`Input`] associated with this regression check.
    type Tolerances: Input + 'static;

    /// The report summary used to describe a successful regression check.
    type Pass: Serialize + std::fmt::Display + 'static;

    /// The report summary used to describe an unsuccessful regression check.
    type Fail: Serialize + std::fmt::Display + 'static;

    /// Run any regression checks necessary for two benchmark runs `before` and `after`.
    /// Argument `tolerances` contain any tuned runtime tolerances to use when determining
    /// whether or not a regression is detected.
    ///
    /// The `input` is the raw input that would have been provided to [`Benchmark::run`]
    /// when generating the `before` and `after` outputs.
    ///
    /// Implementations of `check` should not attempt to print to `stdout` or any other
    /// stream. Instead, all diagnostics should be encoded in the returned [`PassFail`] type
    /// for reporting upstream.
    fn check(
        &self,
        tolerances: &Self::Tolerances,
        input: &Self::Input,
        before: &Self::Output,
        after: &Self::Output,
    ) -> anyhow::Result<PassFail<Self::Pass, Self::Fail>>;
}

/// Describe whether or not a [`Regression`] passed or failed.
#[derive(Debug, Clone, Copy)]
pub enum PassFail<P, F> {
    Pass(P),
    Fail(F),
}

//////////////
// Internal //
//////////////

pub(crate) mod internal {
    use super::*;

    use anyhow::Context;
    use thiserror::Error;

    /// Object-safe trait for type-erased benchmarks stored in the registry.
    pub(crate) trait Benchmark {
        fn try_match(&self, input: &Any) -> Result<MatchScore, FailureScore>;

        fn description(
            &self,
            f: &mut std::fmt::Formatter<'_>,
            input: Option<&Any>,
        ) -> std::fmt::Result;

        fn run(
            &self,
            input: &Any,
            checkpoint: Checkpoint<'_>,
            output: &mut dyn Output,
        ) -> anyhow::Result<serde_json::Value>;

        /// If supported, return an object capable of running regression checks on this benchmark.
        fn as_regression(&self) -> Option<&dyn Regression>;
    }

    pub(crate) struct Checked {
        pub(crate) json: serde_json::Value,
        pub(crate) display: Box<dyn std::fmt::Display>,
    }

    impl Checked {
        /// Serialize `value` to `serde_json::Value` and box it for future display.
        fn new<T>(value: T) -> Result<Self, serde_json::Error>
        where
            T: Serialize + std::fmt::Display + 'static,
        {
            Ok(Self {
                json: serde_json::to_value(&value)?,
                display: Box::new(value),
            })
        }
    }

    pub(crate) type CheckedPassFail = PassFail<Checked, Checked>;

    pub(crate) trait Regression {
        fn tolerance(&self) -> &dyn crate::input::DynInput;
        fn input_tag(&self) -> &'static str;
        fn check(
            &self,
            tolerances: &Any,
            input: &Any,
            before: &serde_json::Value,
            after: &serde_json::Value,
        ) -> anyhow::Result<CheckedPassFail>;
    }

    impl std::fmt::Debug for dyn Regression + '_ {
        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
            f.debug_struct("dyn Regression")
                .field("tolerance", &self.tolerance().tag())
                .field("input_tag", &self.input_tag())
                .finish()
        }
    }

    pub(crate) trait AsRegression<T> {
        fn as_regression(benchmark: &T) -> Option<&dyn Regression>;
    }

    #[derive(Debug, Clone, Copy)]
    pub(crate) struct NoRegression;

    impl<T> AsRegression<T> for NoRegression {
        fn as_regression(_benchmark: &T) -> Option<&dyn Regression> {
            None
        }
    }

    #[derive(Debug, Clone, Copy)]
    pub(crate) struct WithRegression;

    impl<T> AsRegression<T> for WithRegression
    where
        T: super::Regression,
    {
        fn as_regression(benchmark: &T) -> Option<&dyn Regression> {
            Some(benchmark)
        }
    }

    impl<T> Regression for T
    where
        T: super::Regression,
    {
        fn tolerance(&self) -> &dyn crate::input::DynInput {
            &crate::input::Wrapper::<T::Tolerances>::INSTANCE
        }

        fn input_tag(&self) -> &'static str {
            T::Input::tag()
        }

        fn check(
            &self,
            tolerance: &Any,
            input: &Any,
            before: &serde_json::Value,
            after: &serde_json::Value,
        ) -> anyhow::Result<CheckedPassFail> {
            let tolerance = tolerance
                .downcast_ref::<T::Tolerances>()
                .ok_or_else(|| BadDownCast::new(T::Tolerances::tag(), tolerance.tag()))
                .context("failed to obtain tolerance")?;

            let input = input
                .downcast_ref::<T::Input>()
                .ok_or_else(|| BadDownCast::new(T::Input::tag(), input.tag()))
                .context("failed to obtain input")?;

            let before = T::Output::deserialize(before)
                .map_err(|err| DeserializationError::new(Kind::Before, err))?;

            let after = T::Output::deserialize(after)
                .map_err(|err| DeserializationError::new(Kind::After, err))?;

            let passfail = match self.check(tolerance, input, &before, &after)? {
                PassFail::Pass(pass) => PassFail::Pass(Checked::new(pass)?),
                PassFail::Fail(fail) => PassFail::Fail(Checked::new(fail)?),
            };

            Ok(passfail)
        }
    }

    #[derive(Debug, Clone, Copy)]
    pub(crate) struct Wrapper<T, R = NoRegression> {
        benchmark: T,
        _regression: R,
    }

    impl<T, R> Wrapper<T, R> {
        pub(crate) const fn new(benchmark: T, regression: R) -> Self {
            Self {
                benchmark,
                _regression: regression,
            }
        }
    }

    /// The score given to unsuccessful downcasts in [`Benchmark::try_match`].
    const MATCH_FAIL: FailureScore = FailureScore(10_000);

    impl<T, R> Benchmark for Wrapper<T, R>
    where
        T: super::Benchmark,
        R: AsRegression<T>,
    {
        fn try_match(&self, input: &Any) -> Result<MatchScore, FailureScore> {
            if let Some(cast) = input.downcast_ref::<T::Input>() {
                self.benchmark.try_match(cast)
            } else {
                Err(MATCH_FAIL)
            }
        }

        fn description(
            &self,
            f: &mut std::fmt::Formatter<'_>,
            input: Option<&Any>,
        ) -> std::fmt::Result {
            match input {
                Some(input) => match input.downcast_ref::<T::Input>() {
                    Some(cast) => self.benchmark.description(f, Some(cast)),
                    None => write!(
                        f,
                        "expected tag \"{}\" - instead got \"{}\"",
                        T::Input::tag(),
                        input.tag(),
                    ),
                },
                None => {
                    writeln!(f, "tag \"{}\"", <T::Input as Input>::tag())?;
                    self.benchmark.description(f, None)
                }
            }
        }

        fn run(
            &self,
            input: &Any,
            checkpoint: Checkpoint<'_>,
            output: &mut dyn Output,
        ) -> anyhow::Result<serde_json::Value> {
            match input.downcast_ref::<T::Input>() {
                Some(input) => {
                    let result = self.benchmark.run(input, checkpoint, output)?;
                    Ok(serde_json::to_value(result)?)
                }
                None => Err(BadDownCast::new(T::Input::tag(), input.tag()).into()),
            }
        }

        // Extensions
        fn as_regression(&self) -> Option<&dyn Regression> {
            R::as_regression(&self.benchmark)
        }
    }

    //--------//
    // Errors //
    //--------//

    #[derive(Debug, Clone, Copy, Error)]
    #[error(
        "INTERNAL ERROR: bad downcast - expected \"{}\" but got \"{}\"",
        self.expected,
        self.got
    )]
    struct BadDownCast {
        expected: &'static str,
        got: &'static str,
    }

    impl BadDownCast {
        fn new(expected: &'static str, got: &'static str) -> Self {
            Self { expected, got }
        }
    }

    #[derive(Debug, Error)]
    #[error(
        "the \"{}\" results do not match the output schema expected by this benchmark",
        self.kind
    )]
    struct DeserializationError {
        kind: Kind,
        source: serde_json::Error,
    }

    impl DeserializationError {
        fn new(kind: Kind, source: serde_json::Error) -> Self {
            Self { kind, source }
        }
    }

    #[derive(Debug, Clone, Copy)]
    enum Kind {
        Before,
        After,
    }

    impl std::fmt::Display for Kind {
        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
            let as_str = match self {
                Self::Before => "before",
                Self::After => "after",
            };

            write!(f, "{}", as_str)
        }
    }
}