mod transform;
pub use transform::{ModelTransform, TransformChain};
use crate::client::AsyncForgeClient;
use crate::error::ForgeError;
use crate::types::{ChatCompletionRequest, Message};
use async_trait::async_trait;
use std::collections::HashMap;
use std::sync::Arc;
type StepTransformFn =
Box<dyn Fn(&str, &mut PipelineContext) -> PipelineResult<String> + Send + Sync>;
#[derive(Debug, thiserror::Error)]
pub enum PipelineError {
#[error("Pipeline step '{0}' failed: {1}")]
StepFailed(String, String),
#[error("Invalid pipeline configuration: {0}")]
InvalidConfig(String),
#[error("Transform error: {0}")]
TransformError(String),
#[error("LLM error: {0}")]
LlmError(#[from] ForgeError),
#[error("{0}")]
Other(String),
}
pub type PipelineResult<T> = Result<T, PipelineError>;
#[derive(Debug, Clone, Default)]
pub struct PipelineContext {
pub data: HashMap<String, serde_json::Value>,
pub messages: Vec<Message>,
pub metadata: HashMap<String, serde_json::Value>,
}
impl PipelineContext {
pub fn new() -> Self {
Self::default()
}
pub fn with_messages(messages: Vec<Message>) -> Self {
Self {
messages,
..Default::default()
}
}
pub fn set(&mut self, key: impl Into<String>, value: serde_json::Value) {
self.data.insert(key.into(), value);
}
pub fn get(&self, key: &str) -> Option<&serde_json::Value> {
self.data.get(key)
}
pub fn get_string(&self, key: &str) -> Option<&str> {
self.data.get(key).and_then(|v| v.as_str())
}
pub fn add_message(&mut self, message: Message) {
self.messages.push(message);
}
pub fn set_metadata(&mut self, key: impl Into<String>, value: serde_json::Value) {
self.metadata.insert(key.into(), value);
}
}
#[derive(Debug, Clone)]
pub struct StepOutput {
pub text: String,
pub data: Option<serde_json::Value>,
pub continue_pipeline: bool,
}
impl StepOutput {
pub fn new(text: impl Into<String>) -> Self {
Self {
text: text.into(),
data: None,
continue_pipeline: true,
}
}
pub fn with_data(mut self, data: serde_json::Value) -> Self {
self.data = Some(data);
self
}
pub fn stop(mut self) -> Self {
self.continue_pipeline = false;
self
}
}
#[async_trait]
pub trait PipelineStep: Send + Sync {
fn name(&self) -> &str;
async fn execute(
&self,
client: &AsyncForgeClient,
context: &mut PipelineContext,
) -> PipelineResult<StepOutput>;
}
pub struct LlmStep {
name: String,
system_prompt: Option<String>,
model: Option<String>,
temperature: Option<f32>,
max_tokens: Option<u32>,
}
impl LlmStep {
pub fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
system_prompt: None,
model: None,
temperature: None,
max_tokens: None,
}
}
pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
self.system_prompt = Some(prompt.into());
self
}
pub fn with_model(mut self, model: impl Into<String>) -> Self {
self.model = Some(model.into());
self
}
pub fn with_temperature(mut self, temp: f32) -> Self {
self.temperature = Some(temp);
self
}
pub fn with_max_tokens(mut self, tokens: u32) -> Self {
self.max_tokens = Some(tokens);
self
}
}
#[async_trait]
impl PipelineStep for LlmStep {
fn name(&self) -> &str {
&self.name
}
async fn execute(
&self,
client: &AsyncForgeClient,
context: &mut PipelineContext,
) -> PipelineResult<StepOutput> {
let mut messages = Vec::new();
if let Some(ref prompt) = self.system_prompt {
messages.push(Message::system(prompt));
}
messages.extend(context.messages.clone());
let model = self
.model
.clone()
.unwrap_or_else(|| client.model().to_string());
let mut request = ChatCompletionRequest::new(model, messages);
if let Some(temp) = self.temperature {
request = request.temperature(temp);
}
if let Some(max_tokens) = self.max_tokens {
request = request.max_tokens(max_tokens);
}
let response = client.chat_completions(request).await?;
let text = response
.choices
.first()
.and_then(|c| c.message.content.clone())
.unwrap_or_default();
context.add_message(Message::assistant(&text));
Ok(StepOutput::new(text))
}
}
pub struct TransformStep {
name: String,
transform: StepTransformFn,
}
impl TransformStep {
pub fn new<F>(name: impl Into<String>, transform: F) -> Self
where
F: Fn(&str, &mut PipelineContext) -> PipelineResult<String> + Send + Sync + 'static,
{
Self {
name: name.into(),
transform: Box::new(transform),
}
}
}
#[async_trait]
impl PipelineStep for TransformStep {
fn name(&self) -> &str {
&self.name
}
async fn execute(
&self,
_client: &AsyncForgeClient,
context: &mut PipelineContext,
) -> PipelineResult<StepOutput> {
let input = context
.messages
.last()
.and_then(|m| m.content.as_ref())
.cloned()
.unwrap_or_default();
let output = (self.transform)(&input, context)?;
Ok(StepOutput::new(output))
}
}
pub struct BranchStep {
name: String,
condition: Box<dyn Fn(&PipelineContext) -> bool + Send + Sync>,
if_true: Arc<dyn PipelineStep>,
if_false: Option<Arc<dyn PipelineStep>>,
}
impl BranchStep {
pub fn new<F>(
name: impl Into<String>,
condition: F,
if_true: impl PipelineStep + 'static,
) -> Self
where
F: Fn(&PipelineContext) -> bool + Send + Sync + 'static,
{
Self {
name: name.into(),
condition: Box::new(condition),
if_true: Arc::new(if_true),
if_false: None,
}
}
pub fn with_else(mut self, step: impl PipelineStep + 'static) -> Self {
self.if_false = Some(Arc::new(step));
self
}
}
#[async_trait]
impl PipelineStep for BranchStep {
fn name(&self) -> &str {
&self.name
}
async fn execute(
&self,
client: &AsyncForgeClient,
context: &mut PipelineContext,
) -> PipelineResult<StepOutput> {
if (self.condition)(context) {
self.if_true.execute(client, context).await
} else if let Some(ref step) = self.if_false {
step.execute(client, context).await
} else {
let text = context
.messages
.last()
.and_then(|m| m.content.as_ref())
.cloned()
.unwrap_or_default();
Ok(StepOutput::new(text))
}
}
}
pub struct Pipeline {
name: String,
description: String,
steps: Vec<Arc<dyn PipelineStep>>,
}
impl Pipeline {
pub fn new(name: impl Into<String>, description: impl Into<String>) -> Self {
Self {
name: name.into(),
description: description.into(),
steps: Vec::new(),
}
}
pub fn add_step(mut self, step: impl PipelineStep + 'static) -> Self {
self.steps.push(Arc::new(step));
self
}
pub fn name(&self) -> &str {
&self.name
}
pub fn description(&self) -> &str {
&self.description
}
pub fn step_count(&self) -> usize {
self.steps.len()
}
pub async fn execute(
&self,
client: &AsyncForgeClient,
input: impl Into<String>,
) -> PipelineResult<PipelineOutput> {
let mut context = PipelineContext::new();
context.add_message(Message::user(input.into()));
let mut step_outputs = Vec::new();
for step in &self.steps {
let output = step
.execute(client, &mut context)
.await
.map_err(|e| PipelineError::StepFailed(step.name().to_string(), e.to_string()))?;
step_outputs.push((step.name().to_string(), output.clone()));
if !output.continue_pipeline {
break;
}
if !output.text.is_empty() {
context.set("last_output", serde_json::json!(output.text));
}
}
let final_text = step_outputs
.last()
.map(|(_, o)| o.text.clone())
.unwrap_or_default();
Ok(PipelineOutput {
text: final_text,
steps: step_outputs,
context,
})
}
}
#[derive(Debug)]
pub struct PipelineOutput {
pub text: String,
pub steps: Vec<(String, StepOutput)>,
pub context: PipelineContext,
}
pub struct PipelineBuilder {
name: String,
description: String,
steps: Vec<Arc<dyn PipelineStep>>,
}
impl PipelineBuilder {
pub fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
description: String::new(),
steps: Vec::new(),
}
}
pub fn description(mut self, description: impl Into<String>) -> Self {
self.description = description.into();
self
}
pub fn llm(mut self, name: impl Into<String>) -> Self {
self.steps.push(Arc::new(LlmStep::new(name)));
self
}
pub fn llm_with(mut self, step: LlmStep) -> Self {
self.steps.push(Arc::new(step));
self
}
pub fn transform<F>(mut self, name: impl Into<String>, f: F) -> Self
where
F: Fn(&str, &mut PipelineContext) -> PipelineResult<String> + Send + Sync + 'static,
{
self.steps.push(Arc::new(TransformStep::new(name, f)));
self
}
pub fn step(mut self, step: impl PipelineStep + 'static) -> Self {
self.steps.push(Arc::new(step));
self
}
pub fn build(self) -> Pipeline {
Pipeline {
name: self.name,
description: self.description,
steps: self.steps,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pipeline_context() {
let mut ctx = PipelineContext::new();
ctx.set("key", serde_json::json!("value"));
assert_eq!(ctx.get_string("key"), Some("value"));
ctx.add_message(Message::user("Hello"));
assert_eq!(ctx.messages.len(), 1);
}
#[test]
fn test_step_output() {
let output = StepOutput::new("Hello").with_data(serde_json::json!({"count": 1}));
assert_eq!(output.text, "Hello");
assert!(output.data.is_some());
assert!(output.continue_pipeline);
let stopped = StepOutput::new("Done").stop();
assert!(!stopped.continue_pipeline);
}
#[test]
fn test_llm_step_builder() {
let step = LlmStep::new("test")
.with_system_prompt("Be helpful")
.with_model("gpt-4")
.with_temperature(0.7)
.with_max_tokens(100);
assert_eq!(step.name, "test");
assert_eq!(step.system_prompt, Some("Be helpful".to_string()));
assert_eq!(step.model, Some("gpt-4".to_string()));
assert_eq!(step.temperature, Some(0.7));
assert_eq!(step.max_tokens, Some(100));
}
#[test]
fn test_pipeline_builder() {
let pipeline = PipelineBuilder::new("test-pipeline")
.description("A test pipeline")
.llm("step1")
.transform("uppercase", |s, _| Ok(s.to_uppercase()))
.build();
assert_eq!(pipeline.name(), "test-pipeline");
assert_eq!(pipeline.description(), "A test pipeline");
assert_eq!(pipeline.step_count(), 2);
}
#[test]
fn test_pipeline_new() {
let pipeline =
Pipeline::new("my-pipeline", "Does something").add_step(LlmStep::new("process"));
assert_eq!(pipeline.name(), "my-pipeline");
assert_eq!(pipeline.step_count(), 1);
}
#[test]
fn test_transform_step() {
let step = TransformStep::new("reverse", |s, _| Ok(s.chars().rev().collect()));
assert_eq!(step.name(), "reverse");
}
}