Skip to main content

potato_util/
utils.rs

1use crate::error::UtilError;
2
3use colored_json::{Color, ColorMode, ColoredFormatter, PrettyFormatter, Styler};
4use pyo3::prelude::*;
5
6use pyo3::types::{PyAny, PyDict, PyList};
7use pyo3::IntoPyObjectExt;
8use serde::Serialize;
9use serde_json::Value;
10use serde_json::Value::{Null, Object};
11use std::ops::RangeInclusive;
12use std::path::Path;
13use uuid::Uuid;
14pub fn create_uuid7() -> String {
15    Uuid::now_v7().to_string()
16}
17use pythonize::{depythonize, pythonize};
18use tracing::warn;
19pub struct PyHelperFuncs {}
20
21impl PyHelperFuncs {
22    /// Convert any type implementing `IntoPyObject` to a Python object
23    /// # Arguments
24    /// * `py` - A Python interpreter instance
25    /// * `object` - A reference to an object implementing `IntoPyObject`
26    /// # Returns
27    /// * `Result<Bound<'py, PyAny>, UtilError>` - A result containing the Python object or an error
28    pub fn to_bound_py_object<'py, T>(
29        py: Python<'py>,
30        object: &T,
31    ) -> Result<Bound<'py, PyAny>, UtilError>
32    where
33        T: IntoPyObject<'py> + Clone,
34    {
35        Ok(object.clone().into_bound_py_any(py)?)
36    }
37    pub fn __str__<T: Serialize>(object: T) -> String {
38        match ColoredFormatter::with_styler(
39            PrettyFormatter::default(),
40            Styler {
41                key: Color::Rgb(75, 57, 120).foreground(),
42                string_value: Color::Rgb(4, 205, 155).foreground(),
43                float_value: Color::Rgb(4, 205, 155).foreground(),
44                integer_value: Color::Rgb(4, 205, 155).foreground(),
45                bool_value: Color::Rgb(4, 205, 155).foreground(),
46                nil_value: Color::Rgb(4, 205, 155).foreground(),
47                ..Default::default()
48            },
49        )
50        .to_colored_json(&object, ColorMode::On)
51        {
52            Ok(json) => json,
53            Err(e) => format!("Failed to serialize to json: {e}"),
54        }
55        // serialize the struct to a string
56    }
57
58    pub fn __json__<T: Serialize>(object: T) -> String {
59        match serde_json::to_string_pretty(&object) {
60            Ok(json) => json,
61            Err(e) => format!("Failed to serialize to json: {e}"),
62        }
63    }
64
65    /// Save a struct to a JSON file
66    ///
67    /// # Arguments
68    ///
69    /// * `model` - A reference to a struct that implements the `Serialize` trait
70    /// * `path` - A reference to a `Path` object that holds the path to the file
71    ///
72    /// # Returns
73    ///
74    /// A `Result` containing `()` or a `UtilError`
75    ///
76    /// # Errors
77    ///
78    /// This function will return an error if:
79    /// - The struct cannot be serialized to a string
80    pub fn save_to_json<T>(model: T, path: &Path) -> Result<(), UtilError>
81    where
82        T: Serialize,
83    {
84        // serialize the struct to a string
85        let json =
86            serde_json::to_string_pretty(&model).map_err(|_| UtilError::SerializationError)?;
87
88        // ensure .json extension
89        let path = path.with_extension("json");
90
91        if !path.exists() {
92            // ensure path exists, create if not
93            let parent_path = path.parent().ok_or(UtilError::GetParentPathError)?;
94
95            std::fs::create_dir_all(parent_path).map_err(|_| UtilError::CreateDirectoryError)?;
96        }
97
98        std::fs::write(path, json).map_err(|_| UtilError::WriteError)?;
99
100        Ok(())
101    }
102}
103
104pub fn vec_to_py_object<'py>(
105    py: Python<'py>,
106    vec: &Vec<Value>,
107) -> Result<Bound<'py, PyList>, UtilError> {
108    let py_list = PyList::empty(py);
109    for item in vec {
110        let py_item = pythonize(py, item)?;
111        py_list.append(py_item)?;
112    }
113    Ok(py_list)
114}
115
116pub fn version() -> String {
117    env!("CARGO_PKG_VERSION").to_string()
118}
119
120pub fn update_serde_value(value: &mut Value, key: &str, new_value: Value) -> Result<(), UtilError> {
121    if let Value::Object(map) = value {
122        map.insert(key.to_string(), new_value);
123        Ok(())
124    } else {
125        Err(UtilError::RootMustBeObjectError)
126    }
127}
128
129/// Updates a serde_json::Value object with another serde_json::Value object.
130/// Both types must be of the `Object` variant.
131/// If a key in the source object does not exist in the destination object,
132/// it will be added with the value from the source object.
133/// # Arguments
134/// * `dest` - A mutable reference to the destination serde_json::Value object.
135/// * `src` - A reference to the source serde_json::Value object.
136/// # Returns
137/// * `Ok(())` if the update was successful.
138/// * `Err(UtilError::RootMustBeObjectError)` if either `dest` or `src` is not an `Object`.
139pub fn update_serde_map_with(
140    dest: &mut serde_json::Value,
141    src: &serde_json::Value,
142) -> Result<(), UtilError> {
143    match (dest, src) {
144        (&mut Object(ref mut map_dest), Object(ref map_src)) => {
145            // map_dest and map_src both are Map<String, Value>
146            for (key, value) in map_src {
147                // if key is not in map_dest, create a Null object
148                // then only, update the value
149                *map_dest.entry(key.clone()).or_insert(Null) = value.clone();
150            }
151            Ok(())
152        }
153        (_, _) => Err(UtilError::RootMustBeObjectError),
154    }
155}
156
157/// Extracts a string value from a Python object.
158pub fn extract_string_value(py_value: &Bound<'_, PyAny>) -> Result<String, UtilError> {
159    // Try to extract as string first (most common case)
160    if let Ok(string_val) = py_value.extract::<String>() {
161        return Ok(string_val);
162    }
163    // Try to extract as boolean
164    if let Ok(bool_val) = py_value.extract::<bool>() {
165        return Ok(bool_val.to_string());
166    }
167
168    // Try to extract as integer
169    if let Ok(int_val) = py_value.extract::<i64>() {
170        return Ok(int_val.to_string());
171    }
172
173    // Try to extract as float
174    if let Ok(float_val) = py_value.extract::<f64>() {
175        return Ok(float_val.to_string());
176    }
177
178    // For complex objects, convert to JSON but extract the value without quotes
179    let json_value = depythonize(py_value)?;
180
181    match json_value {
182        Value::String(s) => Ok(s),
183        Value::Number(n) => Ok(n.to_string()),
184        Value::Bool(b) => Ok(b.to_string()),
185        Value::Null => Ok("null".to_string()),
186        _ => {
187            // For arrays and objects, serialize to JSON string
188            let json_string = serde_json::to_string(&json_value)?;
189            Ok(json_string)
190        }
191    }
192}
193
194#[pyclass]
195#[derive(Debug, Serialize, Clone)]
196pub struct TokenLogProbs {
197    #[pyo3(get)]
198    pub token: String,
199
200    #[pyo3(get)]
201    pub logprob: f64,
202}
203
204#[pyclass]
205#[derive(Debug, Serialize, Clone)]
206pub struct ResponseLogProbs {
207    #[pyo3(get)]
208    pub tokens: Vec<TokenLogProbs>,
209}
210
211#[pymethods]
212impl ResponseLogProbs {
213    pub fn __str__(&self) -> String {
214        PyHelperFuncs::__str__(self)
215    }
216}
217
218/// Calculate a weighted score base on the log probabilities of tokens 1-5.
219pub fn calculate_weighted_score(log_probs: &[TokenLogProbs]) -> Result<Option<f64>, UtilError> {
220    let score_range = RangeInclusive::new(1, 5);
221    let mut score_probs = Vec::new();
222    let mut weighted_sum = 0.0;
223    let mut total_prob = 0.0;
224
225    for log_prob in log_probs {
226        let token = log_prob.token.parse::<u64>().ok();
227
228        if let Some(token_val) = token {
229            if score_range.contains(&token_val) {
230                let prob = log_prob.logprob.exp();
231                score_probs.push((token_val, prob));
232            }
233        }
234    }
235
236    for (score, logprob) in score_probs {
237        weighted_sum += score as f64 * logprob;
238        total_prob += logprob;
239    }
240
241    if total_prob > 0.0 {
242        Ok(Some(weighted_sum / total_prob))
243    } else {
244        Ok(None)
245    }
246}
247
248/// Generic function to convert text to a structured output model
249/// It is expected that output_model is a pydantic model or a potatohead type that implements serde json deserialization
250/// via model_validate_json method.
251/// Flow:
252/// 1. Attempt to validate the model using model_validate_json
253/// 2. If validation fails, attempt to parse the text as JSON and convert to python object
254/// # Arguments
255/// * `py` - A Python interpreter instance
256/// * `text` - The text to be converted (typically from an LLM response that returns structured output)
257/// * `output_model` - A bound python object representing the output model
258/// # Returns
259/// * `Result<Bound<'py, PyAny>, UtilError>` - A result containing the structured output or an error
260pub fn convert_text_to_structured_output<'py>(
261    py: Python<'py>,
262    text: String,
263    output_model: &Bound<'py, PyAny>,
264) -> Result<Bound<'py, PyAny>, UtilError> {
265    let output = output_model.call_method1("model_validate_json", (&text,));
266    match output {
267        Ok(obj) => {
268            // Successfully validated the model
269            Ok(obj)
270        }
271        Err(err) => {
272            // Model validation failed
273            // convert string to json and then to python object
274            warn!(
275                "Failed to validate model: {}, Attempting fallback to JSON parsing",
276                err
277            );
278            let val = serde_json::from_str::<serde_json::Value>(&text)?;
279            Ok(pythonize(py, &val)?)
280        }
281    }
282}
283
284pub fn is_pydantic_basemodel(py: Python, obj: &Bound<'_, PyAny>) -> Result<bool, UtilError> {
285    let pydantic = match py.import("pydantic") {
286        Ok(module) => module,
287        // return false if pydantic cannot be imported
288        Err(_) => return Ok(false),
289    };
290
291    let basemodel = pydantic.getattr("BaseModel")?;
292
293    // check if context is a pydantic model
294    let is_basemodel = obj
295        .is_instance(&basemodel)
296        .map_err(|e| UtilError::FailedToCheckPydanticModel(e.to_string()))?;
297
298    Ok(is_basemodel)
299}
300
301fn process_dict_with_nested_models(
302    py: Python<'_>,
303    dict: &Bound<'_, PyAny>,
304) -> Result<Value, UtilError> {
305    let py_dict = dict.cast::<PyDict>()?;
306    let mut result = serde_json::Map::new();
307
308    for (key, value) in py_dict.iter() {
309        let key_str: String = key.extract()?;
310        let processed_value = depythonize_object_to_value(py, &value)?;
311        result.insert(key_str, processed_value);
312    }
313
314    Ok(Value::Object(result))
315}
316
317pub fn depythonize_object_to_value<'py>(
318    py: Python<'py>,
319    value: &Bound<'py, PyAny>,
320) -> Result<Value, UtilError> {
321    let py_value = if is_pydantic_basemodel(py, value)? {
322        let model = value.call_method0("model_dump")?;
323        depythonize(&model)?
324    } else if value.is_instance_of::<PyDict>() {
325        process_dict_with_nested_models(py, value)?
326    } else {
327        depythonize(value)?
328    };
329    Ok(py_value)
330}
331
332/// Helper function to extract result from LLM response text
333/// If an output model is provided, it will attempt to convert the text to the structured output
334/// using the provided model. If no model is provided, it will attempt to convert the response to an appropriate
335/// Python type directly.
336/// # Arguments
337/// * `py` - A Python interpreter instance
338/// * `text` - The text to be converted (typically from an LLM response)
339/// * `output_model` - An optional bound python object representing the output model
340/// # Returns
341/// * `Result<Bound<'py, PyAny>, UtilError>` - A result containing the structured output or an error
342pub fn construct_structured_response<'py>(
343    py: Python<'py>,
344    text: String,
345    output_model: Option<&Bound<'py, PyAny>>,
346) -> Result<Bound<'py, PyAny>, UtilError> {
347    match output_model {
348        Some(model) => convert_text_to_structured_output(py, text, model),
349        None => {
350            // No output model provided, return the text as a Python string
351            let val = Value::String(text);
352            Ok(pythonize(py, &val)?)
353        }
354    }
355}
356
357#[cfg(test)]
358mod tests {
359    use super::*;
360    #[test]
361    fn test_calculate_weighted_score() {
362        let log_probs = vec![
363            TokenLogProbs {
364                token: "1".into(),
365                logprob: 0.9,
366            },
367            TokenLogProbs {
368                token: "2".into(),
369                logprob: 0.8,
370            },
371            TokenLogProbs {
372                token: "3".into(),
373                logprob: 0.7,
374            },
375        ];
376
377        let result = calculate_weighted_score(&log_probs);
378        assert!(result.is_ok());
379
380        let val = result.unwrap().unwrap();
381        // round to int
382        assert_eq!(val.round(), 2.0);
383    }
384    #[test]
385    fn test_calculate_weighted_score_empty() {
386        let log_probs: Vec<TokenLogProbs> = vec![];
387        let result = calculate_weighted_score(&log_probs);
388        assert!(result.is_ok());
389        assert_eq!(result.unwrap(), None);
390    }
391}