1use crate::error::UtilError;
2
3use colored_json::{Color, ColorMode, ColoredFormatter, PrettyFormatter, Styler};
4use pyo3::prelude::*;
5
6use pyo3::types::{
7 PyAny, PyBool, PyDict, PyDictMethods, PyFloat, PyInt, PyList, PyString, PyTuple,
8};
9use pyo3::IntoPyObjectExt;
10use serde::Serialize;
11use serde_json::json;
12use serde_json::Value;
13use serde_json::Value::{Null, Object};
14use std::ops::RangeInclusive;
15use std::path::Path;
16use uuid::Uuid;
17pub fn create_uuid7() -> String {
18 Uuid::now_v7().to_string()
19}
20
21pub struct PyHelperFuncs {}
22
23impl PyHelperFuncs {
24 pub fn __str__<T: Serialize>(object: T) -> String {
25 match ColoredFormatter::with_styler(
26 PrettyFormatter::default(),
27 Styler {
28 key: Color::Rgb(75, 57, 120).foreground(),
29 string_value: Color::Rgb(4, 205, 155).foreground(),
30 float_value: Color::Rgb(4, 205, 155).foreground(),
31 integer_value: Color::Rgb(4, 205, 155).foreground(),
32 bool_value: Color::Rgb(4, 205, 155).foreground(),
33 nil_value: Color::Rgb(4, 205, 155).foreground(),
34 ..Default::default()
35 },
36 )
37 .to_colored_json(&object, ColorMode::On)
38 {
39 Ok(json) => json,
40 Err(e) => format!("Failed to serialize to json: {e}"),
41 }
42 }
44
45 pub fn __json__<T: Serialize>(object: T) -> String {
46 match serde_json::to_string_pretty(&object) {
47 Ok(json) => json,
48 Err(e) => format!("Failed to serialize to json: {e}"),
49 }
50 }
51
52 pub fn save_to_json<T>(model: T, path: &Path) -> Result<(), UtilError>
68 where
69 T: Serialize,
70 {
71 let json =
73 serde_json::to_string_pretty(&model).map_err(|_| UtilError::SerializationError)?;
74
75 let path = path.with_extension("json");
77
78 if !path.exists() {
79 let parent_path = path.parent().ok_or(UtilError::GetParentPathError)?;
81
82 std::fs::create_dir_all(parent_path).map_err(|_| UtilError::CreateDirectoryError)?;
83 }
84
85 std::fs::write(path, json).map_err(|_| UtilError::WriteError)?;
86
87 Ok(())
88 }
89}
90
91pub fn json_to_pydict<'py>(
92 py: Python,
93 value: &Value,
94 dict: &Bound<'py, PyDict>,
95) -> Result<Bound<'py, PyDict>, UtilError> {
96 match value {
97 Value::Object(map) => {
98 for (k, v) in map {
99 let py_value = match v {
100 Value::Null => py.None(),
101 Value::Bool(b) => b.into_py_any(py)?,
102 Value::Number(n) => {
103 if let Some(i) = n.as_i64() {
104 i.into_py_any(py)?
105 } else if let Some(f) = n.as_f64() {
106 f.into_py_any(py)?
107 } else {
108 return Err(UtilError::InvalidNumber);
109 }
110 }
111 Value::String(s) => s.into_py_any(py)?,
112 Value::Array(arr) => {
113 let py_list = PyList::empty(py);
114 for item in arr {
115 let py_item = json_to_pyobject(py, item)?;
116 py_list.append(py_item)?;
117 }
118 py_list.into_py_any(py)?
119 }
120 Value::Object(_) => {
121 let nested_dict = PyDict::new(py);
122 json_to_pydict(py, v, &nested_dict)?;
123 nested_dict.into_py_any(py)?
124 }
125 };
126 dict.set_item(k, py_value)?;
127 }
128 }
129 _ => return Err(UtilError::RootMustBeObjectError),
130 }
131
132 Ok(dict.clone())
133}
134
135pub fn json_to_pyobject(py: Python, value: &Value) -> Result<PyObject, UtilError> {
150 Ok(match value {
151 Value::Null => py.None(),
152 Value::Bool(b) => b.into_py_any(py)?,
153 Value::Number(n) => {
154 if let Some(i) = n.as_i64() {
155 i.into_py_any(py)?
156 } else if let Some(f) = n.as_f64() {
157 f.into_py_any(py)?
158 } else {
159 return Err(UtilError::InvalidNumber);
160 }
161 }
162 Value::String(s) => s.into_py_any(py)?,
163 Value::Array(arr) => {
164 let py_list = PyList::empty(py);
165 for item in arr {
166 let py_item = json_to_pyobject(py, item)?;
167 py_list.append(py_item)?;
168 }
169 py_list.into_py_any(py)?
170 }
171 Value::Object(_) => {
172 let nested_dict = PyDict::new(py);
173 json_to_pydict(py, value, &nested_dict)?;
174 nested_dict.into_py_any(py)?
175 }
176 })
177}
178
179pub fn pyobject_to_json(obj: &Bound<'_, PyAny>) -> Result<Value, UtilError> {
180 if obj.is_instance_of::<PyDict>() {
181 let dict = obj.downcast::<PyDict>()?;
182 let mut map = serde_json::Map::new();
183 for (key, value) in dict.iter() {
184 let key_str = key.extract::<String>()?;
185 let json_value = pyobject_to_json(&value)?;
186 map.insert(key_str, json_value);
187 }
188 Ok(Value::Object(map))
189 } else if obj.is_instance_of::<PyList>() {
190 let list = obj.downcast::<PyList>()?;
191 let mut vec = Vec::new();
192 for item in list.iter() {
193 vec.push(pyobject_to_json(&item)?);
194 }
195 Ok(Value::Array(vec))
196 } else if obj.is_instance_of::<PyTuple>() {
197 let tuple = obj.downcast::<PyTuple>()?;
198 let mut vec = Vec::new();
199 for item in tuple.iter() {
200 vec.push(pyobject_to_json(&item)?);
201 }
202 Ok(Value::Array(vec))
203 } else if obj.is_instance_of::<PyString>() {
204 let s = obj.extract::<String>()?;
205 Ok(Value::String(s))
206 } else if obj.is_instance_of::<PyFloat>() {
207 let f = obj.extract::<f64>()?;
208 Ok(json!(f))
209 } else if obj.is_instance_of::<PyBool>() {
210 let b = obj.extract::<bool>()?;
211 Ok(json!(b))
212 } else if obj.is_instance_of::<PyInt>() {
213 let i = obj.extract::<i64>()?;
214 Ok(json!(i))
215 } else if obj.is_none() {
216 Ok(Value::Null)
217 } else {
218 let obj_str = match obj.str() {
222 Ok(s) => s
223 .extract::<String>()
224 .unwrap_or_else(|_| "unsupported type".to_string()),
225 Err(_) => "unsupported type".to_string(),
226 };
227
228 Ok(Value::String(obj_str))
229 }
230}
231
232pub fn version() -> String {
233 env!("CARGO_PKG_VERSION").to_string()
234}
235
236pub fn update_serde_value(value: &mut Value, key: &str, new_value: Value) -> Result<(), UtilError> {
237 if let Value::Object(map) = value {
238 map.insert(key.to_string(), new_value);
239 Ok(())
240 } else {
241 Err(UtilError::RootMustBeObjectError)
242 }
243}
244
245pub fn update_serde_map_with(
256 dest: &mut serde_json::Value,
257 src: &serde_json::Value,
258) -> Result<(), UtilError> {
259 match (dest, src) {
260 (&mut Object(ref mut map_dest), Object(ref map_src)) => {
261 for (key, value) in map_src {
263 *map_dest.entry(key.clone()).or_insert(Null) = value.clone();
266 }
267 Ok(())
268 }
269 (_, _) => Err(UtilError::RootMustBeObjectError),
270 }
271}
272
273pub fn extract_string_value(py_value: &Bound<'_, PyAny>) -> Result<String, UtilError> {
275 if let Ok(string_val) = py_value.extract::<String>() {
277 return Ok(string_val);
278 }
279
280 if let Ok(int_val) = py_value.extract::<i64>() {
282 return Ok(int_val.to_string());
283 }
284
285 if let Ok(float_val) = py_value.extract::<f64>() {
287 return Ok(float_val.to_string());
288 }
289
290 if let Ok(bool_val) = py_value.extract::<bool>() {
292 return Ok(bool_val.to_string());
293 }
294
295 let json_value = pyobject_to_json(py_value)?;
297
298 match json_value {
299 Value::String(s) => Ok(s),
300 Value::Number(n) => Ok(n.to_string()),
301 Value::Bool(b) => Ok(b.to_string()),
302 Value::Null => Ok("null".to_string()),
303 _ => {
304 let json_string = serde_json::to_string(&json_value)?;
306 Ok(json_string)
307 }
308 }
309}
310
311#[pyclass]
312#[derive(Debug, Serialize, Clone)]
313pub struct ResponseLogProbs {
314 #[pyo3(get)]
315 pub token: String,
316
317 #[pyo3(get)]
318 pub logprob: f64,
319}
320
321#[pyclass]
322#[derive(Debug, Serialize, Clone)]
323pub struct LogProbs {
324 #[pyo3(get)]
325 pub tokens: Vec<ResponseLogProbs>,
326}
327
328#[pymethods]
329impl LogProbs {
330 pub fn __str__(&self) -> String {
331 PyHelperFuncs::__str__(self)
332 }
333}
334
335pub fn calculate_weighted_score(log_probs: &[ResponseLogProbs]) -> Result<Option<f64>, UtilError> {
337 let score_range = RangeInclusive::new(1, 5);
338 let mut score_probs = Vec::new();
339 let mut weighted_sum = 0.0;
340 let mut total_prob = 0.0;
341
342 for log_prob in log_probs {
343 let token = log_prob.token.parse::<u64>().ok();
344
345 if let Some(token_val) = token {
346 if score_range.contains(&token_val) {
347 let prob = log_prob.logprob.exp();
348 score_probs.push((token_val, prob));
349 }
350 }
351 }
352
353 for (score, logprob) in score_probs {
354 weighted_sum += score as f64 * logprob;
355 total_prob += logprob;
356 }
357
358 if total_prob > 0.0 {
359 Ok(Some(weighted_sum / total_prob))
360 } else {
361 Ok(None)
362 }
363}
364
365#[cfg(test)]
366mod tests {
367 use super::*;
368 #[test]
369 fn test_calculate_weighted_score() {
370 let log_probs = vec![
371 ResponseLogProbs {
372 token: "1".into(),
373 logprob: 0.9,
374 },
375 ResponseLogProbs {
376 token: "2".into(),
377 logprob: 0.8,
378 },
379 ResponseLogProbs {
380 token: "3".into(),
381 logprob: 0.7,
382 },
383 ];
384
385 let result = calculate_weighted_score(&log_probs);
386 assert!(result.is_ok());
387
388 let val = result.unwrap().unwrap();
389 assert_eq!(val.round(), 2.0);
391 }
392 #[test]
393 fn test_calculate_weighted_score_empty() {
394 let log_probs: Vec<ResponseLogProbs> = vec![];
395 let result = calculate_weighted_score(&log_probs);
396 assert!(result.is_ok());
397 assert_eq!(result.unwrap(), None);
398 }
399}