1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
//! The `Chain` module models a conversation between an entity and an LLM.
//! It manages the conversation state and provides methods for sending messages and receiving responses.
//!
//! It relies on the `traits::Executor` trait to execute prompts and handle LLM interactions.

use crate::output::Output;
use crate::prompt::{ChatMessage, ChatMessageCollection, Prompt, PromptTemplate};
use crate::step::Step;
use crate::tokens::{PromptTokensError, TokenizerError};
use crate::traits::{self, ExecutorError};
use crate::{parameters, Parameters};
use serde::{Deserialize, Serialize};

/// `Chain` represents a conversation between an entity and an LLM.
///
/// It holds the conversation state and provides methods for sending messages and receiving responses.
#[derive(Serialize, Deserialize)]
pub struct Chain<E: traits::Executor> {
    state: ChatMessageCollection<String>,
    _phantom: std::marker::PhantomData<E>,
}

impl<E> Default for Chain<E>
where
    E: traits::Executor,
{
    /// Constructs a new `Chain` with an empty conversation state.
    fn default() -> Self {
        Self {
            state: ChatMessageCollection::new(),
            _phantom: std::marker::PhantomData,
        }
    }
}

impl<E: traits::Executor> Chain<E> {
    /// Constructs a new `Chain` with the given conversation state.
    /// Self,
    /// # Arguments
    /// * `state` - The initial prompt state to use.
    pub fn new(state: PromptTemplate) -> Result<Chain<E>, Error<E::Error>> {
        Ok(state
            .format(&parameters!())
            .map(|state| state.to_chat())
            .map(|state| Self {
                state,
                _phantom: std::marker::PhantomData,
            })?)
    }

    /// Constructs a new `Chain` with the given conversation state by passing a ChatMessageCollection<String> (clone).
    /// Self,
    /// # Arguments
    /// * `state` - The initial prompt state to use.
    pub fn new_with_message_collection(state: &ChatMessageCollection<String>) -> Chain<E> {
        Self {
            state: state.clone(),
            _phantom: std::marker::PhantomData,
        }
    }

    /// Sends a message to the LLM and returns the response.
    ///
    /// This method sends a message to the LLM, adding it and the response to the internal state.
    ///
    /// # Arguments
    /// * `step` - The step to send.
    /// * `parameters` - The parameters to use when formatting the step.
    /// * `exec` - The executor to use.
    ///
    /// # Returns
    /// A `Result` containing the LLM's response as `E::Output` on success or an `Error` variant on failure.
    pub async fn send_message(
        &mut self,
        step: Step<E>,
        parameters: &Parameters,
        exec: &E,
    ) -> Result<E::Output, Error<E::Error>> {
        let fmt = step.format(parameters)?;
        self.send_message_raw(step.options(), &fmt, step.is_streaming(), exec)
            .await
    }

    /// Sends a message to the LLM and returns the response.
    ///
    /// This method takes a ready prompt and options and sends it to the LLM, adding it and the response to the internal state.
    ///
    /// # Arguments
    /// * `options` - The options to use when executing the prompt.
    /// * `prompt` - The prompt to send.
    /// * `exec` - The executor to use.
    ///
    /// # Returns
    /// A `Result` containing the LLM's response as `E::Output` on success or an `Error` variant on failure.
    pub async fn send_message_raw(
        &mut self,
        options: Option<&<E as traits::Executor>::PerInvocationOptions>,
        prompt: &Prompt,
        is_streaming: Option<bool>,
        exec: &E,
    ) -> Result<E::Output, Error<E::Error>> {
        let tok = exec.tokens_used(options, prompt)?;
        let tokens_remaining = tok.tokens_remaining();
        let tokenizer = exec.get_tokenizer(options)?;
        self.state.trim_context(&tokenizer, tokens_remaining)?;

        // Combine the conversation history with the new prompt.
        let prompt_with_history = Prompt::Chat(self.state.clone()).combine(prompt);

        // Execute the prompt and retrieve the LLM's response.
        let res = exec
            .execute(options, &prompt_with_history, is_streaming)
            .await?;

        // Create a ChatMessage from the response and add it to the conversation state.
        let response_message = ChatMessage::new(
            res.get_chat_role()
                .await
                .unwrap_or(crate::prompt::ChatRole::Assistant),
            res.primary_textual_output()
                .await
                .ok_or(Error::NoModelOutput)?,
        );
        self.state = prompt_with_history.to_chat();
        self.state.add_message(response_message);

        Ok(res)
    }
}

/// An error type representing various errors that can occur while interacting with the `Chain`.
#[derive(thiserror::Error, Debug)]
pub enum Error<E: ExecutorError> {
    #[error("PromptTokensError: {0}")]
    PromptTokens(#[from] PromptTokensError),
    #[error("TokenizerError: {0}")]
    Tokenizer(#[from] TokenizerError),
    #[error("ExecutorError: {0}")]
    Executor(#[from] E),
    #[error("No model output")]
    NoModelOutput,
    #[error("StringTemplateError: {0}")]
    StringTemplate(#[from] crate::prompt::StringTemplateError),
}