langchain_rust/chain/sql_datbase/
chain.rs1use std::pin::Pin;
2
3use async_trait::async_trait;
4use futures::Stream;
5use serde_json::Value;
6
7use crate::{
8 chain::{chain_trait::Chain, llm_chain::LLMChain, ChainError},
9 language_models::{GenerateResult, TokenUsage},
10 prompt::PromptArgs,
11 prompt_args,
12 schemas::StreamData,
13 tools::SQLDatabase,
14};
15
16use super::{
17 QUERY_PREFIX_WITH, SQL_CHAIN_DEFAULT_INPUT_KEY_QUERY, SQL_CHAIN_DEFAULT_INPUT_KEY_TABLE_NAMES,
18 STOP_WORD,
19};
20
21pub struct SqlChainPromptBuilder {
22 query: String,
23}
24impl SqlChainPromptBuilder {
25 pub fn new() -> Self {
26 Self {
27 query: "".to_string(),
28 }
29 }
30
31 pub fn query<S: Into<String>>(mut self, input: S) -> Self {
32 self.query = input.into();
33 self
34 }
35
36 pub fn build(self) -> PromptArgs {
37 prompt_args! {
38 SQL_CHAIN_DEFAULT_INPUT_KEY_QUERY => self.query,
39 }
40 }
41}
42
43pub struct SQLDatabaseChain {
44 pub(crate) llmchain: LLMChain,
45 pub(crate) top_k: usize,
46 pub(crate) database: SQLDatabase,
47}
48
49impl SQLDatabaseChain {
86 pub fn prompt_builder(&self) -> SqlChainPromptBuilder {
87 SqlChainPromptBuilder::new()
88 }
89
90 async fn call_builder_chains(
91 &self,
92 input_variables: &PromptArgs,
93 ) -> Result<(PromptArgs, Option<TokenUsage>), ChainError> {
94 let mut token_usage: Option<TokenUsage> = None;
95
96 let query = input_variables
97 .get(SQL_CHAIN_DEFAULT_INPUT_KEY_QUERY)
98 .ok_or_else(|| {
99 ChainError::MissingInputVariable(SQL_CHAIN_DEFAULT_INPUT_KEY_QUERY.to_string())
100 })?
101 .to_string();
102
103 let mut tables: Vec<String> = Vec::new();
104 if let Some(value) = input_variables.get(SQL_CHAIN_DEFAULT_INPUT_KEY_TABLE_NAMES) {
105 if let serde_json::Value::Array(array) = value {
106 for item in array {
107 if let serde_json::Value::String(str) = item {
108 tables.push(str.clone());
109 }
110 }
111 }
112 }
113
114 let tables_info = self
115 .database
116 .table_info(&tables)
117 .await
118 .map_err(|e| ChainError::DatabaseError(e.to_string()))?;
119
120 let mut llm_inputs = prompt_args! {
121 "input"=> query.clone() + QUERY_PREFIX_WITH,
122 "top_k"=> self.top_k,
123 "dialect"=> self.database.dialect().to_string(),
124 "table_info"=> tables_info,
125
126 };
127
128 let output = self.llmchain.call(llm_inputs.clone()).await?;
129 if let Some(tokens) = output.tokens {
130 token_usage = Some(tokens);
131 }
132
133 let sql_query = output.generation.trim();
134 log::debug!("output: {:?}", sql_query);
135 let query_result = self
136 .database
137 .query(sql_query)
138 .await
139 .map_err(|e| ChainError::DatabaseError(e.to_string()))?;
140
141 llm_inputs.insert(
142 "input".to_string(),
143 Value::from(format!(
144 "{}{}{}{}{}",
145 &query, QUERY_PREFIX_WITH, sql_query, STOP_WORD, &query_result,
146 )),
147 );
148 Ok((llm_inputs, token_usage))
149 }
150}
151
152#[async_trait]
153impl Chain for SQLDatabaseChain {
154 fn get_input_keys(&self) -> Vec<String> {
155 self.llmchain.get_input_keys()
156 }
157
158 async fn call(&self, input_variables: PromptArgs) -> Result<GenerateResult, ChainError> {
159 let (llm_inputs, mut token_usage) = self.call_builder_chains(&input_variables).await?;
160 let output = self.llmchain.call(llm_inputs).await?;
161 if let Some(tokens) = output.tokens {
162 if let Some(general_result) = token_usage.as_mut() {
163 general_result.completion_tokens += tokens.completion_tokens;
164 general_result.total_tokens += tokens.total_tokens;
165 }
166 }
167
168 let strs: Vec<&str> = output
169 .generation
170 .split("\n\n")
171 .next()
172 .unwrap_or("")
173 .split("Answer:")
174 .collect();
175 let mut output = strs[0];
176 if strs.len() > 1 {
177 output = strs[1];
178 }
179 output = output.trim();
180 Ok(GenerateResult {
181 generation: output.to_string(),
182 tokens: token_usage,
183 })
184 }
185
186 async fn invoke(&self, input_variables: PromptArgs) -> Result<String, ChainError> {
187 let result = self.call(input_variables).await?;
188 Ok(result.generation)
189 }
190
191 async fn stream(
192 &self,
193 input_variables: PromptArgs,
194 ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamData, ChainError>> + Send>>, ChainError>
195 {
196 let (llm_inputs, _) = self.call_builder_chains(&input_variables).await?;
197
198 self.llmchain.stream(llm_inputs).await
199 }
200}