use super::{PipelineContext, PipelineError, PipelineResult};
use std::sync::Arc;
pub trait ModelTransform: Send + Sync {
fn name(&self) -> &str;
fn transform_input(&self, input: &str, context: &PipelineContext) -> PipelineResult<String>;
fn transform_output(&self, output: &str, context: &PipelineContext) -> PipelineResult<String>;
}
#[allow(dead_code)]
pub struct IdentityTransform;
impl ModelTransform for IdentityTransform {
fn name(&self) -> &str {
"identity"
}
fn transform_input(&self, input: &str, _context: &PipelineContext) -> PipelineResult<String> {
Ok(input.to_string())
}
fn transform_output(&self, output: &str, _context: &PipelineContext) -> PipelineResult<String> {
Ok(output.to_string())
}
}
#[allow(dead_code)]
pub struct WrapTransform {
name: String,
input_prefix: String,
input_suffix: String,
output_prefix: String,
output_suffix: String,
}
#[allow(dead_code)]
impl WrapTransform {
pub fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
input_prefix: String::new(),
input_suffix: String::new(),
output_prefix: String::new(),
output_suffix: String::new(),
}
}
pub fn with_input_prefix(mut self, prefix: impl Into<String>) -> Self {
self.input_prefix = prefix.into();
self
}
pub fn with_input_suffix(mut self, suffix: impl Into<String>) -> Self {
self.input_suffix = suffix.into();
self
}
pub fn with_output_prefix(mut self, prefix: impl Into<String>) -> Self {
self.output_prefix = prefix.into();
self
}
pub fn with_output_suffix(mut self, suffix: impl Into<String>) -> Self {
self.output_suffix = suffix.into();
self
}
}
impl ModelTransform for WrapTransform {
fn name(&self) -> &str {
&self.name
}
fn transform_input(&self, input: &str, _context: &PipelineContext) -> PipelineResult<String> {
Ok(format!(
"{}{}{}",
self.input_prefix, input, self.input_suffix
))
}
fn transform_output(&self, output: &str, _context: &PipelineContext) -> PipelineResult<String> {
Ok(format!(
"{}{}{}",
self.output_prefix, output, self.output_suffix
))
}
}
#[allow(dead_code)]
pub struct RegexTransform {
name: String,
input_patterns: Vec<(regex::Regex, String)>,
output_patterns: Vec<(regex::Regex, String)>,
}
#[allow(dead_code)]
impl RegexTransform {
pub fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
input_patterns: Vec::new(),
output_patterns: Vec::new(),
}
}
pub fn with_input_replace(
mut self,
pattern: &str,
replacement: impl Into<String>,
) -> PipelineResult<Self> {
let regex = regex::Regex::new(pattern)
.map_err(|e| PipelineError::TransformError(format!("Invalid regex: {}", e)))?;
self.input_patterns.push((regex, replacement.into()));
Ok(self)
}
pub fn with_output_replace(
mut self,
pattern: &str,
replacement: impl Into<String>,
) -> PipelineResult<Self> {
let regex = regex::Regex::new(pattern)
.map_err(|e| PipelineError::TransformError(format!("Invalid regex: {}", e)))?;
self.output_patterns.push((regex, replacement.into()));
Ok(self)
}
}
impl ModelTransform for RegexTransform {
fn name(&self) -> &str {
&self.name
}
fn transform_input(&self, input: &str, _context: &PipelineContext) -> PipelineResult<String> {
let mut result = input.to_string();
for (pattern, replacement) in &self.input_patterns {
result = pattern
.replace_all(&result, replacement.as_str())
.to_string();
}
Ok(result)
}
fn transform_output(&self, output: &str, _context: &PipelineContext) -> PipelineResult<String> {
let mut result = output.to_string();
for (pattern, replacement) in &self.output_patterns {
result = pattern
.replace_all(&result, replacement.as_str())
.to_string();
}
Ok(result)
}
}
#[allow(dead_code)]
pub struct JsonExtractTransform {
name: String,
}
#[allow(dead_code)]
impl JsonExtractTransform {
pub fn new(name: impl Into<String>) -> Self {
Self { name: name.into() }
}
}
impl ModelTransform for JsonExtractTransform {
fn name(&self) -> &str {
&self.name
}
fn transform_input(&self, input: &str, _context: &PipelineContext) -> PipelineResult<String> {
Ok(input.to_string())
}
fn transform_output(&self, output: &str, _context: &PipelineContext) -> PipelineResult<String> {
if let Some(start) = output.find("```json") {
if let Some(end) = output[start + 7..].find("```") {
let json_str = output[start + 7..start + 7 + end].trim();
if serde_json::from_str::<serde_json::Value>(json_str).is_ok() {
return Ok(json_str.to_string());
}
}
}
if let Some(start) = output.find("```") {
let after_start = start + 3;
let content_start = output[after_start..]
.find('\n')
.map(|i| after_start + i + 1)
.unwrap_or(after_start);
if let Some(end) = output[content_start..].find("```") {
let json_str = output[content_start..content_start + end].trim();
if serde_json::from_str::<serde_json::Value>(json_str).is_ok() {
return Ok(json_str.to_string());
}
}
}
for (start_char, end_char) in [("{", "}"), ("[", "]")] {
if let Some(start) = output.find(start_char) {
let mut depth = 0;
let mut end = None;
for (i, c) in output[start..].char_indices() {
if c.to_string() == start_char {
depth += 1;
} else if c.to_string() == end_char {
depth -= 1;
if depth == 0 {
end = Some(start + i + 1);
break;
}
}
}
if let Some(end_pos) = end {
let json_str = &output[start..end_pos];
if serde_json::from_str::<serde_json::Value>(json_str).is_ok() {
return Ok(json_str.to_string());
}
}
}
}
Ok(output.to_string())
}
}
pub struct TransformChain {
name: String,
transforms: Vec<Arc<dyn ModelTransform>>,
}
impl TransformChain {
pub fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
transforms: Vec::new(),
}
}
#[allow(clippy::should_implement_trait)]
pub fn add(mut self, transform: impl ModelTransform + 'static) -> Self {
self.transforms.push(Arc::new(transform));
self
}
pub fn len(&self) -> usize {
self.transforms.len()
}
pub fn is_empty(&self) -> bool {
self.transforms.is_empty()
}
}
impl ModelTransform for TransformChain {
fn name(&self) -> &str {
&self.name
}
fn transform_input(&self, input: &str, context: &PipelineContext) -> PipelineResult<String> {
let mut result = input.to_string();
for transform in &self.transforms {
result = transform.transform_input(&result, context)?;
}
Ok(result)
}
fn transform_output(&self, output: &str, context: &PipelineContext) -> PipelineResult<String> {
let mut result = output.to_string();
for transform in self.transforms.iter().rev() {
result = transform.transform_output(&result, context)?;
}
Ok(result)
}
}
type TransformFn = Box<dyn Fn(&str, &PipelineContext) -> PipelineResult<String> + Send + Sync>;
#[allow(dead_code)]
pub struct CustomTransform {
name: String,
input_fn: TransformFn,
output_fn: TransformFn,
}
#[allow(dead_code)]
impl CustomTransform {
pub fn new<I, O>(name: impl Into<String>, input_fn: I, output_fn: O) -> Self
where
I: Fn(&str, &PipelineContext) -> PipelineResult<String> + Send + Sync + 'static,
O: Fn(&str, &PipelineContext) -> PipelineResult<String> + Send + Sync + 'static,
{
Self {
name: name.into(),
input_fn: Box::new(input_fn),
output_fn: Box::new(output_fn),
}
}
pub fn input_only<F>(name: impl Into<String>, f: F) -> Self
where
F: Fn(&str, &PipelineContext) -> PipelineResult<String> + Send + Sync + 'static,
{
Self::new(name, f, |s, _| Ok(s.to_string()))
}
pub fn output_only<F>(name: impl Into<String>, f: F) -> Self
where
F: Fn(&str, &PipelineContext) -> PipelineResult<String> + Send + Sync + 'static,
{
Self::new(name, |s, _| Ok(s.to_string()), f)
}
}
impl ModelTransform for CustomTransform {
fn name(&self) -> &str {
&self.name
}
fn transform_input(&self, input: &str, context: &PipelineContext) -> PipelineResult<String> {
(self.input_fn)(input, context)
}
fn transform_output(&self, output: &str, context: &PipelineContext) -> PipelineResult<String> {
(self.output_fn)(output, context)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_identity_transform() {
let transform = IdentityTransform;
let ctx = PipelineContext::new();
assert_eq!(transform.name(), "identity");
assert_eq!(transform.transform_input("hello", &ctx).unwrap(), "hello");
assert_eq!(transform.transform_output("world", &ctx).unwrap(), "world");
}
#[test]
fn test_wrap_transform() {
let transform = WrapTransform::new("wrapper")
.with_input_prefix("<<")
.with_input_suffix(">>")
.with_output_prefix("[")
.with_output_suffix("]");
let ctx = PipelineContext::new();
assert_eq!(transform.transform_input("text", &ctx).unwrap(), "<<text>>");
assert_eq!(
transform.transform_output("result", &ctx).unwrap(),
"[result]"
);
}
#[test]
fn test_regex_transform() {
let transform = RegexTransform::new("regex")
.with_input_replace(r"\bfoo\b", "bar")
.unwrap()
.with_output_replace(r"\d+", "NUMBER")
.unwrap();
let ctx = PipelineContext::new();
assert_eq!(
transform.transform_input("foo is foo", &ctx).unwrap(),
"bar is bar"
);
assert_eq!(
transform.transform_output("count: 42", &ctx).unwrap(),
"count: NUMBER"
);
}
#[test]
fn test_json_extract_transform() {
let transform = JsonExtractTransform::new("json");
let ctx = PipelineContext::new();
let output = r#"Here's the JSON:
```json
{"key": "value"}
```
Done!"#;
assert_eq!(
transform.transform_output(output, &ctx).unwrap(),
r#"{"key": "value"}"#
);
let output = r#"The result is {"a": 1, "b": 2} which is good."#;
assert_eq!(
transform.transform_output(output, &ctx).unwrap(),
r#"{"a": 1, "b": 2}"#
);
let output = "Just plain text";
assert_eq!(
transform.transform_output(output, &ctx).unwrap(),
"Just plain text"
);
}
#[test]
fn test_transform_chain() {
let chain = TransformChain::new("chain")
.add(WrapTransform::new("wrap1").with_input_prefix("A"))
.add(WrapTransform::new("wrap2").with_input_prefix("B"));
let ctx = PipelineContext::new();
assert_eq!(chain.transform_input("X", &ctx).unwrap(), "BAX");
assert_eq!(chain.len(), 2);
}
#[test]
fn test_custom_transform() {
let transform = CustomTransform::new(
"custom",
|s, _| Ok(s.to_uppercase()),
|s, _| Ok(s.to_lowercase()),
);
let ctx = PipelineContext::new();
assert_eq!(transform.name(), "custom");
assert_eq!(transform.transform_input("Hello", &ctx).unwrap(), "HELLO");
assert_eq!(transform.transform_output("WORLD", &ctx).unwrap(), "world");
}
#[test]
fn test_custom_transform_input_only() {
let transform = CustomTransform::input_only("input-only", |s, _| Ok(format!("[{}]", s)));
let ctx = PipelineContext::new();
assert_eq!(transform.transform_input("test", &ctx).unwrap(), "[test]");
assert_eq!(transform.transform_output("test", &ctx).unwrap(), "test");
}
}