1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
#[cfg(feature = "serialization")]
use serde::{Deserialize, Serialize};

use crate::{
    frame::Frame,
    traits::{Executor, ExecutorError, Step},
    Parameters,
};

#[cfg(feature = "serialization")]
use crate::serialization::StorableEntity;

#[derive(thiserror::Error, Debug)]
pub enum SequentialChainError<Err: ExecutorError> {
    #[error("ExecutorError: {0}")]
    ExecutorError(#[from] Err),
    #[error("The vector of steps was empty")]
    NoSteps,
}
// A sequential chain is a chain where each step is executed in order, with the output of the previous being available to the next.
pub struct Chain<S: Step> {
    steps: Vec<S>,
}

impl<S: Step> Chain<S> {
    pub fn new(steps: Vec<S>) -> Chain<S> {
        Chain { steps }
    }
    pub fn of_one(step: S) -> Chain<S> {
        Chain { steps: vec![step] }
    }

    pub async fn run<E: Executor<Step = S>>(
        &self,
        parameters: Parameters,
        executor: &E,
    ) -> Result<E::Output, SequentialChainError<E::Error>> {
        if self.steps.is_empty() {
            return Err(SequentialChainError::NoSteps);
        }

        let mut current_params = parameters;
        let mut output: Option<E::Output> = None;
        for step in self.steps.iter() {
            let frame = Frame::new(executor, step);
            let res = frame.format_and_execute(&current_params).await?;

            current_params = current_params.with_text_from_output(&res).await;
            output = Some(res);
        }
        Ok(output.expect("No output from chain"))
    }
}

#[cfg(feature = "serialization")]
impl<S: Step + Serialize> Serialize for Chain<S> {
    fn serialize<SER>(&self, serializer: SER) -> Result<SER::Ok, SER::Error>
    where
        SER: serde::Serializer,
    {
        Serialize::serialize(&self.steps, serializer)
    }
}

#[cfg(feature = "serialization")]
impl<'de, S: Step + Deserialize<'de>> Deserialize<'de> for Chain<S> {
    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
    where
        D: serde::Deserializer<'de>,
    {
        Deserialize::deserialize(deserializer).map(|steps| Chain { steps })
    }
}

#[cfg(feature = "serialization")]
impl<S: Step + StorableEntity> StorableEntity for Chain<S> {
    fn get_metadata() -> Vec<(String, String)> {
        let mut base = vec![(
            "chain-type".to_string(),
            "llm-chain::chains::sequential::Chain".to_string(),
        )];
        base.append(&mut S::get_metadata());
        base
    }
}