potato_util/
utils.rs

1use crate::error::UtilError;
2
3use colored_json::{Color, ColorMode, ColoredFormatter, PrettyFormatter, Styler};
4use pyo3::prelude::*;
5
6use pyo3::types::{
7    PyAny, PyBool, PyDict, PyDictMethods, PyFloat, PyInt, PyList, PyString, PyTuple,
8};
9use pyo3::IntoPyObjectExt;
10use serde::Serialize;
11use serde_json::json;
12use serde_json::Value;
13use serde_json::Value::{Null, Object};
14use std::ops::RangeInclusive;
15use std::path::Path;
16use uuid::Uuid;
17pub fn create_uuid7() -> String {
18    Uuid::now_v7().to_string()
19}
20
21pub struct PyHelperFuncs {}
22
23impl PyHelperFuncs {
24    pub fn __str__<T: Serialize>(object: T) -> String {
25        match ColoredFormatter::with_styler(
26            PrettyFormatter::default(),
27            Styler {
28                key: Color::Rgb(75, 57, 120).foreground(),
29                string_value: Color::Rgb(4, 205, 155).foreground(),
30                float_value: Color::Rgb(4, 205, 155).foreground(),
31                integer_value: Color::Rgb(4, 205, 155).foreground(),
32                bool_value: Color::Rgb(4, 205, 155).foreground(),
33                nil_value: Color::Rgb(4, 205, 155).foreground(),
34                ..Default::default()
35            },
36        )
37        .to_colored_json(&object, ColorMode::On)
38        {
39            Ok(json) => json,
40            Err(e) => format!("Failed to serialize to json: {e}"),
41        }
42        // serialize the struct to a string
43    }
44
45    pub fn __json__<T: Serialize>(object: T) -> String {
46        match serde_json::to_string_pretty(&object) {
47            Ok(json) => json,
48            Err(e) => format!("Failed to serialize to json: {e}"),
49        }
50    }
51
52    /// Save a struct to a JSON file
53    ///
54    /// # Arguments
55    ///
56    /// * `model` - A reference to a struct that implements the `Serialize` trait
57    /// * `path` - A reference to a `Path` object that holds the path to the file
58    ///
59    /// # Returns
60    ///
61    /// A `Result` containing `()` or a `UtilError`
62    ///
63    /// # Errors
64    ///
65    /// This function will return an error if:
66    /// - The struct cannot be serialized to a string
67    pub fn save_to_json<T>(model: T, path: &Path) -> Result<(), UtilError>
68    where
69        T: Serialize,
70    {
71        // serialize the struct to a string
72        let json =
73            serde_json::to_string_pretty(&model).map_err(|_| UtilError::SerializationError)?;
74
75        // ensure .json extension
76        let path = path.with_extension("json");
77
78        if !path.exists() {
79            // ensure path exists, create if not
80            let parent_path = path.parent().ok_or(UtilError::GetParentPathError)?;
81
82            std::fs::create_dir_all(parent_path).map_err(|_| UtilError::CreateDirectoryError)?;
83        }
84
85        std::fs::write(path, json).map_err(|_| UtilError::WriteError)?;
86
87        Ok(())
88    }
89}
90
91pub fn json_to_pydict<'py>(
92    py: Python,
93    value: &Value,
94    dict: &Bound<'py, PyDict>,
95) -> Result<Bound<'py, PyDict>, UtilError> {
96    match value {
97        Value::Object(map) => {
98            for (k, v) in map {
99                let py_value = match v {
100                    Value::Null => py.None(),
101                    Value::Bool(b) => b.into_py_any(py)?,
102                    Value::Number(n) => {
103                        if let Some(i) = n.as_i64() {
104                            i.into_py_any(py)?
105                        } else if let Some(f) = n.as_f64() {
106                            f.into_py_any(py)?
107                        } else {
108                            return Err(UtilError::InvalidNumber);
109                        }
110                    }
111                    Value::String(s) => s.into_py_any(py)?,
112                    Value::Array(arr) => {
113                        let py_list = PyList::empty(py);
114                        for item in arr {
115                            let py_item = json_to_pyobject(py, item)?;
116                            py_list.append(py_item)?;
117                        }
118                        py_list.into_py_any(py)?
119                    }
120                    Value::Object(_) => {
121                        let nested_dict = PyDict::new(py);
122                        json_to_pydict(py, v, &nested_dict)?;
123                        nested_dict.into_py_any(py)?
124                    }
125                };
126                dict.set_item(k, py_value)?;
127            }
128        }
129        _ => return Err(UtilError::RootMustBeObjectError),
130    }
131
132    Ok(dict.clone())
133}
134
135/// Converts a serde_json::Value to a PyObject. Including support for nested objects and arrays.
136/// This function handles all Serde JSON types:
137/// - Serde Null -> Python None
138/// - Serde Bool -> Python bool
139/// - Serde String -> Python str
140/// - Serde Number -> Python int or float
141/// - Serde Array -> Python list (with each item converted to Python type)
142/// - Serde Object -> Python dict (with each key-value pair converted to Python type)
143////// # Arguments
144/// * `py` - A Python interpreter instance.
145/// * `value` - A reference to a serde_json::Value object.
146/// # Returns
147/// * `Ok(PyObject)` if the conversion was successful.
148/// * `Err(UtilError)` if the conversion failed.
149pub fn json_to_pyobject(py: Python, value: &Value) -> Result<Py<PyAny>, UtilError> {
150    Ok(match value {
151        Value::Null => py.None(),
152        Value::Bool(b) => b.into_py_any(py)?,
153        Value::Number(n) => {
154            if let Some(i) = n.as_i64() {
155                i.into_py_any(py)?
156            } else if let Some(f) = n.as_f64() {
157                f.into_py_any(py)?
158            } else {
159                return Err(UtilError::InvalidNumber);
160            }
161        }
162        Value::String(s) => s.into_py_any(py)?,
163        Value::Array(arr) => {
164            let py_list = PyList::empty(py);
165            for item in arr {
166                let py_item = json_to_pyobject(py, item)?;
167                py_list.append(py_item)?;
168            }
169            py_list.into_py_any(py)?
170        }
171        Value::Object(_) => {
172            let nested_dict = PyDict::new(py);
173            json_to_pydict(py, value, &nested_dict)?;
174            nested_dict.into_py_any(py)?
175        }
176    })
177}
178
179pub fn vec_to_py_object<'py>(
180    py: Python<'py>,
181    vec: &Vec<Value>,
182) -> Result<Bound<'py, PyList>, UtilError> {
183    let py_list = PyList::empty(py);
184    for item in vec {
185        let py_item = json_to_pyobject(py, item)?;
186        py_list.append(py_item)?;
187    }
188    Ok(py_list)
189}
190
191pub fn pyobject_to_json(obj: &Bound<'_, PyAny>) -> Result<Value, UtilError> {
192    if obj.is_instance_of::<PyDict>() {
193        let dict = obj.downcast::<PyDict>()?;
194        let mut map = serde_json::Map::new();
195        for (key, value) in dict.iter() {
196            let key_str = key.extract::<String>()?;
197            let json_value = pyobject_to_json(&value)?;
198            map.insert(key_str, json_value);
199        }
200        Ok(Value::Object(map))
201    } else if obj.is_instance_of::<PyList>() {
202        let list = obj.downcast::<PyList>()?;
203        let mut vec = Vec::new();
204        for item in list.iter() {
205            vec.push(pyobject_to_json(&item)?);
206        }
207        Ok(Value::Array(vec))
208    } else if obj.is_instance_of::<PyTuple>() {
209        let tuple = obj.downcast::<PyTuple>()?;
210        let mut vec = Vec::new();
211        for item in tuple.iter() {
212            vec.push(pyobject_to_json(&item)?);
213        }
214        Ok(Value::Array(vec))
215    } else if obj.is_instance_of::<PyString>() {
216        let s = obj.extract::<String>()?;
217        Ok(Value::String(s))
218    } else if obj.is_instance_of::<PyFloat>() {
219        let f = obj.extract::<f64>()?;
220        Ok(json!(f))
221    } else if obj.is_instance_of::<PyBool>() {
222        let b = obj.extract::<bool>()?;
223        Ok(json!(b))
224    } else if obj.is_instance_of::<PyInt>() {
225        let i = obj.extract::<i64>()?;
226        Ok(json!(i))
227    } else if obj.is_none() {
228        Ok(Value::Null)
229    } else {
230        // display "cant show" for unsupported types
231        // call obj.str to get the string representation
232        // if error, default to "unsupported type"
233        let obj_str = match obj.str() {
234            Ok(s) => s
235                .extract::<String>()
236                .unwrap_or_else(|_| "unsupported type".to_string()),
237            Err(_) => "unsupported type".to_string(),
238        };
239
240        Ok(Value::String(obj_str))
241    }
242}
243
244pub fn version() -> String {
245    env!("CARGO_PKG_VERSION").to_string()
246}
247
248pub fn update_serde_value(value: &mut Value, key: &str, new_value: Value) -> Result<(), UtilError> {
249    if let Value::Object(map) = value {
250        map.insert(key.to_string(), new_value);
251        Ok(())
252    } else {
253        Err(UtilError::RootMustBeObjectError)
254    }
255}
256
257/// Updates a serde_json::Value object with another serde_json::Value object.
258/// Both types must be of the `Object` variant.
259/// If a key in the source object does not exist in the destination object,
260/// it will be added with the value from the source object.
261/// # Arguments
262/// * `dest` - A mutable reference to the destination serde_json::Value object.
263/// * `src` - A reference to the source serde_json::Value object.
264/// # Returns
265/// * `Ok(())` if the update was successful.    
266/// * `Err(UtilError::RootMustBeObjectError)` if either `dest` or `src` is not an `Object`.
267pub fn update_serde_map_with(
268    dest: &mut serde_json::Value,
269    src: &serde_json::Value,
270) -> Result<(), UtilError> {
271    match (dest, src) {
272        (&mut Object(ref mut map_dest), Object(ref map_src)) => {
273            // map_dest and map_src both are Map<String, Value>
274            for (key, value) in map_src {
275                // if key is not in map_dest, create a Null object
276                // then only, update the value
277                *map_dest.entry(key.clone()).or_insert(Null) = value.clone();
278            }
279            Ok(())
280        }
281        (_, _) => Err(UtilError::RootMustBeObjectError),
282    }
283}
284
285/// Extracts a string value from a Python object.
286pub fn extract_string_value(py_value: &Bound<'_, PyAny>) -> Result<String, UtilError> {
287    // Try to extract as string first (most common case)
288    if let Ok(string_val) = py_value.extract::<String>() {
289        return Ok(string_val);
290    }
291
292    // Try to extract as integer
293    if let Ok(int_val) = py_value.extract::<i64>() {
294        return Ok(int_val.to_string());
295    }
296
297    // Try to extract as float
298    if let Ok(float_val) = py_value.extract::<f64>() {
299        return Ok(float_val.to_string());
300    }
301
302    // Try to extract as boolean
303    if let Ok(bool_val) = py_value.extract::<bool>() {
304        return Ok(bool_val.to_string());
305    }
306
307    // For complex objects, convert to JSON but extract the value without quotes
308    let json_value = pyobject_to_json(py_value)?;
309
310    match json_value {
311        Value::String(s) => Ok(s),
312        Value::Number(n) => Ok(n.to_string()),
313        Value::Bool(b) => Ok(b.to_string()),
314        Value::Null => Ok("null".to_string()),
315        _ => {
316            // For arrays and objects, serialize to JSON string
317            let json_string = serde_json::to_string(&json_value)?;
318            Ok(json_string)
319        }
320    }
321}
322
323#[pyclass]
324#[derive(Debug, Serialize, Clone)]
325pub struct ResponseLogProbs {
326    #[pyo3(get)]
327    pub token: String,
328
329    #[pyo3(get)]
330    pub logprob: f64,
331}
332
333#[pyclass]
334#[derive(Debug, Serialize, Clone)]
335pub struct LogProbs {
336    #[pyo3(get)]
337    pub tokens: Vec<ResponseLogProbs>,
338}
339
340#[pymethods]
341impl LogProbs {
342    pub fn __str__(&self) -> String {
343        PyHelperFuncs::__str__(self)
344    }
345}
346
347/// Calculate a weighted score base on the log probabilities of tokens 1-5.
348pub fn calculate_weighted_score(log_probs: &[ResponseLogProbs]) -> Result<Option<f64>, UtilError> {
349    let score_range = RangeInclusive::new(1, 5);
350    let mut score_probs = Vec::new();
351    let mut weighted_sum = 0.0;
352    let mut total_prob = 0.0;
353
354    for log_prob in log_probs {
355        let token = log_prob.token.parse::<u64>().ok();
356
357        if let Some(token_val) = token {
358            if score_range.contains(&token_val) {
359                let prob = log_prob.logprob.exp();
360                score_probs.push((token_val, prob));
361            }
362        }
363    }
364
365    for (score, logprob) in score_probs {
366        weighted_sum += score as f64 * logprob;
367        total_prob += logprob;
368    }
369
370    if total_prob > 0.0 {
371        Ok(Some(weighted_sum / total_prob))
372    } else {
373        Ok(None)
374    }
375}
376
377#[cfg(test)]
378mod tests {
379    use super::*;
380    #[test]
381    fn test_calculate_weighted_score() {
382        let log_probs = vec![
383            ResponseLogProbs {
384                token: "1".into(),
385                logprob: 0.9,
386            },
387            ResponseLogProbs {
388                token: "2".into(),
389                logprob: 0.8,
390            },
391            ResponseLogProbs {
392                token: "3".into(),
393                logprob: 0.7,
394            },
395        ];
396
397        let result = calculate_weighted_score(&log_probs);
398        assert!(result.is_ok());
399
400        let val = result.unwrap().unwrap();
401        // round to int
402        assert_eq!(val.round(), 2.0);
403    }
404    #[test]
405    fn test_calculate_weighted_score_empty() {
406        let log_probs: Vec<ResponseLogProbs> = vec![];
407        let result = calculate_weighted_score(&log_probs);
408        assert!(result.is_ok());
409        assert_eq!(result.unwrap(), None);
410    }
411}