rig/
evals.rs

1//! Evals.
2//! From OpenAI's evals repo:
3//! > Evals provide a framework for evaluating large language models (LLMs) or systems built using LLMs. We offer an existing registry of evals to test different dimensions of OpenAI models and the ability to write your own custom evals for use cases you care about. You can also use your data to build private evals which represent the common LLMs patterns in your workflow without exposing any of that data publicly.
4
5use schemars::JsonSchema;
6use serde::{Deserialize, Serialize};
7
8use crate::{
9    completion::CompletionModel,
10    embeddings::EmbeddingModel,
11    extractor::{Extractor, ExtractorBuilder},
12};
13
14/// Evaluation errors.
15#[derive(Debug, thiserror::Error)]
16pub enum EvalError {
17    /// A mandatory field was null when attempting to initialise a struct
18    #[error("Field must not be null: {0}")]
19    FieldCannotBeNull(String),
20    /// Generic eval module error
21    #[error("Eval error: {0}")]
22    Custom(String),
23}
24
25/// The outcome of an evaluation (ie, sending an input to an LLM which then gets tested against a set of criteria).
26/// Invalid results due to things like functions returning errors should be encoded as invalid evaluation outcomes.
27#[derive(Deserialize, Serialize, Clone, Debug)]
28#[serde(tag = "outcome", content = "data")]
29pub enum EvalOutcome<Output> {
30    /// Evaluation passed
31    Pass(Output),
32    /// Evaluation failed
33    Fail(Output),
34    /// Evaluation was invalidated (reason in field)
35    Invalid(String),
36}
37
38impl<Output> EvalOutcome<Output> {
39    /// Check whether or not an evaluation has passed.
40    pub fn is_pass(&self) -> bool {
41        matches!(self, EvalOutcome::Pass(_))
42    }
43
44    /// Gets the score from an eval (assuming it isn't invalid).
45    pub fn score(&self) -> Option<&Output> {
46        match self {
47            EvalOutcome::Pass(o) | EvalOutcome::Fail(o) => Some(o),
48            EvalOutcome::Invalid(_) => None,
49        }
50    }
51}
52
53/// A trait to encode evaluators - types that can be used to test LLM outputs against criteria.
54/// Evaluators come in all shapes and sizes, and additionally may themselves use LLMs (although there are many heuristics you can use that don't).
55/// There are three possible states that an LLM can result in:
56/// - Pass (the output passed all criteria)
57/// - Fail (the output failed one or all criteria)
58/// - Invalid (the output was unable to be retrieved due to an external failure like an API call fail)
59pub trait Eval<Output>
60where
61    Output: for<'a> Deserialize<'a> + Serialize + Clone + Send + Sync,
62    Self: Sized + Send + Sync + 'static,
63{
64    fn eval(&self, input: String) -> impl Future<Output = EvalOutcome<Output>> + Send;
65
66    /// Send a bunch of inputs to be evaluated all in one call.
67    /// You can set the concurrency limit to help alleviate issues
68    /// with model provider API limits, as sending requests too quickly may
69    /// result in throttling or temporary request refusal.
70    fn eval_batch(
71        &self,
72        input: Vec<String>,
73        concurrency_limit: usize,
74    ) -> impl Future<Output = Vec<EvalOutcome<Output>>> + Send {
75        use futures::StreamExt;
76        async move {
77            let thing: Vec<EvalOutcome<Output>> = futures::stream::iter(input)
78                .map(|x| Self::eval(self, x))
79                .buffered(concurrency_limit)
80                .collect()
81                .await;
82
83            thing
84        }
85    }
86}
87
88/// A semantic similarity metric. Uses cosine similarity.
89/// In broad terms, cosine similarity can be used to measure how similar two documents are.
90/// This can be useful for things like quickly testing semantic similarity between two documents.
91#[derive(Clone, Debug)]
92#[non_exhaustive]
93pub struct SemanticSimilarityMetric<E> {
94    embedding_model: E,
95    threshold: f64,
96    reference_answer: String,
97    reference_answer_embedding: Vec<f64>,
98}
99
100impl<E> SemanticSimilarityMetric<E>
101where
102    E: EmbeddingModel,
103{
104    pub fn builder(embedding_model: E) -> SemanticSimilarityMetricBuilder<E> {
105        SemanticSimilarityMetricBuilder::new(embedding_model)
106    }
107
108    pub fn reference_answer(&self) -> &str {
109        &self.reference_answer
110    }
111}
112
113/// A builder struct for [`SemanticSimilarityMetric`].
114#[derive(Clone, Debug)]
115#[non_exhaustive]
116pub struct SemanticSimilarityMetricBuilder<E> {
117    embedding_model: E,
118    threshold: Option<f64>,
119    reference_answer: Option<String>,
120}
121
122impl<E> SemanticSimilarityMetricBuilder<E>
123where
124    E: EmbeddingModel,
125{
126    pub fn new(embedding_model: E) -> Self {
127        Self {
128            embedding_model,
129            threshold: None,
130            reference_answer: None,
131        }
132    }
133
134    pub fn threshold(mut self, threshold: f64) -> Self {
135        self.threshold = Some(threshold);
136        self
137    }
138
139    pub fn reference_answer(mut self, reference_answer: &str) -> Self {
140        self.reference_answer = Some(reference_answer.to_string());
141        self
142    }
143
144    pub async fn build(self) -> Result<SemanticSimilarityMetric<E>, EvalError> {
145        let threshold = self
146            .threshold
147            .ok_or(EvalError::FieldCannotBeNull("threshold".into()))?;
148        let reference_answer = self
149            .reference_answer
150            .ok_or(EvalError::FieldCannotBeNull("reference_answer".into()))?;
151        let reference_answer_embedding = self
152            .embedding_model
153            .embed_text(&reference_answer)
154            .await
155            .map_err(|x| EvalError::Custom(x.to_string()))?
156            .vec;
157
158        let res = SemanticSimilarityMetric {
159            embedding_model: self.embedding_model,
160            threshold,
161            reference_answer,
162            reference_answer_embedding,
163        };
164
165        Ok(res)
166    }
167}
168
169/// The scoring metric used for [`SemanticSimilarityMetric`].
170#[derive(Deserialize, Serialize, Clone, Debug)]
171#[non_exhaustive]
172pub struct SemanticSimilarityMetricScore {
173    pub score: f64,
174}
175
176impl<E> Eval<SemanticSimilarityMetricScore> for SemanticSimilarityMetric<E>
177where
178    E: EmbeddingModel + 'static,
179{
180    async fn eval(&self, input: String) -> EvalOutcome<SemanticSimilarityMetricScore> {
181        let input = match self.embedding_model.embed_text(&input).await {
182            Ok(res) => res.vec,
183            Err(e) => return EvalOutcome::Invalid(e.to_string()),
184        };
185        let ref_answer = &self.reference_answer_embedding;
186
187        let dot: f64 = input.iter().zip(ref_answer).map(|(x, y)| x * y).sum();
188        let norm_a = input.iter().map(|x| x * x).sum::<f64>().sqrt();
189        let norm_b = ref_answer.iter().map(|x| x * x).sum::<f64>().sqrt();
190
191        let cosine_sim = dot / (norm_a * norm_b);
192
193        if cosine_sim >= self.threshold {
194            EvalOutcome::Pass(SemanticSimilarityMetricScore { score: cosine_sim })
195        } else {
196            EvalOutcome::Fail(SemanticSimilarityMetricScore { score: cosine_sim })
197        }
198    }
199}
200
201/// An LLM as a judge that judges an output by a given schema (and outputs the schema).
202/// The schema type uses the `Judgment` trait, which simply enforces a single function that checks whether it passes or not.
203pub struct LlmJudgeMetric<M, T>
204where
205    M: CompletionModel,
206    T: Judgment + Send + Sync + JsonSchema + Serialize + for<'a> Deserialize<'a>,
207{
208    ext: Extractor<M, T>,
209}
210
211/// An LLM as a judge that judges an output by a given schema (and outputs the schema).
212/// Unlike `LlmJudgeMetric`, this type uses a function pointer that takes the type and returns a `bool` instead.
213pub struct LlmJudgeMetricWithFn<M, T>
214where
215    M: CompletionModel,
216    T: Send + Sync + JsonSchema + Serialize + for<'a> Deserialize<'a>,
217{
218    ext: Extractor<M, T>,
219    evaluator: Box<dyn Fn(&T) -> bool + Send + Sync>,
220}
221
222pub struct LlmJudgeBuilder<M, T>
223where
224    M: CompletionModel,
225    T: Send + Sync + JsonSchema + Serialize + for<'a> Deserialize<'a> + 'static,
226{
227    ext: ExtractorBuilder<M, T>,
228}
229
230pub struct LlmJudgeBuilderWithFn<M, T>
231where
232    M: CompletionModel,
233    T: Send + Sync + JsonSchema + Serialize + for<'a> Deserialize<'a> + 'static,
234{
235    ext: ExtractorBuilder<M, T>,
236    evaluator: Box<dyn Fn(&T) -> bool + Send + Sync>,
237}
238
239impl<M, T> LlmJudgeBuilder<M, T>
240where
241    M: CompletionModel,
242    T: Send + Sync + JsonSchema + Serialize + for<'a> Deserialize<'a>,
243{
244    pub fn new(ext: ExtractorBuilder<M, T>) -> Self {
245        Self { ext }
246    }
247
248    pub fn with_fn<F>(self, f: F) -> LlmJudgeBuilderWithFn<M, T>
249    where
250        F: Fn(&T) -> bool + Send + Sync + 'static,
251    {
252        LlmJudgeBuilderWithFn {
253            ext: self.ext,
254            evaluator: Box::new(f),
255        }
256    }
257
258    pub fn build(self) -> LlmJudgeMetric<M, T>
259    where
260        T: Judgment + 'static,
261    {
262        let ext = self
263            .ext
264            .preamble(
265                "Judge the prompt input by the schema given and return it as a JSON tool result",
266            )
267            .build();
268        LlmJudgeMetric { ext }
269    }
270}
271
272impl<M, T> LlmJudgeBuilderWithFn<M, T>
273where
274    M: CompletionModel,
275    T: Send + Sync + JsonSchema + Serialize + for<'a> Deserialize<'a> + 'static,
276{
277    pub fn with_fn<F2>(mut self, f: F2) -> Self
278    where
279        F2: Fn(&T) -> bool + Send + Sync + 'static,
280    {
281        self.evaluator = Box::new(f);
282        self
283    }
284
285    pub fn build(self) -> LlmJudgeMetricWithFn<M, T> {
286        let ext = self
287            .ext
288            .preamble(
289                "Judge the prompt input by the schema given and return it as a JSON tool result",
290            )
291            .build();
292        LlmJudgeMetricWithFn {
293            ext,
294            evaluator: self.evaluator,
295        }
296    }
297}
298
299/// A helper trait for `LlmJudgeMetric`.
300/// Types that implement `Judgment` generally have a very standard way of either passing or failing.
301/// As such, this can be enforced as a trait.
302pub trait Judgment {
303    fn passes(&self) -> bool;
304}
305
306impl<M, T> Eval<T> for LlmJudgeMetric<M, T>
307where
308    M: CompletionModel + 'static,
309    T: Judgment + Send + Sync + JsonSchema + Serialize + for<'a> Deserialize<'a> + Clone + 'static,
310{
311    async fn eval(&self, input: String) -> EvalOutcome<T> {
312        match self.ext.extract(input).await {
313            Ok(judgment) => {
314                if judgment.passes() {
315                    EvalOutcome::Pass(judgment)
316                } else {
317                    EvalOutcome::Fail(judgment)
318                }
319            }
320            Err(e) => EvalOutcome::Invalid(e.to_string()),
321        }
322    }
323}
324
325impl<M, T> Eval<T> for LlmJudgeMetricWithFn<M, T>
326where
327    M: CompletionModel + 'static,
328    T: Send + Sync + JsonSchema + Serialize + for<'a> Deserialize<'a> + Clone + 'static,
329{
330    async fn eval(&self, input: String) -> EvalOutcome<T> {
331        match self.ext.extract(input).await {
332            Ok(judgment) => {
333                if (self.evaluator)(&judgment) {
334                    EvalOutcome::Pass(judgment)
335                } else {
336                    EvalOutcome::Fail(judgment)
337                }
338            }
339            Err(e) => EvalOutcome::Invalid(e.to_string()),
340        }
341    }
342}
343
344impl<M, T> From<ExtractorBuilder<M, T>> for LlmJudgeBuilder<M, T>
345where
346    M: CompletionModel,
347    T: Send + Sync + JsonSchema + Serialize + for<'a> Deserialize<'a>,
348{
349    fn from(ext: ExtractorBuilder<M, T>) -> Self {
350        Self::new(ext)
351    }
352}
353
354/// An eval that scores an output based on some given criteria.
355#[non_exhaustive]
356pub struct LlmScoreMetric<M>
357where
358    M: CompletionModel,
359{
360    agent: Extractor<M, LlmScoreMetricScore>,
361    threshold: f64,
362}
363
364/// The scoring output returned by `LlmScoreMetric`.
365/// Must also be used as the Extractor return type when passed into `LlmScoreMetric`.
366#[derive(Deserialize, Serialize, Debug, Clone, JsonSchema)]
367pub struct LlmScoreMetricScore {
368    /// A score between 0.0 and 1.0 inclusive.
369    pub score: f64,
370    /// Feedback on a given input in relation to the required criteria to be met.
371    pub feedback: String,
372}
373
374impl<M> Eval<LlmScoreMetricScore> for LlmScoreMetric<M>
375where
376    M: CompletionModel + 'static,
377{
378    async fn eval(&self, input: String) -> EvalOutcome<LlmScoreMetricScore> {
379        let res = match self.agent.extract(input).await {
380            Ok(res) => res,
381            Err(e) => return EvalOutcome::Invalid(e.to_string()),
382        };
383
384        if !(0.0..=1.0).contains(&res.score) {
385            return EvalOutcome::Invalid(format!(
386                "Score {} outside valid range [0.0, 1.0]",
387                res.score
388            ));
389        }
390
391        if res.score >= self.threshold {
392            EvalOutcome::Pass(res)
393        } else {
394            EvalOutcome::Fail(res)
395        }
396    }
397}
398
399#[non_exhaustive]
400pub struct LlmScoreMetricBuilder<M>
401where
402    M: CompletionModel,
403{
404    agent: ExtractorBuilder<M, LlmScoreMetricScore>,
405    criteria: Vec<String>,
406    threshold: Option<f64>,
407}
408
409impl<M> LlmScoreMetricBuilder<M>
410where
411    M: CompletionModel,
412{
413    pub fn new(agent: ExtractorBuilder<M, LlmScoreMetricScore>) -> Self {
414        Self {
415            agent,
416            criteria: Vec::new(),
417            threshold: None,
418        }
419    }
420
421    pub fn threshold(mut self, threshold: f64) -> Self {
422        self.threshold = Some(threshold);
423        self
424    }
425
426    pub fn criteria(mut self, criteria: &str) -> Self {
427        self.criteria.push(criteria.to_string());
428        self
429    }
430
431    pub fn build(self) -> Result<LlmScoreMetric<M>, EvalError> {
432        let threshold = self
433            .threshold
434            .ok_or(EvalError::FieldCannotBeNull("threshold".into()))?;
435        let preamble = format!(
436            "You are an evaluation model. Score the input based on these criteria:\n{}\n\n\
437            Provide a score between 0.0 and 1.0 (where 1.0 is best) and explain your reasoning.",
438            self.criteria.join("\n")
439        );
440
441        let agent = self.agent.preamble(&preamble).build();
442
443        Ok(LlmScoreMetric { agent, threshold })
444    }
445}