1use crate::error::UtilError;
2
3use colored_json::{Color, ColorMode, ColoredFormatter, PrettyFormatter, Styler};
4use pyo3::prelude::*;
5
6use pyo3::types::{PyAny, PyDict, PyList};
7use pyo3::IntoPyObjectExt;
8use serde::Serialize;
9use serde_json::Value;
10use serde_json::Value::{Null, Object};
11use std::ops::RangeInclusive;
12use std::path::Path;
13use uuid::Uuid;
14pub fn create_uuid7() -> String {
15 Uuid::now_v7().to_string()
16}
17use pythonize::{depythonize, pythonize};
18use tracing::warn;
19pub struct PyHelperFuncs {}
20
21impl PyHelperFuncs {
22 pub fn to_bound_py_object<'py, T>(
29 py: Python<'py>,
30 object: &T,
31 ) -> Result<Bound<'py, PyAny>, UtilError>
32 where
33 T: IntoPyObject<'py> + Clone,
34 {
35 Ok(object.clone().into_bound_py_any(py)?)
36 }
37 pub fn __str__<T: Serialize>(object: T) -> String {
38 match ColoredFormatter::with_styler(
39 PrettyFormatter::default(),
40 Styler {
41 key: Color::Rgb(75, 57, 120).foreground(),
42 string_value: Color::Rgb(4, 205, 155).foreground(),
43 float_value: Color::Rgb(4, 205, 155).foreground(),
44 integer_value: Color::Rgb(4, 205, 155).foreground(),
45 bool_value: Color::Rgb(4, 205, 155).foreground(),
46 nil_value: Color::Rgb(4, 205, 155).foreground(),
47 ..Default::default()
48 },
49 )
50 .to_colored_json(&object, ColorMode::On)
51 {
52 Ok(json) => json,
53 Err(e) => format!("Failed to serialize to json: {e}"),
54 }
55 }
57
58 pub fn __json__<T: Serialize>(object: T) -> String {
59 match serde_json::to_string_pretty(&object) {
60 Ok(json) => json,
61 Err(e) => format!("Failed to serialize to json: {e}"),
62 }
63 }
64
65 pub fn save_to_json<T>(model: T, path: &Path) -> Result<(), UtilError>
81 where
82 T: Serialize,
83 {
84 let json =
86 serde_json::to_string_pretty(&model).map_err(|_| UtilError::SerializationError)?;
87
88 let path = path.with_extension("json");
90
91 if !path.exists() {
92 let parent_path = path.parent().ok_or(UtilError::GetParentPathError)?;
94
95 std::fs::create_dir_all(parent_path).map_err(|_| UtilError::CreateDirectoryError)?;
96 }
97
98 std::fs::write(path, json).map_err(|_| UtilError::WriteError)?;
99
100 Ok(())
101 }
102}
103
104pub fn vec_to_py_object<'py>(
105 py: Python<'py>,
106 vec: &Vec<Value>,
107) -> Result<Bound<'py, PyList>, UtilError> {
108 let py_list = PyList::empty(py);
109 for item in vec {
110 let py_item = pythonize(py, item)?;
111 py_list.append(py_item)?;
112 }
113 Ok(py_list)
114}
115
116pub fn version() -> String {
117 env!("CARGO_PKG_VERSION").to_string()
118}
119
120pub fn update_serde_value(value: &mut Value, key: &str, new_value: Value) -> Result<(), UtilError> {
121 if let Value::Object(map) = value {
122 map.insert(key.to_string(), new_value);
123 Ok(())
124 } else {
125 Err(UtilError::RootMustBeObjectError)
126 }
127}
128
129pub fn update_serde_map_with(
140 dest: &mut serde_json::Value,
141 src: &serde_json::Value,
142) -> Result<(), UtilError> {
143 match (dest, src) {
144 (&mut Object(ref mut map_dest), Object(ref map_src)) => {
145 for (key, value) in map_src {
147 *map_dest.entry(key.clone()).or_insert(Null) = value.clone();
150 }
151 Ok(())
152 }
153 (_, _) => Err(UtilError::RootMustBeObjectError),
154 }
155}
156
157pub fn extract_string_value(py_value: &Bound<'_, PyAny>) -> Result<String, UtilError> {
159 if let Ok(string_val) = py_value.extract::<String>() {
161 return Ok(string_val);
162 }
163 if let Ok(bool_val) = py_value.extract::<bool>() {
165 return Ok(bool_val.to_string());
166 }
167
168 if let Ok(int_val) = py_value.extract::<i64>() {
170 return Ok(int_val.to_string());
171 }
172
173 if let Ok(float_val) = py_value.extract::<f64>() {
175 return Ok(float_val.to_string());
176 }
177
178 let json_value = depythonize(py_value)?;
180
181 match json_value {
182 Value::String(s) => Ok(s),
183 Value::Number(n) => Ok(n.to_string()),
184 Value::Bool(b) => Ok(b.to_string()),
185 Value::Null => Ok("null".to_string()),
186 _ => {
187 let json_string = serde_json::to_string(&json_value)?;
189 Ok(json_string)
190 }
191 }
192}
193
194#[pyclass]
195#[derive(Debug, Serialize, Clone)]
196pub struct TokenLogProbs {
197 #[pyo3(get)]
198 pub token: String,
199
200 #[pyo3(get)]
201 pub logprob: f64,
202}
203
204#[pyclass]
205#[derive(Debug, Serialize, Clone)]
206pub struct ResponseLogProbs {
207 #[pyo3(get)]
208 pub tokens: Vec<TokenLogProbs>,
209}
210
211#[pymethods]
212impl ResponseLogProbs {
213 pub fn __str__(&self) -> String {
214 PyHelperFuncs::__str__(self)
215 }
216}
217
218pub fn calculate_weighted_score(log_probs: &[TokenLogProbs]) -> Result<Option<f64>, UtilError> {
220 let score_range = RangeInclusive::new(1, 5);
221 let mut score_probs = Vec::new();
222 let mut weighted_sum = 0.0;
223 let mut total_prob = 0.0;
224
225 for log_prob in log_probs {
226 let token = log_prob.token.parse::<u64>().ok();
227
228 if let Some(token_val) = token {
229 if score_range.contains(&token_val) {
230 let prob = log_prob.logprob.exp();
231 score_probs.push((token_val, prob));
232 }
233 }
234 }
235
236 for (score, logprob) in score_probs {
237 weighted_sum += score as f64 * logprob;
238 total_prob += logprob;
239 }
240
241 if total_prob > 0.0 {
242 Ok(Some(weighted_sum / total_prob))
243 } else {
244 Ok(None)
245 }
246}
247
248pub fn convert_text_to_structured_output<'py>(
261 py: Python<'py>,
262 text: String,
263 output_model: &Bound<'py, PyAny>,
264) -> Result<Bound<'py, PyAny>, UtilError> {
265 let output = output_model.call_method1("model_validate_json", (&text,));
266 match output {
267 Ok(obj) => {
268 Ok(obj)
270 }
271 Err(err) => {
272 warn!(
275 "Failed to validate model: {}, Attempting fallback to JSON parsing",
276 err
277 );
278 let val = serde_json::from_str::<serde_json::Value>(&text)?;
279 Ok(pythonize(py, &val)?)
280 }
281 }
282}
283
284pub fn is_pydantic_basemodel(py: Python, obj: &Bound<'_, PyAny>) -> Result<bool, UtilError> {
285 let pydantic = match py.import("pydantic") {
286 Ok(module) => module,
287 Err(_) => return Ok(false),
289 };
290
291 let basemodel = pydantic.getattr("BaseModel")?;
292
293 let is_basemodel = obj
295 .is_instance(&basemodel)
296 .map_err(|e| UtilError::FailedToCheckPydanticModel(e.to_string()))?;
297
298 Ok(is_basemodel)
299}
300
301fn process_dict_with_nested_models(
302 py: Python<'_>,
303 dict: &Bound<'_, PyAny>,
304) -> Result<Value, UtilError> {
305 let py_dict = dict.cast::<PyDict>()?;
306 let mut result = serde_json::Map::new();
307
308 for (key, value) in py_dict.iter() {
309 let key_str: String = key.extract()?;
310 let processed_value = depythonize_object_to_value(py, &value)?;
311 result.insert(key_str, processed_value);
312 }
313
314 Ok(Value::Object(result))
315}
316
317pub fn depythonize_object_to_value<'py>(
318 py: Python<'py>,
319 value: &Bound<'py, PyAny>,
320) -> Result<Value, UtilError> {
321 let py_value = if is_pydantic_basemodel(py, value)? {
322 let model = value.call_method0("model_dump")?;
323 depythonize(&model)?
324 } else if value.is_instance_of::<PyDict>() {
325 process_dict_with_nested_models(py, value)?
326 } else {
327 depythonize(value)?
328 };
329 Ok(py_value)
330}
331
332pub fn construct_structured_response<'py>(
343 py: Python<'py>,
344 text: String,
345 output_model: Option<&Bound<'py, PyAny>>,
346) -> Result<Bound<'py, PyAny>, UtilError> {
347 match output_model {
348 Some(model) => convert_text_to_structured_output(py, text, model),
349 None => {
350 let val = Value::String(text);
352 Ok(pythonize(py, &val)?)
353 }
354 }
355}
356
357#[cfg(test)]
358mod tests {
359 use super::*;
360 #[test]
361 fn test_calculate_weighted_score() {
362 let log_probs = vec![
363 TokenLogProbs {
364 token: "1".into(),
365 logprob: 0.9,
366 },
367 TokenLogProbs {
368 token: "2".into(),
369 logprob: 0.8,
370 },
371 TokenLogProbs {
372 token: "3".into(),
373 logprob: 0.7,
374 },
375 ];
376
377 let result = calculate_weighted_score(&log_probs);
378 assert!(result.is_ok());
379
380 let val = result.unwrap().unwrap();
381 assert_eq!(val.round(), 2.0);
383 }
384 #[test]
385 fn test_calculate_weighted_score_empty() {
386 let log_probs: Vec<TokenLogProbs> = vec![];
387 let result = calculate_weighted_score(&log_probs);
388 assert!(result.is_ok());
389 assert_eq!(result.unwrap(), None);
390 }
391}