use serde::{Deserialize, Serialize};
use crate::frame::FormatAndExecuteError;
use crate::output::Output;
use crate::{
frame::Frame, serialization::StorableEntity, step::Step, traits::Executor, Parameters,
};
#[derive(thiserror::Error, Debug)]
pub enum SequentialChainError {
#[error("ExecutorError: {0}")]
FormatAndExecuteError(#[from] FormatAndExecuteError),
#[error("The vector of steps was empty")]
NoSteps,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Chain {
steps: Vec<Step>,
}
impl Chain {
pub fn new(steps: Vec<Step>) -> Chain {
Chain { steps }
}
pub fn of_one(step: Step) -> Chain {
Chain { steps: vec![step] }
}
pub async fn run<E>(
&self,
parameters: Parameters,
executor: &E,
) -> Result<Output, SequentialChainError>
where
E: Executor,
{
if self.steps.is_empty() {
return Err(SequentialChainError::NoSteps);
}
let mut current_params = parameters;
for step in &self.steps[..self.steps.len() - 1] {
let body = Frame::new(executor, step)
.format_and_execute(¤t_params)
.await?
.to_immediate()
.await
.map_err(|err| {
SequentialChainError::FormatAndExecuteError(FormatAndExecuteError::Execute(err))
})?
.as_content()
.extract_last_body()
.cloned()
.unwrap_or_default();
current_params = current_params.with_text(body);
}
let last_step = self.steps.last().unwrap();
Ok(Frame::new(executor, last_step)
.format_and_execute(¤t_params)
.await?)
}
}
impl StorableEntity for Chain {
fn get_metadata() -> Vec<(String, String)> {
let base = vec![(
"chain-type".to_string(),
"llm-chain::chains::sequential::Chain".to_string(),
)];
base
}
}