use crate::error::UtilError;
use colored_json::{Color, ColorMode, ColoredFormatter, PrettyFormatter, Styler};
use pyo3::prelude::*;
use pyo3::types::{PyAny, PyDict, PyList};
use pyo3::IntoPyObjectExt;
use serde::Serialize;
use serde_json::Value;
use serde_json::Value::{Null, Object};
use std::ops::RangeInclusive;
use std::path::Path;
use uuid::Uuid;
pub fn create_uuid7() -> String {
Uuid::now_v7().to_string()
}
use pythonize::{depythonize, pythonize};
use tracing::warn;
pub struct PyHelperFuncs {}
impl PyHelperFuncs {
pub fn to_bound_py_object<'py, T>(
py: Python<'py>,
object: &T,
) -> Result<Bound<'py, PyAny>, UtilError>
where
T: IntoPyObject<'py> + Clone,
{
Ok(object.clone().into_bound_py_any(py)?)
}
pub fn __str__<T: Serialize>(object: T) -> String {
match ColoredFormatter::with_styler(
PrettyFormatter::default(),
Styler {
key: Color::Rgb(75, 57, 120).foreground(),
string_value: Color::Rgb(4, 205, 155).foreground(),
float_value: Color::Rgb(4, 205, 155).foreground(),
integer_value: Color::Rgb(4, 205, 155).foreground(),
bool_value: Color::Rgb(4, 205, 155).foreground(),
nil_value: Color::Rgb(4, 205, 155).foreground(),
..Default::default()
},
)
.to_colored_json(&object, ColorMode::On)
{
Ok(json) => json,
Err(e) => format!("Failed to serialize to json: {e}"),
}
}
pub fn __json__<T: Serialize>(object: T) -> String {
match serde_json::to_string_pretty(&object) {
Ok(json) => json,
Err(e) => format!("Failed to serialize to json: {e}"),
}
}
pub fn save_to_json<T>(model: T, path: &Path) -> Result<(), UtilError>
where
T: Serialize,
{
let json =
serde_json::to_string_pretty(&model).map_err(|_| UtilError::SerializationError)?;
let path = path.with_extension("json");
if !path.exists() {
let parent_path = path.parent().ok_or(UtilError::GetParentPathError)?;
std::fs::create_dir_all(parent_path).map_err(|_| UtilError::CreateDirectoryError)?;
}
std::fs::write(path, json).map_err(|_| UtilError::WriteError)?;
Ok(())
}
}
pub fn vec_to_py_object<'py>(
py: Python<'py>,
vec: &Vec<Value>,
) -> Result<Bound<'py, PyList>, UtilError> {
let py_list = PyList::empty(py);
for item in vec {
let py_item = pythonize(py, item)?;
py_list.append(py_item)?;
}
Ok(py_list)
}
pub fn version() -> String {
env!("CARGO_PKG_VERSION").to_string()
}
pub fn update_serde_value(value: &mut Value, key: &str, new_value: Value) -> Result<(), UtilError> {
if let Value::Object(map) = value {
map.insert(key.to_string(), new_value);
Ok(())
} else {
Err(UtilError::RootMustBeObjectError)
}
}
pub fn update_serde_map_with(
dest: &mut serde_json::Value,
src: &serde_json::Value,
) -> Result<(), UtilError> {
match (dest, src) {
(&mut Object(ref mut map_dest), Object(ref map_src)) => {
for (key, value) in map_src {
*map_dest.entry(key.clone()).or_insert(Null) = value.clone();
}
Ok(())
}
(_, _) => Err(UtilError::RootMustBeObjectError),
}
}
pub fn extract_string_value(py_value: &Bound<'_, PyAny>) -> Result<String, UtilError> {
if let Ok(string_val) = py_value.extract::<String>() {
return Ok(string_val);
}
if let Ok(bool_val) = py_value.extract::<bool>() {
return Ok(bool_val.to_string());
}
if let Ok(int_val) = py_value.extract::<i64>() {
return Ok(int_val.to_string());
}
if let Ok(float_val) = py_value.extract::<f64>() {
return Ok(float_val.to_string());
}
let json_value = depythonize(py_value)?;
match json_value {
Value::String(s) => Ok(s),
Value::Number(n) => Ok(n.to_string()),
Value::Bool(b) => Ok(b.to_string()),
Value::Null => Ok("null".to_string()),
_ => {
let json_string = serde_json::to_string(&json_value)?;
Ok(json_string)
}
}
}
#[pyclass(from_py_object)]
#[derive(Debug, Serialize, Clone)]
pub struct TokenLogProbs {
#[pyo3(get)]
pub token: String,
#[pyo3(get)]
pub logprob: f64,
}
#[pyclass(from_py_object)]
#[derive(Debug, Serialize, Clone)]
pub struct ResponseLogProbs {
#[pyo3(get)]
pub tokens: Vec<TokenLogProbs>,
}
#[pymethods]
impl ResponseLogProbs {
pub fn __str__(&self) -> String {
PyHelperFuncs::__str__(self)
}
}
pub fn calculate_weighted_score(log_probs: &[TokenLogProbs]) -> Result<Option<f64>, UtilError> {
let score_range = RangeInclusive::new(1, 5);
let mut score_probs = Vec::new();
let mut weighted_sum = 0.0;
let mut total_prob = 0.0;
for log_prob in log_probs {
let token = log_prob.token.parse::<u64>().ok();
if let Some(token_val) = token {
if score_range.contains(&token_val) {
let prob = log_prob.logprob.exp();
score_probs.push((token_val, prob));
}
}
}
for (score, logprob) in score_probs {
weighted_sum += score as f64 * logprob;
total_prob += logprob;
}
if total_prob > 0.0 {
Ok(Some(weighted_sum / total_prob))
} else {
Ok(None)
}
}
pub fn convert_text_to_structured_output<'py>(
py: Python<'py>,
text: String,
output_model: &Bound<'py, PyAny>,
) -> Result<Bound<'py, PyAny>, UtilError> {
let output = output_model.call_method1("model_validate_json", (&text,));
match output {
Ok(obj) => {
Ok(obj)
}
Err(err) => {
warn!(
"Failed to validate model: {}, Attempting fallback to JSON parsing",
err
);
let val = serde_json::from_str::<serde_json::Value>(&text)?;
Ok(pythonize(py, &val)?)
}
}
}
pub fn is_pydantic_basemodel(py: Python, obj: &Bound<'_, PyAny>) -> Result<bool, UtilError> {
let pydantic = match py.import("pydantic") {
Ok(module) => module,
Err(_) => return Ok(false),
};
let basemodel = pydantic.getattr("BaseModel")?;
let is_basemodel = obj
.is_instance(&basemodel)
.map_err(|e| UtilError::FailedToCheckPydanticModel(e.to_string()))?;
Ok(is_basemodel)
}
fn process_dict_with_nested_models(
py: Python<'_>,
dict: &Bound<'_, PyAny>,
) -> Result<Value, UtilError> {
let py_dict = dict.cast::<PyDict>()?;
let mut result = serde_json::Map::new();
for (key, value) in py_dict.iter() {
let key_str: String = key.extract()?;
let processed_value = depythonize_object_to_value(py, &value)?;
result.insert(key_str, processed_value);
}
Ok(Value::Object(result))
}
pub fn depythonize_object_to_value<'py>(
py: Python<'py>,
value: &Bound<'py, PyAny>,
) -> Result<Value, UtilError> {
let py_value = if is_pydantic_basemodel(py, value)? {
let model = value.call_method0("model_dump")?;
depythonize(&model)?
} else if value.is_instance_of::<PyDict>() {
process_dict_with_nested_models(py, value)?
} else {
depythonize(value)?
};
Ok(py_value)
}
pub fn construct_structured_response<'py>(
py: Python<'py>,
text: String,
output_model: Option<&Bound<'py, PyAny>>,
) -> Result<Bound<'py, PyAny>, UtilError> {
match output_model {
Some(model) => convert_text_to_structured_output(py, text, model),
None => {
let val = Value::String(text);
Ok(pythonize(py, &val)?)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_calculate_weighted_score() {
let log_probs = vec![
TokenLogProbs {
token: "1".into(),
logprob: 0.9,
},
TokenLogProbs {
token: "2".into(),
logprob: 0.8,
},
TokenLogProbs {
token: "3".into(),
logprob: 0.7,
},
];
let result = calculate_weighted_score(&log_probs);
assert!(result.is_ok());
let val = result.unwrap().unwrap();
assert_eq!(val.round(), 2.0);
}
#[test]
fn test_calculate_weighted_score_empty() {
let log_probs: Vec<TokenLogProbs> = vec![];
let result = calculate_weighted_score(&log_probs);
assert!(result.is_ok());
assert_eq!(result.unwrap(), None);
}
}