dspy_rs/data/
prediction.rs

1use serde::{Deserialize, Serialize};
2use std::{collections::HashMap, ops::Index};
3
4use crate::LmUsage;
5
6#[derive(Serialize, Deserialize, Default, Debug, Clone)]
7pub struct Prediction {
8    pub data: HashMap<String, serde_json::Value>,
9    pub lm_usage: LmUsage,
10}
11
12impl Prediction {
13    pub fn new(data: HashMap<String, serde_json::Value>, lm_usage: LmUsage) -> Self {
14        Self { data, lm_usage }
15    }
16
17    pub fn get(&self, key: &str, default: Option<&str>) -> serde_json::Value {
18        self.data
19            .get(key)
20            .unwrap_or(&default.unwrap_or_default().to_string().into())
21            .clone()
22    }
23
24    pub fn keys(&self) -> Vec<String> {
25        self.data.keys().cloned().collect()
26    }
27
28    pub fn values(&self) -> Vec<serde_json::Value> {
29        self.data.values().cloned().collect()
30    }
31
32    pub fn set_lm_usage(&mut self, lm_usage: LmUsage) -> Self {
33        self.lm_usage = lm_usage;
34        self.clone()
35    }
36}
37
38impl Index<String> for Prediction {
39    type Output = serde_json::Value;
40
41    fn index(&self, index: String) -> &Self::Output {
42        &self.data[&index]
43    }
44}