langchain_rust/chain/sql_datbase/
chain.rs

1use std::pin::Pin;
2
3use async_trait::async_trait;
4use futures::Stream;
5use serde_json::Value;
6
7use crate::{
8    chain::{chain_trait::Chain, llm_chain::LLMChain, ChainError},
9    language_models::{GenerateResult, TokenUsage},
10    prompt::PromptArgs,
11    prompt_args,
12    schemas::StreamData,
13    tools::SQLDatabase,
14};
15
16use super::{
17    QUERY_PREFIX_WITH, SQL_CHAIN_DEFAULT_INPUT_KEY_QUERY, SQL_CHAIN_DEFAULT_INPUT_KEY_TABLE_NAMES,
18    STOP_WORD,
19};
20
21pub struct SqlChainPromptBuilder {
22    query: String,
23}
24impl SqlChainPromptBuilder {
25    pub fn new() -> Self {
26        Self {
27            query: "".to_string(),
28        }
29    }
30
31    pub fn query<S: Into<String>>(mut self, input: S) -> Self {
32        self.query = input.into();
33        self
34    }
35
36    pub fn build(self) -> PromptArgs {
37        prompt_args! {
38          SQL_CHAIN_DEFAULT_INPUT_KEY_QUERY  => self.query,
39        }
40    }
41}
42
43pub struct SQLDatabaseChain {
44    pub(crate) llmchain: LLMChain,
45    pub(crate) top_k: usize,
46    pub(crate) database: SQLDatabase,
47}
48
49/// SQLChain let you interact with a db in human lenguage
50///
51/// The input variable name is `query`.
52/// Example
53/// ```rust,ignore
54/// # async {
55/// let options = ChainCallOptions::default();
56/// let llm = OpenAI::default();
57///
58/// let db = std::env::var("DATABASE_URL").expect("DATABASE_URL must be set");
59/// let engine = PostgreSQLEngine::new(&db).await.unwrap();
60/// let db = SQLDatabaseBuilder::new(engine).build().await.unwrap();
61/// let chain = SQLDatabaseChainBuilder::new()
62///     .llm(llm)
63///     .top_k(4)
64///     .database(db)
65///     .options(options)
66///     .build()
67///     .expect("Failed to build LLMChain");
68///
69/// let input_variables = prompt_args! {
70///     "query" => "Whats the phone number of luis"
71///   };
72///   //OR
73/// let input_variables = chain.prompt_builder()
74///     .query("Whats the phone number of luis")
75///     .build();
76/// match chain.invoke(input_variables).await {
77///    Ok(result) => {
78///     println!("Result: {:?}", result);
79/// }
80/// Err(e) => panic!("Error invoking LLMChain: {:?}", e),
81/// }
82///
83/// }
84/// ```
85impl SQLDatabaseChain {
86    pub fn prompt_builder(&self) -> SqlChainPromptBuilder {
87        SqlChainPromptBuilder::new()
88    }
89
90    async fn call_builder_chains(
91        &self,
92        input_variables: &PromptArgs,
93    ) -> Result<(PromptArgs, Option<TokenUsage>), ChainError> {
94        let mut token_usage: Option<TokenUsage> = None;
95
96        let query = input_variables
97            .get(SQL_CHAIN_DEFAULT_INPUT_KEY_QUERY)
98            .ok_or_else(|| {
99                ChainError::MissingInputVariable(SQL_CHAIN_DEFAULT_INPUT_KEY_QUERY.to_string())
100            })?
101            .to_string();
102
103        let mut tables: Vec<String> = Vec::new();
104        if let Some(value) = input_variables.get(SQL_CHAIN_DEFAULT_INPUT_KEY_TABLE_NAMES) {
105            if let serde_json::Value::Array(array) = value {
106                for item in array {
107                    if let serde_json::Value::String(str) = item {
108                        tables.push(str.clone());
109                    }
110                }
111            }
112        }
113
114        let tables_info = self
115            .database
116            .table_info(&tables)
117            .await
118            .map_err(|e| ChainError::DatabaseError(e.to_string()))?;
119
120        let mut llm_inputs = prompt_args! {
121            "input"=> query.clone() + QUERY_PREFIX_WITH,
122            "top_k"=> self.top_k,
123            "dialect"=> self.database.dialect().to_string(),
124            "table_info"=> tables_info,
125
126        };
127
128        let output = self.llmchain.call(llm_inputs.clone()).await?;
129        if let Some(tokens) = output.tokens {
130            token_usage = Some(tokens);
131        }
132
133        let sql_query = output.generation.trim();
134        log::debug!("output: {:?}", sql_query);
135        let query_result = self
136            .database
137            .query(sql_query)
138            .await
139            .map_err(|e| ChainError::DatabaseError(e.to_string()))?;
140
141        llm_inputs.insert(
142            "input".to_string(),
143            Value::from(format!(
144                "{}{}{}{}{}",
145                &query, QUERY_PREFIX_WITH, sql_query, STOP_WORD, &query_result,
146            )),
147        );
148        Ok((llm_inputs, token_usage))
149    }
150}
151
152#[async_trait]
153impl Chain for SQLDatabaseChain {
154    fn get_input_keys(&self) -> Vec<String> {
155        self.llmchain.get_input_keys()
156    }
157
158    async fn call(&self, input_variables: PromptArgs) -> Result<GenerateResult, ChainError> {
159        let (llm_inputs, mut token_usage) = self.call_builder_chains(&input_variables).await?;
160        let output = self.llmchain.call(llm_inputs).await?;
161        if let Some(tokens) = output.tokens {
162            if let Some(general_result) = token_usage.as_mut() {
163                general_result.completion_tokens += tokens.completion_tokens;
164                general_result.total_tokens += tokens.total_tokens;
165            }
166        }
167
168        let strs: Vec<&str> = output
169            .generation
170            .split("\n\n")
171            .next()
172            .unwrap_or("")
173            .split("Answer:")
174            .collect();
175        let mut output = strs[0];
176        if strs.len() > 1 {
177            output = strs[1];
178        }
179        output = output.trim();
180        Ok(GenerateResult {
181            generation: output.to_string(),
182            tokens: token_usage,
183        })
184    }
185
186    async fn invoke(&self, input_variables: PromptArgs) -> Result<String, ChainError> {
187        let result = self.call(input_variables).await?;
188        Ok(result.generation)
189    }
190
191    async fn stream(
192        &self,
193        input_variables: PromptArgs,
194    ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamData, ChainError>> + Send>>, ChainError>
195    {
196        let (llm_inputs, _) = self.call_builder_chains(&input_variables).await?;
197
198        self.llmchain.stream(llm_inputs).await
199    }
200}