1use schemars::JsonSchema;
6use serde::{Deserialize, Serialize};
7
8use crate::{
9 completion::CompletionModel,
10 embeddings::EmbeddingModel,
11 extractor::{Extractor, ExtractorBuilder},
12};
13
14#[derive(Debug, thiserror::Error)]
16pub enum EvalError {
17 #[error("Field must not be null: {0}")]
19 FieldCannotBeNull(String),
20 #[error("Eval error: {0}")]
22 Custom(String),
23}
24
25#[derive(Deserialize, Serialize, Clone, Debug)]
28#[serde(tag = "outcome", content = "data")]
29pub enum EvalOutcome<Output> {
30 Pass(Output),
32 Fail(Output),
34 Invalid(String),
36}
37
38impl<Output> EvalOutcome<Output> {
39 pub fn is_pass(&self) -> bool {
41 matches!(self, EvalOutcome::Pass(_))
42 }
43
44 pub fn score(&self) -> Option<&Output> {
46 match self {
47 EvalOutcome::Pass(o) | EvalOutcome::Fail(o) => Some(o),
48 EvalOutcome::Invalid(_) => None,
49 }
50 }
51}
52
53pub trait Eval<Output>
60where
61 Output: for<'a> Deserialize<'a> + Serialize + Clone + Send + Sync,
62 Self: Sized + Send + Sync + 'static,
63{
64 fn eval(&self, input: String) -> impl Future<Output = EvalOutcome<Output>> + Send;
65
66 fn eval_batch(
71 &self,
72 input: Vec<String>,
73 concurrency_limit: usize,
74 ) -> impl Future<Output = Vec<EvalOutcome<Output>>> + Send {
75 use futures::StreamExt;
76 async move {
77 let thing: Vec<EvalOutcome<Output>> = futures::stream::iter(input)
78 .map(|x| Self::eval(self, x))
79 .buffered(concurrency_limit)
80 .collect()
81 .await;
82
83 thing
84 }
85 }
86}
87
88#[derive(Clone, Debug)]
92#[non_exhaustive]
93pub struct SemanticSimilarityMetric<E> {
94 embedding_model: E,
95 threshold: f64,
96 reference_answer: String,
97 reference_answer_embedding: Vec<f64>,
98}
99
100impl<E> SemanticSimilarityMetric<E>
101where
102 E: EmbeddingModel,
103{
104 pub fn builder(embedding_model: E) -> SemanticSimilarityMetricBuilder<E> {
105 SemanticSimilarityMetricBuilder::new(embedding_model)
106 }
107
108 pub fn reference_answer(&self) -> &str {
109 &self.reference_answer
110 }
111}
112
113#[derive(Clone, Debug)]
115#[non_exhaustive]
116pub struct SemanticSimilarityMetricBuilder<E> {
117 embedding_model: E,
118 threshold: Option<f64>,
119 reference_answer: Option<String>,
120}
121
122impl<E> SemanticSimilarityMetricBuilder<E>
123where
124 E: EmbeddingModel,
125{
126 pub fn new(embedding_model: E) -> Self {
127 Self {
128 embedding_model,
129 threshold: None,
130 reference_answer: None,
131 }
132 }
133
134 pub fn threshold(mut self, threshold: f64) -> Self {
135 self.threshold = Some(threshold);
136 self
137 }
138
139 pub fn reference_answer(mut self, reference_answer: &str) -> Self {
140 self.reference_answer = Some(reference_answer.to_string());
141 self
142 }
143
144 pub async fn build(self) -> Result<SemanticSimilarityMetric<E>, EvalError> {
145 let threshold = self
146 .threshold
147 .ok_or(EvalError::FieldCannotBeNull("threshold".into()))?;
148 let reference_answer = self
149 .reference_answer
150 .ok_or(EvalError::FieldCannotBeNull("reference_answer".into()))?;
151 let reference_answer_embedding = self
152 .embedding_model
153 .embed_text(&reference_answer)
154 .await
155 .map_err(|x| EvalError::Custom(x.to_string()))?
156 .vec;
157
158 let res = SemanticSimilarityMetric {
159 embedding_model: self.embedding_model,
160 threshold,
161 reference_answer,
162 reference_answer_embedding,
163 };
164
165 Ok(res)
166 }
167}
168
169#[derive(Deserialize, Serialize, Clone, Debug)]
171#[non_exhaustive]
172pub struct SemanticSimilarityMetricScore {
173 pub score: f64,
174}
175
176impl<E> Eval<SemanticSimilarityMetricScore> for SemanticSimilarityMetric<E>
177where
178 E: EmbeddingModel + 'static,
179{
180 async fn eval(&self, input: String) -> EvalOutcome<SemanticSimilarityMetricScore> {
181 let input = match self.embedding_model.embed_text(&input).await {
182 Ok(res) => res.vec,
183 Err(e) => return EvalOutcome::Invalid(e.to_string()),
184 };
185 let ref_answer = &self.reference_answer_embedding;
186
187 let dot: f64 = input.iter().zip(ref_answer).map(|(x, y)| x * y).sum();
188 let norm_a = input.iter().map(|x| x * x).sum::<f64>().sqrt();
189 let norm_b = ref_answer.iter().map(|x| x * x).sum::<f64>().sqrt();
190
191 let cosine_sim = dot / (norm_a * norm_b);
192
193 if cosine_sim >= self.threshold {
194 EvalOutcome::Pass(SemanticSimilarityMetricScore { score: cosine_sim })
195 } else {
196 EvalOutcome::Fail(SemanticSimilarityMetricScore { score: cosine_sim })
197 }
198 }
199}
200
201pub struct LlmJudgeMetric<M, T>
204where
205 M: CompletionModel,
206 T: Judgment + Send + Sync + JsonSchema + Serialize + for<'a> Deserialize<'a>,
207{
208 ext: Extractor<M, T>,
209}
210
211pub struct LlmJudgeMetricWithFn<M, T>
214where
215 M: CompletionModel,
216 T: Send + Sync + JsonSchema + Serialize + for<'a> Deserialize<'a>,
217{
218 ext: Extractor<M, T>,
219 evaluator: Box<dyn Fn(&T) -> bool + Send + Sync>,
220}
221
222pub struct LlmJudgeBuilder<M, T>
223where
224 M: CompletionModel,
225 T: Send + Sync + JsonSchema + Serialize + for<'a> Deserialize<'a> + 'static,
226{
227 ext: ExtractorBuilder<M, T>,
228}
229
230pub struct LlmJudgeBuilderWithFn<M, T>
231where
232 M: CompletionModel,
233 T: Send + Sync + JsonSchema + Serialize + for<'a> Deserialize<'a> + 'static,
234{
235 ext: ExtractorBuilder<M, T>,
236 evaluator: Box<dyn Fn(&T) -> bool + Send + Sync>,
237}
238
239impl<M, T> LlmJudgeBuilder<M, T>
240where
241 M: CompletionModel,
242 T: Send + Sync + JsonSchema + Serialize + for<'a> Deserialize<'a>,
243{
244 pub fn new(ext: ExtractorBuilder<M, T>) -> Self {
245 Self { ext }
246 }
247
248 pub fn with_fn<F>(self, f: F) -> LlmJudgeBuilderWithFn<M, T>
249 where
250 F: Fn(&T) -> bool + Send + Sync + 'static,
251 {
252 LlmJudgeBuilderWithFn {
253 ext: self.ext,
254 evaluator: Box::new(f),
255 }
256 }
257
258 pub fn build(self) -> LlmJudgeMetric<M, T>
259 where
260 T: Judgment + 'static,
261 {
262 let ext = self
263 .ext
264 .preamble(
265 "Judge the prompt input by the schema given and return it as a JSON tool result",
266 )
267 .build();
268 LlmJudgeMetric { ext }
269 }
270}
271
272impl<M, T> LlmJudgeBuilderWithFn<M, T>
273where
274 M: CompletionModel,
275 T: Send + Sync + JsonSchema + Serialize + for<'a> Deserialize<'a> + 'static,
276{
277 pub fn with_fn<F2>(mut self, f: F2) -> Self
278 where
279 F2: Fn(&T) -> bool + Send + Sync + 'static,
280 {
281 self.evaluator = Box::new(f);
282 self
283 }
284
285 pub fn build(self) -> LlmJudgeMetricWithFn<M, T> {
286 let ext = self
287 .ext
288 .preamble(
289 "Judge the prompt input by the schema given and return it as a JSON tool result",
290 )
291 .build();
292 LlmJudgeMetricWithFn {
293 ext,
294 evaluator: self.evaluator,
295 }
296 }
297}
298
299pub trait Judgment {
303 fn passes(&self) -> bool;
304}
305
306impl<M, T> Eval<T> for LlmJudgeMetric<M, T>
307where
308 M: CompletionModel + 'static,
309 T: Judgment + Send + Sync + JsonSchema + Serialize + for<'a> Deserialize<'a> + Clone + 'static,
310{
311 async fn eval(&self, input: String) -> EvalOutcome<T> {
312 match self.ext.extract(input).await {
313 Ok(judgment) => {
314 if judgment.passes() {
315 EvalOutcome::Pass(judgment)
316 } else {
317 EvalOutcome::Fail(judgment)
318 }
319 }
320 Err(e) => EvalOutcome::Invalid(e.to_string()),
321 }
322 }
323}
324
325impl<M, T> Eval<T> for LlmJudgeMetricWithFn<M, T>
326where
327 M: CompletionModel + 'static,
328 T: Send + Sync + JsonSchema + Serialize + for<'a> Deserialize<'a> + Clone + 'static,
329{
330 async fn eval(&self, input: String) -> EvalOutcome<T> {
331 match self.ext.extract(input).await {
332 Ok(judgment) => {
333 if (self.evaluator)(&judgment) {
334 EvalOutcome::Pass(judgment)
335 } else {
336 EvalOutcome::Fail(judgment)
337 }
338 }
339 Err(e) => EvalOutcome::Invalid(e.to_string()),
340 }
341 }
342}
343
344impl<M, T> From<ExtractorBuilder<M, T>> for LlmJudgeBuilder<M, T>
345where
346 M: CompletionModel,
347 T: Send + Sync + JsonSchema + Serialize + for<'a> Deserialize<'a>,
348{
349 fn from(ext: ExtractorBuilder<M, T>) -> Self {
350 Self::new(ext)
351 }
352}
353
354#[non_exhaustive]
356pub struct LlmScoreMetric<M>
357where
358 M: CompletionModel,
359{
360 agent: Extractor<M, LlmScoreMetricScore>,
361 threshold: f64,
362}
363
364#[derive(Deserialize, Serialize, Debug, Clone, JsonSchema)]
367pub struct LlmScoreMetricScore {
368 pub score: f64,
370 pub feedback: String,
372}
373
374impl<M> Eval<LlmScoreMetricScore> for LlmScoreMetric<M>
375where
376 M: CompletionModel + 'static,
377{
378 async fn eval(&self, input: String) -> EvalOutcome<LlmScoreMetricScore> {
379 let res = match self.agent.extract(input).await {
380 Ok(res) => res,
381 Err(e) => return EvalOutcome::Invalid(e.to_string()),
382 };
383
384 if !(0.0..=1.0).contains(&res.score) {
385 return EvalOutcome::Invalid(format!(
386 "Score {} outside valid range [0.0, 1.0]",
387 res.score
388 ));
389 }
390
391 if res.score >= self.threshold {
392 EvalOutcome::Pass(res)
393 } else {
394 EvalOutcome::Fail(res)
395 }
396 }
397}
398
399#[non_exhaustive]
400pub struct LlmScoreMetricBuilder<M>
401where
402 M: CompletionModel,
403{
404 agent: ExtractorBuilder<M, LlmScoreMetricScore>,
405 criteria: Vec<String>,
406 threshold: Option<f64>,
407}
408
409impl<M> LlmScoreMetricBuilder<M>
410where
411 M: CompletionModel,
412{
413 pub fn new(agent: ExtractorBuilder<M, LlmScoreMetricScore>) -> Self {
414 Self {
415 agent,
416 criteria: Vec::new(),
417 threshold: None,
418 }
419 }
420
421 pub fn threshold(mut self, threshold: f64) -> Self {
422 self.threshold = Some(threshold);
423 self
424 }
425
426 pub fn criteria(mut self, criteria: &str) -> Self {
427 self.criteria.push(criteria.to_string());
428 self
429 }
430
431 pub fn build(self) -> Result<LlmScoreMetric<M>, EvalError> {
432 let threshold = self
433 .threshold
434 .ok_or(EvalError::FieldCannotBeNull("threshold".into()))?;
435 let preamble = format!(
436 "You are an evaluation model. Score the input based on these criteria:\n{}\n\n\
437 Provide a score between 0.0 and 1.0 (where 1.0 is best) and explain your reasoning.",
438 self.criteria.join("\n")
439 );
440
441 let agent = self.agent.preamble(&preamble).build();
442
443 Ok(LlmScoreMetric { agent, threshold })
444 }
445}