1use regex::Regex;
33use serde::{Deserialize, Serialize};
34use serde_json::Value;
35use std::collections::HashMap;
36use std::fmt;
37use std::sync::Arc;
38
39use crate::error::{ClientError, Result};
40
41#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct ValueTransform {
44 field: String,
46 transform: TransformType,
48}
49
50impl ValueTransform {
51 pub fn new(field: impl Into<String>, transform: TransformType) -> Self {
53 Self {
54 field: field.into(),
55 transform,
56 }
57 }
58
59 pub fn field(&self) -> &str {
61 &self.field
62 }
63
64 pub fn transform(&self) -> &TransformType {
66 &self.transform
67 }
68
69 pub fn apply(&self, root: &mut Value) -> Result<()> {
71 if let Some(value) = get_value_mut(root, &self.field) {
72 *value = self.transform.apply(value.take())?;
73 }
74 Ok(())
75 }
76}
77
78#[derive(Clone, Serialize, Deserialize)]
80#[serde(rename_all = "snake_case")]
81pub enum TransformType {
82 Map(HashMap<String, String>),
87
88 Regex {
93 pattern: String,
94 replacement: String,
95 },
96
97 Lowercase,
99
100 Uppercase,
102
103 ParseInt,
107
108 ParseFloat,
112
113 #[serde(skip)]
117 Custom(Arc<dyn Fn(Value) -> Value + Send + Sync>),
118}
119
120impl fmt::Debug for TransformType {
121 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
122 match self {
123 TransformType::Map(m) => f.debug_tuple("Map").field(m).finish(),
124 TransformType::Regex {
125 pattern,
126 replacement,
127 } => f
128 .debug_struct("Regex")
129 .field("pattern", pattern)
130 .field("replacement", replacement)
131 .finish(),
132 TransformType::Lowercase => write!(f, "Lowercase"),
133 TransformType::Uppercase => write!(f, "Uppercase"),
134 TransformType::ParseInt => write!(f, "ParseInt"),
135 TransformType::ParseFloat => write!(f, "ParseFloat"),
136 TransformType::Custom(_) => write!(f, "Custom(<fn>)"),
137 }
138 }
139}
140
141impl TransformType {
142 pub fn apply(&self, value: Value) -> Result<Value> {
144 match self {
145 TransformType::Map(mappings) => {
146 if let Value::String(s) = &value {
147 if let Some(mapped) = mappings.get(s) {
148 return Ok(Value::String(mapped.clone()));
149 }
150 }
151 Ok(value)
152 }
153
154 TransformType::Regex {
155 pattern,
156 replacement,
157 } => {
158 if let Value::String(s) = value {
159 let re = Regex::new(pattern).map_err(|e| {
160 ClientError::ValidationError(format!("Invalid regex pattern: {}", e))
161 })?;
162 let result = re.replace_all(&s, replacement.as_str());
163 Ok(Value::String(result.into_owned()))
164 } else {
165 Ok(value)
166 }
167 }
168
169 TransformType::Lowercase => {
170 if let Value::String(s) = value {
171 Ok(Value::String(s.to_lowercase()))
172 } else {
173 Ok(value)
174 }
175 }
176
177 TransformType::Uppercase => {
178 if let Value::String(s) = value {
179 Ok(Value::String(s.to_uppercase()))
180 } else {
181 Ok(value)
182 }
183 }
184
185 TransformType::ParseInt => {
186 if let Value::String(s) = &value {
187 if let Ok(n) = s.parse::<i64>() {
188 return Ok(Value::Number(n.into()));
189 }
190 }
191 if value.is_number() {
193 return Ok(value);
194 }
195 Ok(Value::Null)
196 }
197
198 TransformType::ParseFloat => {
199 if let Value::String(s) = &value {
200 if let Ok(n) = s.parse::<f64>() {
201 if let Some(num) = serde_json::Number::from_f64(n) {
202 return Ok(Value::Number(num));
203 }
204 }
205 }
206 if value.is_number() {
208 return Ok(value);
209 }
210 Ok(Value::Null)
211 }
212
213 TransformType::Custom(f) => Ok(f(value)),
214 }
215 }
216}
217
218pub struct TransformBuilder {
231 field: String,
232 transform_type: TransformBuilderType,
233}
234
235enum TransformBuilderType {
236 Map(HashMap<String, String>),
237 Regex { pattern: String, replacement: String },
238 Lowercase,
239 Uppercase,
240 ParseInt,
241 ParseFloat,
242 Custom(Arc<dyn Fn(Value) -> Value + Send + Sync>),
243}
244
245impl TransformBuilder {
246 pub fn field(name: impl Into<String>) -> Self {
248 Self {
249 field: name.into(),
250 transform_type: TransformBuilderType::Map(HashMap::new()),
251 }
252 }
253
254 pub fn map(mut self, from: impl Into<String>, to: impl Into<String>) -> Self {
258 if let TransformBuilderType::Map(ref mut mappings) = self.transform_type {
259 mappings.insert(from.into(), to.into());
260 } else {
261 let mut mappings = HashMap::new();
263 mappings.insert(from.into(), to.into());
264 self.transform_type = TransformBuilderType::Map(mappings);
265 }
266 self
267 }
268
269 pub fn regex(mut self, pattern: impl Into<String>, replacement: impl Into<String>) -> Self {
273 self.transform_type = TransformBuilderType::Regex {
274 pattern: pattern.into(),
275 replacement: replacement.into(),
276 };
277 self
278 }
279
280 pub fn lowercase(mut self) -> Self {
284 self.transform_type = TransformBuilderType::Lowercase;
285 self
286 }
287
288 pub fn uppercase(mut self) -> Self {
292 self.transform_type = TransformBuilderType::Uppercase;
293 self
294 }
295
296 pub fn parse_int(mut self) -> Self {
300 self.transform_type = TransformBuilderType::ParseInt;
301 self
302 }
303
304 pub fn parse_float(mut self) -> Self {
308 self.transform_type = TransformBuilderType::ParseFloat;
309 self
310 }
311
312 pub fn custom<F>(mut self, f: F) -> Self
316 where
317 F: Fn(Value) -> Value + Send + Sync + 'static,
318 {
319 self.transform_type = TransformBuilderType::Custom(Arc::new(f));
320 self
321 }
322
323 pub fn build(self) -> ValueTransform {
325 let transform = match self.transform_type {
326 TransformBuilderType::Map(mappings) => TransformType::Map(mappings),
327 TransformBuilderType::Regex {
328 pattern,
329 replacement,
330 } => TransformType::Regex {
331 pattern,
332 replacement,
333 },
334 TransformBuilderType::Lowercase => TransformType::Lowercase,
335 TransformBuilderType::Uppercase => TransformType::Uppercase,
336 TransformBuilderType::ParseInt => TransformType::ParseInt,
337 TransformBuilderType::ParseFloat => TransformType::ParseFloat,
338 TransformBuilderType::Custom(f) => TransformType::Custom(f),
339 };
340 ValueTransform {
341 field: self.field,
342 transform,
343 }
344 }
345}
346
347#[derive(Debug, Clone, Default)]
351pub struct TransformEngine {
352 transforms: Vec<ValueTransform>,
353}
354
355impl TransformEngine {
356 pub fn new(transforms: Vec<ValueTransform>) -> Self {
358 Self { transforms }
359 }
360
361 pub fn empty() -> Self {
363 Self::default()
364 }
365
366 pub fn add(&mut self, transform: ValueTransform) {
368 self.transforms.push(transform);
369 }
370
371 pub fn apply(&self, value: &mut Value) -> Result<()> {
373 for transform in &self.transforms {
374 transform.apply(value)?;
375 }
376 Ok(())
377 }
378
379 pub fn is_empty(&self) -> bool {
381 self.transforms.is_empty()
382 }
383
384 pub fn len(&self) -> usize {
386 self.transforms.len()
387 }
388
389 pub fn iter(&self) -> impl Iterator<Item = &ValueTransform> {
391 self.transforms.iter()
392 }
393}
394
395impl FromIterator<ValueTransform> for TransformEngine {
396 fn from_iter<I: IntoIterator<Item = ValueTransform>>(iter: I) -> Self {
397 Self {
398 transforms: iter.into_iter().collect(),
399 }
400 }
401}
402
403#[derive(Debug, Clone, Serialize, Deserialize, Default)]
421pub struct TransformConfig {
422 #[serde(flatten)]
423 pub fields: HashMap<String, FieldTransformConfig>,
424}
425
426#[derive(Debug, Clone, Serialize, Deserialize)]
428#[serde(untagged)]
429pub enum FieldTransformConfig {
430 Map { map: HashMap<String, String> },
432 Regex { pattern: String, replacement: String },
434 Lowercase { lowercase: bool },
436 Uppercase { uppercase: bool },
438 ParseInt { parse_int: bool },
440 ParseFloat { parse_float: bool },
442}
443
444impl TransformConfig {
445 pub fn to_engine(&self) -> TransformEngine {
447 let transforms: Vec<ValueTransform> = self
448 .fields
449 .iter()
450 .map(|(field, config)| {
451 let transform_type = match config {
452 FieldTransformConfig::Map { map } => TransformType::Map(map.clone()),
453 FieldTransformConfig::Regex {
454 pattern,
455 replacement,
456 } => TransformType::Regex {
457 pattern: pattern.clone(),
458 replacement: replacement.clone(),
459 },
460 FieldTransformConfig::Lowercase { lowercase: true } => TransformType::Lowercase,
461 FieldTransformConfig::Lowercase { lowercase: false } => {
462 TransformType::Map(HashMap::new())
463 }
464 FieldTransformConfig::Uppercase { uppercase: true } => TransformType::Uppercase,
465 FieldTransformConfig::Uppercase { uppercase: false } => {
466 TransformType::Map(HashMap::new())
467 }
468 FieldTransformConfig::ParseInt { parse_int: true } => TransformType::ParseInt,
469 FieldTransformConfig::ParseInt { parse_int: false } => {
470 TransformType::Map(HashMap::new())
471 }
472 FieldTransformConfig::ParseFloat { parse_float: true } => {
473 TransformType::ParseFloat
474 }
475 FieldTransformConfig::ParseFloat { parse_float: false } => {
476 TransformType::Map(HashMap::new())
477 }
478 };
479 ValueTransform::new(field.clone(), transform_type)
480 })
481 .collect();
482 TransformEngine::new(transforms)
483 }
484}
485
486fn get_value_mut<'a>(root: &'a mut Value, path: &str) -> Option<&'a mut Value> {
490 let parts: Vec<&str> = path.split('.').collect();
491 let mut current = root;
492
493 for part in parts {
494 if let Ok(index) = part.parse::<usize>() {
496 current = current.get_mut(index)?;
497 } else {
498 current = current.get_mut(part)?;
499 }
500 }
501
502 Some(current)
503}
504
505pub fn get_value<'a>(root: &'a Value, path: &str) -> Option<&'a Value> {
510 let parts: Vec<&str> = path.split('.').collect();
511 let mut current = root;
512
513 for part in parts {
514 if let Ok(index) = part.parse::<usize>() {
516 current = current.get(index)?;
517 } else {
518 current = current.get(part)?;
519 }
520 }
521
522 Some(current)
523}
524
525#[allow(dead_code)]
527fn set_value(root: &mut Value, path: &str, value: Value) {
528 let parts: Vec<&str> = path.split('.').collect();
529
530 if parts.is_empty() {
531 return;
532 }
533
534 let mut current = root;
535 for part in &parts[..parts.len() - 1] {
536 if !current.get(part).is_some() {
538 current[*part] = Value::Object(serde_json::Map::new());
539 }
540 current = current.get_mut(part).unwrap();
541 }
542
543 let last_part = parts[parts.len() - 1];
544 current[last_part] = value;
545}
546
547#[cfg(test)]
548mod tests {
549 use super::*;
550 use serde_json::json;
551
552 #[test]
553 fn test_map_transform() {
554 let transform = TransformBuilder::field("finish_reason")
555 .map("STOP", "stop")
556 .map("MAX_TOKENS", "length")
557 .map("SAFETY", "content_filter")
558 .build();
559
560 let mut value = json!({"finish_reason": "STOP"});
561 transform.apply(&mut value).unwrap();
562 assert_eq!(value["finish_reason"], "stop");
563
564 let mut value = json!({"finish_reason": "MAX_TOKENS"});
565 transform.apply(&mut value).unwrap();
566 assert_eq!(value["finish_reason"], "length");
567
568 let mut value = json!({"finish_reason": "UNKNOWN"});
570 transform.apply(&mut value).unwrap();
571 assert_eq!(value["finish_reason"], "UNKNOWN");
572 }
573
574 #[test]
575 fn test_map_transform_google_role() {
576 let transform = TransformBuilder::field("role")
578 .map("model", "assistant")
579 .build();
580
581 let mut value = json!({"role": "model"});
582 transform.apply(&mut value).unwrap();
583 assert_eq!(value["role"], "assistant");
584
585 let mut value = json!({"role": "user"});
587 transform.apply(&mut value).unwrap();
588 assert_eq!(value["role"], "user");
589 }
590
591 #[test]
592 fn test_regex_transform() {
593 let transform = TransformBuilder::field("text")
594 .regex(r"\d+", "NUMBER")
595 .build();
596
597 let mut value = json!({"text": "I have 42 apples and 7 oranges"});
598 transform.apply(&mut value).unwrap();
599 assert_eq!(value["text"], "I have NUMBER apples and NUMBER oranges");
600 }
601
602 #[test]
603 fn test_regex_transform_with_capture_groups() {
604 let transform = TransformBuilder::field("text")
605 .regex(r"(\w+)@(\w+)\.com", "[$1 at $2]")
606 .build();
607
608 let mut value = json!({"text": "Contact: user@example.com"});
609 transform.apply(&mut value).unwrap();
610 assert_eq!(value["text"], "Contact: [user at example]");
611 }
612
613 #[test]
614 fn test_lowercase_transform() {
615 let transform = TransformBuilder::field("status").lowercase().build();
616
617 let mut value = json!({"status": "ACTIVE"});
618 transform.apply(&mut value).unwrap();
619 assert_eq!(value["status"], "active");
620 }
621
622 #[test]
623 fn test_uppercase_transform() {
624 let transform = TransformBuilder::field("status").uppercase().build();
625
626 let mut value = json!({"status": "pending"});
627 transform.apply(&mut value).unwrap();
628 assert_eq!(value["status"], "PENDING");
629 }
630
631 #[test]
632 fn test_parse_int_transform() {
633 let transform = TransformBuilder::field("count").parse_int().build();
634
635 let mut value = json!({"count": "42"});
636 transform.apply(&mut value).unwrap();
637 assert_eq!(value["count"], 42);
638
639 let mut value = json!({"count": "not_a_number"});
641 transform.apply(&mut value).unwrap();
642 assert!(value["count"].is_null());
643
644 let mut value = json!({"count": 42});
646 transform.apply(&mut value).unwrap();
647 assert_eq!(value["count"], 42);
648 }
649
650 #[test]
651 fn test_parse_float_transform() {
652 let transform = TransformBuilder::field("ratio").parse_float().build();
653
654 let mut value = json!({"ratio": "3.14159"});
655 transform.apply(&mut value).unwrap();
656 assert!((value["ratio"].as_f64().unwrap() - 3.14159).abs() < 0.0001);
657
658 let mut value = json!({"ratio": "not_a_float"});
660 transform.apply(&mut value).unwrap();
661 assert!(value["ratio"].is_null());
662 }
663
664 #[test]
665 fn test_custom_transform() {
666 let transform = TransformBuilder::field("data")
667 .custom(|v| {
668 if let Value::Array(arr) = v {
669 Value::Number(arr.len().into())
670 } else {
671 v
672 }
673 })
674 .build();
675
676 let mut value = json!({"data": [1, 2, 3, 4, 5]});
677 transform.apply(&mut value).unwrap();
678 assert_eq!(value["data"], 5);
679 }
680
681 #[test]
682 fn test_nested_field_transform() {
683 let transform = TransformBuilder::field("response.status.code")
684 .map("OK", "success")
685 .build();
686
687 let mut value = json!({
688 "response": {
689 "status": {
690 "code": "OK"
691 }
692 }
693 });
694 transform.apply(&mut value).unwrap();
695 assert_eq!(value["response"]["status"]["code"], "success");
696 }
697
698 #[test]
699 fn test_transform_engine() {
700 let engine = TransformEngine::new(vec![
701 TransformBuilder::field("finish_reason")
702 .map("STOP", "stop")
703 .map("MAX_TOKENS", "length")
704 .build(),
705 TransformBuilder::field("role")
706 .map("model", "assistant")
707 .build(),
708 ]);
709
710 let mut value = json!({
711 "finish_reason": "STOP",
712 "role": "model",
713 "content": "Hello!"
714 });
715
716 engine.apply(&mut value).unwrap();
717
718 assert_eq!(value["finish_reason"], "stop");
719 assert_eq!(value["role"], "assistant");
720 assert_eq!(value["content"], "Hello!");
721 }
722
723 #[test]
724 fn test_transform_engine_from_iter() {
725 let transforms = vec![
726 TransformBuilder::field("a").lowercase().build(),
727 TransformBuilder::field("b").uppercase().build(),
728 ];
729
730 let engine: TransformEngine = transforms.into_iter().collect();
731 assert_eq!(engine.len(), 2);
732 }
733
734 #[test]
735 fn test_transform_config_parsing() {
736 let yaml = r#"
737finish_reason:
738 map:
739 STOP: stop
740 MAX_TOKENS: length
741role:
742 map:
743 model: assistant
744"#;
745
746 let config: TransformConfig = serde_yaml::from_str(yaml).unwrap();
747 let engine = config.to_engine();
748 assert_eq!(engine.len(), 2);
749
750 let mut value = json!({
751 "finish_reason": "STOP",
752 "role": "model"
753 });
754 engine.apply(&mut value).unwrap();
755 assert_eq!(value["finish_reason"], "stop");
756 assert_eq!(value["role"], "assistant");
757 }
758
759 #[test]
760 fn test_transform_missing_field() {
761 let transform = TransformBuilder::field("nonexistent")
762 .map("a", "b")
763 .build();
764
765 let mut value = json!({"other_field": "value"});
766 transform.apply(&mut value).unwrap();
768 assert_eq!(value, json!({"other_field": "value"}));
769 }
770
771 #[test]
772 fn test_transform_non_string_value() {
773 let transform = TransformBuilder::field("count")
775 .map("42", "forty-two")
776 .build();
777
778 let mut value = json!({"count": 42});
779 transform.apply(&mut value).unwrap();
780 assert_eq!(value["count"], 42); }
782
783 #[test]
784 fn test_get_value_array_index() {
785 let value = json!({
786 "items": [
787 {"name": "first"},
788 {"name": "second"}
789 ]
790 });
791
792 assert_eq!(get_value(&value, "items.0.name").unwrap(), "first");
793 assert_eq!(get_value(&value, "items.1.name").unwrap(), "second");
794 }
795
796 #[test]
797 fn test_transform_array_element() {
798 let transform = TransformBuilder::field("items.0.status")
799 .map("ACTIVE", "active")
800 .build();
801
802 let mut value = json!({
803 "items": [
804 {"status": "ACTIVE"},
805 {"status": "INACTIVE"}
806 ]
807 });
808 transform.apply(&mut value).unwrap();
809 assert_eq!(value["items"][0]["status"], "active");
810 assert_eq!(value["items"][1]["status"], "INACTIVE");
811 }
812
813 #[test]
814 fn test_serialization_roundtrip() {
815 let transform = ValueTransform::new(
816 "finish_reason",
817 TransformType::Map({
818 let mut m = HashMap::new();
819 m.insert("STOP".to_string(), "stop".to_string());
820 m.insert("MAX_TOKENS".to_string(), "length".to_string());
821 m
822 }),
823 );
824
825 let json = serde_json::to_string(&transform).unwrap();
826 let deserialized: ValueTransform = serde_json::from_str(&json).unwrap();
827
828 assert_eq!(deserialized.field, "finish_reason");
829 if let TransformType::Map(map) = &deserialized.transform {
830 assert_eq!(map.get("STOP"), Some(&"stop".to_string()));
831 } else {
832 panic!("Expected Map transform type");
833 }
834 }
835
836 #[test]
837 fn test_yaml_serialization() {
838 let transform = ValueTransform::new(
839 "status",
840 TransformType::Regex {
841 pattern: r"\d+".to_string(),
842 replacement: "NUM".to_string(),
843 },
844 );
845
846 let yaml = serde_yaml::to_string(&transform).unwrap();
847 let deserialized: ValueTransform = serde_yaml::from_str(&yaml).unwrap();
848
849 assert_eq!(deserialized.field, "status");
850 if let TransformType::Regex {
851 pattern,
852 replacement,
853 } = &deserialized.transform
854 {
855 assert_eq!(pattern, r"\d+");
856 assert_eq!(replacement, "NUM");
857 } else {
858 panic!("Expected Regex transform type");
859 }
860 }
861
862 #[test]
863 fn test_google_finish_reason_transforms() {
864 let engine = TransformEngine::new(vec![
866 TransformBuilder::field("finish_reason")
867 .map("STOP", "stop")
868 .map("MAX_TOKENS", "length")
869 .map("SAFETY", "content_filter")
870 .map("RECITATION", "content_filter")
871 .map("OTHER", "other")
872 .build(),
873 TransformBuilder::field("role")
874 .map("model", "assistant")
875 .build(),
876 ]);
877
878 let test_cases = vec![
879 ("STOP", "stop"),
880 ("MAX_TOKENS", "length"),
881 ("SAFETY", "content_filter"),
882 ("RECITATION", "content_filter"),
883 ("OTHER", "other"),
884 ];
885
886 for (input, expected) in test_cases {
887 let mut value = json!({"finish_reason": input, "role": "model"});
888 engine.apply(&mut value).unwrap();
889 assert_eq!(
890 value["finish_reason"], expected,
891 "Failed for input: {}",
892 input
893 );
894 assert_eq!(value["role"], "assistant");
895 }
896 }
897
898 #[test]
899 fn test_openai_finish_reason_transforms() {
900 let engine = TransformEngine::new(vec![TransformBuilder::field("finish_reason")
902 .map("stop", "stop")
903 .map("length", "length")
904 .map("tool_calls", "tool_calls")
905 .map("content_filter", "content_filter")
906 .build()]);
907
908 let mut value = json!({"finish_reason": "stop"});
910 engine.apply(&mut value).unwrap();
911 assert_eq!(value["finish_reason"], "stop");
912
913 let mut value = json!({"finish_reason": "tool_calls"});
914 engine.apply(&mut value).unwrap();
915 assert_eq!(value["finish_reason"], "tool_calls");
916 }
917
918 #[test]
919 fn test_chained_transforms() {
920 let engine = TransformEngine::new(vec![
922 TransformBuilder::field("text").lowercase().build(),
923 TransformBuilder::field("text")
924 .regex(r"\s+", "_")
925 .build(),
926 ]);
927
928 let mut value = json!({"text": "Hello World"});
929 engine.apply(&mut value).unwrap();
930 assert_eq!(value["text"], "hello_world");
931 }
932
933 #[test]
934 fn test_empty_engine() {
935 let engine = TransformEngine::empty();
936 assert!(engine.is_empty());
937 assert_eq!(engine.len(), 0);
938
939 let mut value = json!({"test": "value"});
940 engine.apply(&mut value).unwrap();
941 assert_eq!(value["test"], "value");
942 }
943
944 #[test]
945 fn test_engine_add() {
946 let mut engine = TransformEngine::empty();
947 engine.add(TransformBuilder::field("test").lowercase().build());
948 assert_eq!(engine.len(), 1);
949 }
950}