1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
use std::{error::Error, sync::Arc};

use async_trait::async_trait;
use tokio::sync::Mutex;

use crate::{
    language_models::{llm::LLM, GenerateResult},
    memory::SimpleMemory,
    prompt::{HumanMessagePromptTemplate, PromptArgs},
    schemas::{memory::BaseMemory, messages::Message},
    template_fstring,
};

use self::prompt::DEFAULT_TEMPLATE;

use super::{
    chain_trait::Chain,
    llm_chain::{LLMChain, LLMChainBuilder},
};

pub mod builder;
mod prompt;
pub struct ConversationalChain {
    llm: LLMChain,
    memory: Arc<Mutex<dyn BaseMemory>>,
}

//Conversational Chain is a simple chain to interact with ai as a string of messages
impl ConversationalChain {
    pub fn new<L: LLM + 'static>(llm: L) -> Result<Self, Box<dyn Error>> {
        let prompt = HumanMessagePromptTemplate::new(template_fstring!(
            DEFAULT_TEMPLATE,
            "history",
            "input"
        ));
        let llm_chain = LLMChainBuilder::new().prompt(prompt).llm(llm).build()?; //Using the llm
                                                                                 //chian whitout memroy, because the conversational chain will take care of the history
        Ok(Self {
            llm: llm_chain,
            memory: Arc::new(Mutex::new(SimpleMemory::new())),
        })
    }

    pub fn with_memory(mut self, memory: Arc<Mutex<dyn BaseMemory>>) -> Self {
        self.memory = memory;
        self
    }
}

#[async_trait]
impl Chain for ConversationalChain {
    async fn call(&self, input_variables: PromptArgs) -> Result<GenerateResult, Box<dyn Error>> {
        let mut memory = self.memory.lock().await;
        let mut input_variables = input_variables;
        input_variables.insert("history".to_string(), memory.to_string().into());
        let result = self.llm.call(input_variables.clone()).await?;
        memory.add_message(Message::new_ai_message(&input_variables["input"]));
        memory.add_message(Message::new_ai_message(&result.generation));
        Ok(result)
    }

    async fn invoke(&self, input_variables: PromptArgs) -> Result<String, Box<dyn Error>> {
        let mut memory = self.memory.lock().await;
        let mut input_variables = input_variables;
        input_variables.insert("history".to_string(), memory.to_string().into());
        let result = self.llm.invoke(input_variables.clone()).await?;
        memory.add_message(Message::new_ai_message(&input_variables["input"]));
        memory.add_message(Message::new_ai_message(&result));
        Ok(result)
    }
}

#[cfg(test)]
mod tests {
    use crate::{
        chain::conversational::builder::ConversationalChainBuilder,
        llm::openai::{OpenAI, OpenAIModel},
        prompt_args,
    };

    use super::*;

    #[tokio::test]
    async fn test_invoke_conversational() {
        let llm = OpenAI::default().with_model(OpenAIModel::Gpt35);
        let chain = ConversationalChainBuilder::new()
            .llm(llm)
            .build()
            .expect("Error building ConversationalChain");

        let input_variables = prompt_args! {
            "input" => "Soy de peru",
        };
        match chain.invoke(input_variables).await {
            Ok(result) => {
                println!("Result: {:?}", result);
            }
            Err(e) => panic!("Error invoking LLMChain: {:?}", e),
        }

        let input_variables = prompt_args! {
            "input" => "Cuales son platos tipicos de mi pais",
        };
        match chain.invoke(input_variables).await {
            Ok(result) => {
                println!("Result: {:?}", result);
            }
            Err(e) => panic!("Error invoking LLMChain: {:?}", e),
        }
    }
}