dspy_rs/data/
prediction.rs1use serde::{Deserialize, Serialize};
2use std::{collections::HashMap, ops::Index};
3
4use crate::LmUsage;
5
6#[derive(Serialize, Deserialize, Default, Debug, Clone)]
7pub struct Prediction {
8 pub data: HashMap<String, serde_json::Value>,
9 pub lm_usage: LmUsage,
10}
11
12impl Prediction {
13 pub fn new(data: HashMap<String, serde_json::Value>, lm_usage: LmUsage) -> Self {
14 Self { data, lm_usage }
15 }
16
17 pub fn get(&self, key: &str, default: Option<&str>) -> serde_json::Value {
18 self.data
19 .get(key)
20 .unwrap_or(&default.unwrap_or_default().to_string().into())
21 .clone()
22 }
23
24 pub fn keys(&self) -> Vec<String> {
25 self.data.keys().cloned().collect()
26 }
27
28 pub fn values(&self) -> Vec<serde_json::Value> {
29 self.data.values().cloned().collect()
30 }
31
32 pub fn set_lm_usage(&mut self, lm_usage: LmUsage) -> Self {
33 self.lm_usage = lm_usage;
34 self.clone()
35 }
36}
37
38impl Index<String> for Prediction {
39 type Output = serde_json::Value;
40
41 fn index(&self, index: String) -> &Self::Output {
42 &self.data[&index]
43 }
44}