langchain_rust/chain/sql_datbase/
builder.rs1use crate::{
2 chain::{
3 llm_chain::LLMChainBuilder, options::ChainCallOptions, ChainError, DEFAULT_OUTPUT_KEY,
4 },
5 language_models::llm::LLM,
6 output_parsers::OutputParser,
7 prompt::HumanMessagePromptTemplate,
8 template_jinja2,
9 tools::SQLDatabase,
10};
11
12use super::{
13 chain::SQLDatabaseChain,
14 prompt::{DEFAULT_SQLSUFFIX, DEFAULT_SQLTEMPLATE},
15 STOP_WORD,
16};
17
18pub struct SQLDatabaseChainBuilder {
19 llm: Option<Box<dyn LLM>>,
20 options: Option<ChainCallOptions>,
21 top_k: Option<usize>,
22 database: Option<SQLDatabase>,
23 output_key: Option<String>,
24 output_parser: Option<Box<dyn OutputParser>>,
25}
26
27impl SQLDatabaseChainBuilder {
28 pub fn new() -> Self {
29 Self {
30 llm: None,
31 options: None,
32 top_k: None,
33 database: None,
34 output_key: None,
35 output_parser: None,
36 }
37 }
38
39 pub fn llm<L: Into<Box<dyn LLM>>>(mut self, llm: L) -> Self {
40 self.llm = Some(llm.into());
41 self
42 }
43
44 pub fn output_key<S: Into<String>>(mut self, output_key: S) -> Self {
45 self.output_key = Some(output_key.into());
46 self
47 }
48
49 pub fn output_parser<P: Into<Box<dyn OutputParser>>>(mut self, output_parser: P) -> Self {
50 self.output_parser = Some(output_parser.into());
51 self
52 }
53
54 pub fn options(mut self, options: ChainCallOptions) -> Self {
55 self.options = Some(options);
56 self
57 }
58
59 pub fn top_k(mut self, top_k: usize) -> Self {
60 self.top_k = Some(top_k);
61 self
62 }
63
64 pub fn database(mut self, database: SQLDatabase) -> Self {
65 self.database = Some(database);
66 self
67 }
68
69 pub fn build(self) -> Result<SQLDatabaseChain, ChainError> {
70 let llm = self
71 .llm
72 .ok_or_else(|| ChainError::MissingObject("LLM must be set".into()))?;
73 let top_k = self
74 .top_k
75 .ok_or_else(|| ChainError::MissingObject("Top K must be set".into()))?;
76 let database = self
77 .database
78 .ok_or_else(|| ChainError::MissingObject("Database must be set".into()))?;
79
80 let prompt = HumanMessagePromptTemplate::new(template_jinja2!(
81 format!("{}{}", DEFAULT_SQLTEMPLATE, DEFAULT_SQLSUFFIX),
82 "dialect",
83 "table_info",
84 "top_k",
85 "input"
86 ));
87
88 let llm_chain = {
89 let mut builder = LLMChainBuilder::new()
90 .prompt(prompt)
91 .output_key(self.output_key.unwrap_or_else(|| DEFAULT_OUTPUT_KEY.into()))
92 .llm(llm);
93
94 let mut options = self.options.unwrap_or_default();
95 options = options.with_stop_words(vec![STOP_WORD.to_string()]);
96 builder = builder.options(options);
97
98 if let Some(output_parser) = self.output_parser {
99 builder = builder.output_parser(output_parser);
100 }
101
102 builder.build()?
103 };
104
105 Ok(SQLDatabaseChain {
106 llmchain: llm_chain,
107 top_k,
108 database,
109 })
110 }
111}