potato_util/
utils.rs

1use crate::error::UtilError;
2
3use colored_json::{Color, ColorMode, ColoredFormatter, PrettyFormatter, Styler};
4use pyo3::prelude::*;
5
6use pyo3::types::{
7    PyAny, PyBool, PyDict, PyDictMethods, PyFloat, PyInt, PyList, PyString, PyTuple,
8};
9use pyo3::IntoPyObjectExt;
10use serde::Serialize;
11use serde_json::json;
12use serde_json::Value;
13use serde_json::Value::{Null, Object};
14use std::ops::RangeInclusive;
15use std::path::Path;
16use uuid::Uuid;
17pub fn create_uuid7() -> String {
18    Uuid::now_v7().to_string()
19}
20
21pub struct PyHelperFuncs {}
22
23impl PyHelperFuncs {
24    pub fn __str__<T: Serialize>(object: T) -> String {
25        match ColoredFormatter::with_styler(
26            PrettyFormatter::default(),
27            Styler {
28                key: Color::Rgb(75, 57, 120).foreground(),
29                string_value: Color::Rgb(4, 205, 155).foreground(),
30                float_value: Color::Rgb(4, 205, 155).foreground(),
31                integer_value: Color::Rgb(4, 205, 155).foreground(),
32                bool_value: Color::Rgb(4, 205, 155).foreground(),
33                nil_value: Color::Rgb(4, 205, 155).foreground(),
34                ..Default::default()
35            },
36        )
37        .to_colored_json(&object, ColorMode::On)
38        {
39            Ok(json) => json,
40            Err(e) => format!("Failed to serialize to json: {e}"),
41        }
42        // serialize the struct to a string
43    }
44
45    pub fn __json__<T: Serialize>(object: T) -> String {
46        match serde_json::to_string_pretty(&object) {
47            Ok(json) => json,
48            Err(e) => format!("Failed to serialize to json: {e}"),
49        }
50    }
51
52    /// Save a struct to a JSON file
53    ///
54    /// # Arguments
55    ///
56    /// * `model` - A reference to a struct that implements the `Serialize` trait
57    /// * `path` - A reference to a `Path` object that holds the path to the file
58    ///
59    /// # Returns
60    ///
61    /// A `Result` containing `()` or a `UtilError`
62    ///
63    /// # Errors
64    ///
65    /// This function will return an error if:
66    /// - The struct cannot be serialized to a string
67    pub fn save_to_json<T>(model: T, path: &Path) -> Result<(), UtilError>
68    where
69        T: Serialize,
70    {
71        // serialize the struct to a string
72        let json =
73            serde_json::to_string_pretty(&model).map_err(|_| UtilError::SerializationError)?;
74
75        // ensure .json extension
76        let path = path.with_extension("json");
77
78        if !path.exists() {
79            // ensure path exists, create if not
80            let parent_path = path.parent().ok_or(UtilError::GetParentPathError)?;
81
82            std::fs::create_dir_all(parent_path).map_err(|_| UtilError::CreateDirectoryError)?;
83        }
84
85        std::fs::write(path, json).map_err(|_| UtilError::WriteError)?;
86
87        Ok(())
88    }
89}
90
91pub fn json_to_pydict<'py>(
92    py: Python,
93    value: &Value,
94    dict: &Bound<'py, PyDict>,
95) -> Result<Bound<'py, PyDict>, UtilError> {
96    match value {
97        Value::Object(map) => {
98            for (k, v) in map {
99                let py_value = match v {
100                    Value::Null => py.None(),
101                    Value::Bool(b) => b.into_py_any(py)?,
102                    Value::Number(n) => {
103                        if let Some(i) = n.as_i64() {
104                            i.into_py_any(py)?
105                        } else if let Some(f) = n.as_f64() {
106                            f.into_py_any(py)?
107                        } else {
108                            return Err(UtilError::InvalidNumber);
109                        }
110                    }
111                    Value::String(s) => s.into_py_any(py)?,
112                    Value::Array(arr) => {
113                        let py_list = PyList::empty(py);
114                        for item in arr {
115                            let py_item = json_to_pyobject(py, item)?;
116                            py_list.append(py_item)?;
117                        }
118                        py_list.into_py_any(py)?
119                    }
120                    Value::Object(_) => {
121                        let nested_dict = PyDict::new(py);
122                        json_to_pydict(py, v, &nested_dict)?;
123                        nested_dict.into_py_any(py)?
124                    }
125                };
126                dict.set_item(k, py_value)?;
127            }
128        }
129        _ => return Err(UtilError::RootMustBeObjectError),
130    }
131
132    Ok(dict.clone())
133}
134
135/// Converts a serde_json::Value to a PyObject. Including support for nested objects and arrays.
136/// This function handles all Serde JSON types:
137/// - Serde Null -> Python None
138/// - Serde Bool -> Python bool
139/// - Serde String -> Python str
140/// - Serde Number -> Python int or float
141/// - Serde Array -> Python list (with each item converted to Python type)
142/// - Serde Object -> Python dict (with each key-value pair converted to Python type)
143////// # Arguments
144/// * `py` - A Python interpreter instance.
145/// * `value` - A reference to a serde_json::Value object.
146/// # Returns
147/// * `Ok(PyObject)` if the conversion was successful.
148/// * `Err(UtilError)` if the conversion failed.
149pub fn json_to_pyobject(py: Python, value: &Value) -> Result<PyObject, UtilError> {
150    Ok(match value {
151        Value::Null => py.None(),
152        Value::Bool(b) => b.into_py_any(py)?,
153        Value::Number(n) => {
154            if let Some(i) = n.as_i64() {
155                i.into_py_any(py)?
156            } else if let Some(f) = n.as_f64() {
157                f.into_py_any(py)?
158            } else {
159                return Err(UtilError::InvalidNumber);
160            }
161        }
162        Value::String(s) => s.into_py_any(py)?,
163        Value::Array(arr) => {
164            let py_list = PyList::empty(py);
165            for item in arr {
166                let py_item = json_to_pyobject(py, item)?;
167                py_list.append(py_item)?;
168            }
169            py_list.into_py_any(py)?
170        }
171        Value::Object(_) => {
172            let nested_dict = PyDict::new(py);
173            json_to_pydict(py, value, &nested_dict)?;
174            nested_dict.into_py_any(py)?
175        }
176    })
177}
178
179pub fn pyobject_to_json(obj: &Bound<'_, PyAny>) -> Result<Value, UtilError> {
180    if obj.is_instance_of::<PyDict>() {
181        let dict = obj.downcast::<PyDict>()?;
182        let mut map = serde_json::Map::new();
183        for (key, value) in dict.iter() {
184            let key_str = key.extract::<String>()?;
185            let json_value = pyobject_to_json(&value)?;
186            map.insert(key_str, json_value);
187        }
188        Ok(Value::Object(map))
189    } else if obj.is_instance_of::<PyList>() {
190        let list = obj.downcast::<PyList>()?;
191        let mut vec = Vec::new();
192        for item in list.iter() {
193            vec.push(pyobject_to_json(&item)?);
194        }
195        Ok(Value::Array(vec))
196    } else if obj.is_instance_of::<PyTuple>() {
197        let tuple = obj.downcast::<PyTuple>()?;
198        let mut vec = Vec::new();
199        for item in tuple.iter() {
200            vec.push(pyobject_to_json(&item)?);
201        }
202        Ok(Value::Array(vec))
203    } else if obj.is_instance_of::<PyString>() {
204        let s = obj.extract::<String>()?;
205        Ok(Value::String(s))
206    } else if obj.is_instance_of::<PyFloat>() {
207        let f = obj.extract::<f64>()?;
208        Ok(json!(f))
209    } else if obj.is_instance_of::<PyBool>() {
210        let b = obj.extract::<bool>()?;
211        Ok(json!(b))
212    } else if obj.is_instance_of::<PyInt>() {
213        let i = obj.extract::<i64>()?;
214        Ok(json!(i))
215    } else if obj.is_none() {
216        Ok(Value::Null)
217    } else {
218        // display "cant show" for unsupported types
219        // call obj.str to get the string representation
220        // if error, default to "unsupported type"
221        let obj_str = match obj.str() {
222            Ok(s) => s
223                .extract::<String>()
224                .unwrap_or_else(|_| "unsupported type".to_string()),
225            Err(_) => "unsupported type".to_string(),
226        };
227
228        Ok(Value::String(obj_str))
229    }
230}
231
232pub fn version() -> String {
233    env!("CARGO_PKG_VERSION").to_string()
234}
235
236pub fn update_serde_value(value: &mut Value, key: &str, new_value: Value) -> Result<(), UtilError> {
237    if let Value::Object(map) = value {
238        map.insert(key.to_string(), new_value);
239        Ok(())
240    } else {
241        Err(UtilError::RootMustBeObjectError)
242    }
243}
244
245/// Updates a serde_json::Value object with another serde_json::Value object.
246/// Both types must be of the `Object` variant.
247/// If a key in the source object does not exist in the destination object,
248/// it will be added with the value from the source object.
249/// # Arguments
250/// * `dest` - A mutable reference to the destination serde_json::Value object.
251/// * `src` - A reference to the source serde_json::Value object.
252/// # Returns
253/// * `Ok(())` if the update was successful.    
254/// * `Err(UtilError::RootMustBeObjectError)` if either `dest` or `src` is not an `Object`.
255pub fn update_serde_map_with(
256    dest: &mut serde_json::Value,
257    src: &serde_json::Value,
258) -> Result<(), UtilError> {
259    match (dest, src) {
260        (&mut Object(ref mut map_dest), Object(ref map_src)) => {
261            // map_dest and map_src both are Map<String, Value>
262            for (key, value) in map_src {
263                // if key is not in map_dest, create a Null object
264                // then only, update the value
265                *map_dest.entry(key.clone()).or_insert(Null) = value.clone();
266            }
267            Ok(())
268        }
269        (_, _) => Err(UtilError::RootMustBeObjectError),
270    }
271}
272
273/// Extracts a string value from a Python object.
274pub fn extract_string_value(py_value: &Bound<'_, PyAny>) -> Result<String, UtilError> {
275    // Try to extract as string first (most common case)
276    if let Ok(string_val) = py_value.extract::<String>() {
277        return Ok(string_val);
278    }
279
280    // Try to extract as integer
281    if let Ok(int_val) = py_value.extract::<i64>() {
282        return Ok(int_val.to_string());
283    }
284
285    // Try to extract as float
286    if let Ok(float_val) = py_value.extract::<f64>() {
287        return Ok(float_val.to_string());
288    }
289
290    // Try to extract as boolean
291    if let Ok(bool_val) = py_value.extract::<bool>() {
292        return Ok(bool_val.to_string());
293    }
294
295    // For complex objects, convert to JSON but extract the value without quotes
296    let json_value = pyobject_to_json(py_value)?;
297
298    match json_value {
299        Value::String(s) => Ok(s),
300        Value::Number(n) => Ok(n.to_string()),
301        Value::Bool(b) => Ok(b.to_string()),
302        Value::Null => Ok("null".to_string()),
303        _ => {
304            // For arrays and objects, serialize to JSON string
305            let json_string = serde_json::to_string(&json_value)?;
306            Ok(json_string)
307        }
308    }
309}
310
311#[pyclass]
312#[derive(Debug, Serialize, Clone)]
313pub struct ResponseLogProbs {
314    #[pyo3(get)]
315    pub token: String,
316
317    #[pyo3(get)]
318    pub logprob: f64,
319}
320
321#[pyclass]
322#[derive(Debug, Serialize, Clone)]
323pub struct LogProbs {
324    #[pyo3(get)]
325    pub tokens: Vec<ResponseLogProbs>,
326}
327
328#[pymethods]
329impl LogProbs {
330    pub fn __str__(&self) -> String {
331        PyHelperFuncs::__str__(self)
332    }
333}
334
335/// Calculate a weighted score base on the log probabilities of tokens 1-5.
336pub fn calculate_weighted_score(log_probs: &[ResponseLogProbs]) -> Result<Option<f64>, UtilError> {
337    let score_range = RangeInclusive::new(1, 5);
338    let mut score_probs = Vec::new();
339    let mut weighted_sum = 0.0;
340    let mut total_prob = 0.0;
341
342    for log_prob in log_probs {
343        let token = log_prob.token.parse::<u64>().ok();
344
345        if let Some(token_val) = token {
346            if score_range.contains(&token_val) {
347                let prob = log_prob.logprob.exp();
348                score_probs.push((token_val, prob));
349            }
350        }
351    }
352
353    for (score, logprob) in score_probs {
354        weighted_sum += score as f64 * logprob;
355        total_prob += logprob;
356    }
357
358    if total_prob > 0.0 {
359        Ok(Some(weighted_sum / total_prob))
360    } else {
361        Ok(None)
362    }
363}
364
365#[cfg(test)]
366mod tests {
367    use super::*;
368    #[test]
369    fn test_calculate_weighted_score() {
370        let log_probs = vec![
371            ResponseLogProbs {
372                token: "1".into(),
373                logprob: 0.9,
374            },
375            ResponseLogProbs {
376                token: "2".into(),
377                logprob: 0.8,
378            },
379            ResponseLogProbs {
380                token: "3".into(),
381                logprob: 0.7,
382            },
383        ];
384
385        let result = calculate_weighted_score(&log_probs);
386        assert!(result.is_ok());
387
388        let val = result.unwrap().unwrap();
389        // round to int
390        assert_eq!(val.round(), 2.0);
391    }
392    #[test]
393    fn test_calculate_weighted_score_empty() {
394        let log_probs: Vec<ResponseLogProbs> = vec![];
395        let result = calculate_weighted_score(&log_probs);
396        assert!(result.is_ok());
397        assert_eq!(result.unwrap(), None);
398    }
399}