1use crate::error::UtilError;
2
3use colored_json::{Color, ColorMode, ColoredFormatter, PrettyFormatter, Styler};
4use pyo3::prelude::*;
5
6use pyo3::types::{
7 PyAny, PyBool, PyDict, PyDictMethods, PyFloat, PyInt, PyList, PyString, PyTuple,
8};
9use pyo3::IntoPyObjectExt;
10use serde::Serialize;
11use serde_json::json;
12use serde_json::Value;
13use serde_json::Value::{Null, Object};
14use std::ops::RangeInclusive;
15use std::path::Path;
16use uuid::Uuid;
17pub fn create_uuid7() -> String {
18 Uuid::now_v7().to_string()
19}
20
21pub struct PyHelperFuncs {}
22
23impl PyHelperFuncs {
24 pub fn __str__<T: Serialize>(object: T) -> String {
25 match ColoredFormatter::with_styler(
26 PrettyFormatter::default(),
27 Styler {
28 key: Color::Rgb(75, 57, 120).foreground(),
29 string_value: Color::Rgb(4, 205, 155).foreground(),
30 float_value: Color::Rgb(4, 205, 155).foreground(),
31 integer_value: Color::Rgb(4, 205, 155).foreground(),
32 bool_value: Color::Rgb(4, 205, 155).foreground(),
33 nil_value: Color::Rgb(4, 205, 155).foreground(),
34 ..Default::default()
35 },
36 )
37 .to_colored_json(&object, ColorMode::On)
38 {
39 Ok(json) => json,
40 Err(e) => format!("Failed to serialize to json: {e}"),
41 }
42 }
44
45 pub fn __json__<T: Serialize>(object: T) -> String {
46 match serde_json::to_string_pretty(&object) {
47 Ok(json) => json,
48 Err(e) => format!("Failed to serialize to json: {e}"),
49 }
50 }
51
52 pub fn save_to_json<T>(model: T, path: &Path) -> Result<(), UtilError>
68 where
69 T: Serialize,
70 {
71 let json =
73 serde_json::to_string_pretty(&model).map_err(|_| UtilError::SerializationError)?;
74
75 let path = path.with_extension("json");
77
78 if !path.exists() {
79 let parent_path = path.parent().ok_or(UtilError::GetParentPathError)?;
81
82 std::fs::create_dir_all(parent_path).map_err(|_| UtilError::CreateDirectoryError)?;
83 }
84
85 std::fs::write(path, json).map_err(|_| UtilError::WriteError)?;
86
87 Ok(())
88 }
89}
90
91pub fn json_to_pydict<'py>(
92 py: Python,
93 value: &Value,
94 dict: &Bound<'py, PyDict>,
95) -> Result<Bound<'py, PyDict>, UtilError> {
96 match value {
97 Value::Object(map) => {
98 for (k, v) in map {
99 let py_value = match v {
100 Value::Null => py.None(),
101 Value::Bool(b) => b.into_py_any(py)?,
102 Value::Number(n) => {
103 if let Some(i) = n.as_i64() {
104 i.into_py_any(py)?
105 } else if let Some(f) = n.as_f64() {
106 f.into_py_any(py)?
107 } else {
108 return Err(UtilError::InvalidNumber);
109 }
110 }
111 Value::String(s) => s.into_py_any(py)?,
112 Value::Array(arr) => {
113 let py_list = PyList::empty(py);
114 for item in arr {
115 let py_item = json_to_pyobject(py, item)?;
116 py_list.append(py_item)?;
117 }
118 py_list.into_py_any(py)?
119 }
120 Value::Object(_) => {
121 let nested_dict = PyDict::new(py);
122 json_to_pydict(py, v, &nested_dict)?;
123 nested_dict.into_py_any(py)?
124 }
125 };
126 dict.set_item(k, py_value)?;
127 }
128 }
129 _ => return Err(UtilError::RootMustBeObjectError),
130 }
131
132 Ok(dict.clone())
133}
134
135pub fn json_to_pyobject(py: Python, value: &Value) -> Result<Py<PyAny>, UtilError> {
150 Ok(match value {
151 Value::Null => py.None(),
152 Value::Bool(b) => b.into_py_any(py)?,
153 Value::Number(n) => {
154 if let Some(i) = n.as_i64() {
155 i.into_py_any(py)?
156 } else if let Some(f) = n.as_f64() {
157 f.into_py_any(py)?
158 } else {
159 return Err(UtilError::InvalidNumber);
160 }
161 }
162 Value::String(s) => s.into_py_any(py)?,
163 Value::Array(arr) => {
164 let py_list = PyList::empty(py);
165 for item in arr {
166 let py_item = json_to_pyobject(py, item)?;
167 py_list.append(py_item)?;
168 }
169 py_list.into_py_any(py)?
170 }
171 Value::Object(_) => {
172 let nested_dict = PyDict::new(py);
173 json_to_pydict(py, value, &nested_dict)?;
174 nested_dict.into_py_any(py)?
175 }
176 })
177}
178
179pub fn vec_to_py_object<'py>(
180 py: Python<'py>,
181 vec: &Vec<Value>,
182) -> Result<Bound<'py, PyList>, UtilError> {
183 let py_list = PyList::empty(py);
184 for item in vec {
185 let py_item = json_to_pyobject(py, item)?;
186 py_list.append(py_item)?;
187 }
188 Ok(py_list)
189}
190
191pub fn pyobject_to_json(obj: &Bound<'_, PyAny>) -> Result<Value, UtilError> {
192 if obj.is_instance_of::<PyDict>() {
193 let dict = obj.downcast::<PyDict>()?;
194 let mut map = serde_json::Map::new();
195 for (key, value) in dict.iter() {
196 let key_str = key.extract::<String>()?;
197 let json_value = pyobject_to_json(&value)?;
198 map.insert(key_str, json_value);
199 }
200 Ok(Value::Object(map))
201 } else if obj.is_instance_of::<PyList>() {
202 let list = obj.downcast::<PyList>()?;
203 let mut vec = Vec::new();
204 for item in list.iter() {
205 vec.push(pyobject_to_json(&item)?);
206 }
207 Ok(Value::Array(vec))
208 } else if obj.is_instance_of::<PyTuple>() {
209 let tuple = obj.downcast::<PyTuple>()?;
210 let mut vec = Vec::new();
211 for item in tuple.iter() {
212 vec.push(pyobject_to_json(&item)?);
213 }
214 Ok(Value::Array(vec))
215 } else if obj.is_instance_of::<PyString>() {
216 let s = obj.extract::<String>()?;
217 Ok(Value::String(s))
218 } else if obj.is_instance_of::<PyFloat>() {
219 let f = obj.extract::<f64>()?;
220 Ok(json!(f))
221 } else if obj.is_instance_of::<PyBool>() {
222 let b = obj.extract::<bool>()?;
223 Ok(json!(b))
224 } else if obj.is_instance_of::<PyInt>() {
225 let i = obj.extract::<i64>()?;
226 Ok(json!(i))
227 } else if obj.is_none() {
228 Ok(Value::Null)
229 } else {
230 let obj_str = match obj.str() {
234 Ok(s) => s
235 .extract::<String>()
236 .unwrap_or_else(|_| "unsupported type".to_string()),
237 Err(_) => "unsupported type".to_string(),
238 };
239
240 Ok(Value::String(obj_str))
241 }
242}
243
244pub fn version() -> String {
245 env!("CARGO_PKG_VERSION").to_string()
246}
247
248pub fn update_serde_value(value: &mut Value, key: &str, new_value: Value) -> Result<(), UtilError> {
249 if let Value::Object(map) = value {
250 map.insert(key.to_string(), new_value);
251 Ok(())
252 } else {
253 Err(UtilError::RootMustBeObjectError)
254 }
255}
256
257pub fn update_serde_map_with(
268 dest: &mut serde_json::Value,
269 src: &serde_json::Value,
270) -> Result<(), UtilError> {
271 match (dest, src) {
272 (&mut Object(ref mut map_dest), Object(ref map_src)) => {
273 for (key, value) in map_src {
275 *map_dest.entry(key.clone()).or_insert(Null) = value.clone();
278 }
279 Ok(())
280 }
281 (_, _) => Err(UtilError::RootMustBeObjectError),
282 }
283}
284
285pub fn extract_string_value(py_value: &Bound<'_, PyAny>) -> Result<String, UtilError> {
287 if let Ok(string_val) = py_value.extract::<String>() {
289 return Ok(string_val);
290 }
291
292 if let Ok(int_val) = py_value.extract::<i64>() {
294 return Ok(int_val.to_string());
295 }
296
297 if let Ok(float_val) = py_value.extract::<f64>() {
299 return Ok(float_val.to_string());
300 }
301
302 if let Ok(bool_val) = py_value.extract::<bool>() {
304 return Ok(bool_val.to_string());
305 }
306
307 let json_value = pyobject_to_json(py_value)?;
309
310 match json_value {
311 Value::String(s) => Ok(s),
312 Value::Number(n) => Ok(n.to_string()),
313 Value::Bool(b) => Ok(b.to_string()),
314 Value::Null => Ok("null".to_string()),
315 _ => {
316 let json_string = serde_json::to_string(&json_value)?;
318 Ok(json_string)
319 }
320 }
321}
322
323#[pyclass]
324#[derive(Debug, Serialize, Clone)]
325pub struct ResponseLogProbs {
326 #[pyo3(get)]
327 pub token: String,
328
329 #[pyo3(get)]
330 pub logprob: f64,
331}
332
333#[pyclass]
334#[derive(Debug, Serialize, Clone)]
335pub struct LogProbs {
336 #[pyo3(get)]
337 pub tokens: Vec<ResponseLogProbs>,
338}
339
340#[pymethods]
341impl LogProbs {
342 pub fn __str__(&self) -> String {
343 PyHelperFuncs::__str__(self)
344 }
345}
346
347pub fn calculate_weighted_score(log_probs: &[ResponseLogProbs]) -> Result<Option<f64>, UtilError> {
349 let score_range = RangeInclusive::new(1, 5);
350 let mut score_probs = Vec::new();
351 let mut weighted_sum = 0.0;
352 let mut total_prob = 0.0;
353
354 for log_prob in log_probs {
355 let token = log_prob.token.parse::<u64>().ok();
356
357 if let Some(token_val) = token {
358 if score_range.contains(&token_val) {
359 let prob = log_prob.logprob.exp();
360 score_probs.push((token_val, prob));
361 }
362 }
363 }
364
365 for (score, logprob) in score_probs {
366 weighted_sum += score as f64 * logprob;
367 total_prob += logprob;
368 }
369
370 if total_prob > 0.0 {
371 Ok(Some(weighted_sum / total_prob))
372 } else {
373 Ok(None)
374 }
375}
376
377#[cfg(test)]
378mod tests {
379 use super::*;
380 #[test]
381 fn test_calculate_weighted_score() {
382 let log_probs = vec![
383 ResponseLogProbs {
384 token: "1".into(),
385 logprob: 0.9,
386 },
387 ResponseLogProbs {
388 token: "2".into(),
389 logprob: 0.8,
390 },
391 ResponseLogProbs {
392 token: "3".into(),
393 logprob: 0.7,
394 },
395 ];
396
397 let result = calculate_weighted_score(&log_probs);
398 assert!(result.is_ok());
399
400 let val = result.unwrap().unwrap();
401 assert_eq!(val.round(), 2.0);
403 }
404 #[test]
405 fn test_calculate_weighted_score_empty() {
406 let log_probs: Vec<ResponseLogProbs> = vec![];
407 let result = calculate_weighted_score(&log_probs);
408 assert!(result.is_ok());
409 assert_eq!(result.unwrap(), None);
410 }
411}