dspy_rs/data/
prediction.rs

1use serde::{Deserialize, Serialize};
2use serde_json::Value;
3use std::{collections::HashMap, ops::Index};
4
5use crate::LmUsage;
6
7#[derive(Serialize, Deserialize, Default, Debug, Clone)]
8pub struct Prediction {
9    pub data: HashMap<String, serde_json::Value>,
10    pub lm_usage: LmUsage,
11}
12
13impl Prediction {
14    pub fn new(data: HashMap<String, serde_json::Value>, lm_usage: LmUsage) -> Self {
15        Self { data, lm_usage }
16    }
17
18    pub fn get(&self, key: &str, default: Option<&str>) -> serde_json::Value {
19        self.data
20            .get(key)
21            .unwrap_or(&default.unwrap_or_default().to_string().into())
22            .clone()
23    }
24
25    pub fn keys(&self) -> Vec<String> {
26        self.data.keys().cloned().collect()
27    }
28
29    pub fn values(&self) -> Vec<serde_json::Value> {
30        self.data.values().cloned().collect()
31    }
32
33    pub fn set_lm_usage(&mut self, lm_usage: LmUsage) -> Self {
34        self.lm_usage = lm_usage;
35        self.clone()
36    }
37}
38
39impl Index<String> for Prediction {
40    type Output = serde_json::Value;
41
42    fn index(&self, index: String) -> &Self::Output {
43        &self.data[&index]
44    }
45}
46
47impl IntoIterator for Prediction {
48    type Item = (String, Value);
49    type IntoIter = std::collections::hash_map::IntoIter<String, Value>;
50
51    fn into_iter(self) -> Self::IntoIter {
52        self.data.into_iter()
53    }
54}
55
56impl From<Vec<(String, Value)>> for Prediction {
57    fn from(value: Vec<(String, Value)>) -> Self {
58        Self {
59            data: value.into_iter().collect(),
60            lm_usage: LmUsage::default(),
61        }
62    }
63}