langchain_rust/chain/sql_datbase/
builder.rs

1use crate::{
2    chain::{
3        llm_chain::LLMChainBuilder, options::ChainCallOptions, ChainError, DEFAULT_OUTPUT_KEY,
4    },
5    language_models::llm::LLM,
6    output_parsers::OutputParser,
7    prompt::HumanMessagePromptTemplate,
8    template_jinja2,
9    tools::SQLDatabase,
10};
11
12use super::{
13    chain::SQLDatabaseChain,
14    prompt::{DEFAULT_SQLSUFFIX, DEFAULT_SQLTEMPLATE},
15    STOP_WORD,
16};
17
18pub struct SQLDatabaseChainBuilder {
19    llm: Option<Box<dyn LLM>>,
20    options: Option<ChainCallOptions>,
21    top_k: Option<usize>,
22    database: Option<SQLDatabase>,
23    output_key: Option<String>,
24    output_parser: Option<Box<dyn OutputParser>>,
25}
26
27impl SQLDatabaseChainBuilder {
28    pub fn new() -> Self {
29        Self {
30            llm: None,
31            options: None,
32            top_k: None,
33            database: None,
34            output_key: None,
35            output_parser: None,
36        }
37    }
38
39    pub fn llm<L: Into<Box<dyn LLM>>>(mut self, llm: L) -> Self {
40        self.llm = Some(llm.into());
41        self
42    }
43
44    pub fn output_key<S: Into<String>>(mut self, output_key: S) -> Self {
45        self.output_key = Some(output_key.into());
46        self
47    }
48
49    pub fn output_parser<P: Into<Box<dyn OutputParser>>>(mut self, output_parser: P) -> Self {
50        self.output_parser = Some(output_parser.into());
51        self
52    }
53
54    pub fn options(mut self, options: ChainCallOptions) -> Self {
55        self.options = Some(options);
56        self
57    }
58
59    pub fn top_k(mut self, top_k: usize) -> Self {
60        self.top_k = Some(top_k);
61        self
62    }
63
64    pub fn database(mut self, database: SQLDatabase) -> Self {
65        self.database = Some(database);
66        self
67    }
68
69    pub fn build(self) -> Result<SQLDatabaseChain, ChainError> {
70        let llm = self
71            .llm
72            .ok_or_else(|| ChainError::MissingObject("LLM must be set".into()))?;
73        let top_k = self
74            .top_k
75            .ok_or_else(|| ChainError::MissingObject("Top K must be set".into()))?;
76        let database = self
77            .database
78            .ok_or_else(|| ChainError::MissingObject("Database must be set".into()))?;
79
80        let prompt = HumanMessagePromptTemplate::new(template_jinja2!(
81            format!("{}{}", DEFAULT_SQLTEMPLATE, DEFAULT_SQLSUFFIX),
82            "dialect",
83            "table_info",
84            "top_k",
85            "input"
86        ));
87
88        let llm_chain = {
89            let mut builder = LLMChainBuilder::new()
90                .prompt(prompt)
91                .output_key(self.output_key.unwrap_or_else(|| DEFAULT_OUTPUT_KEY.into()))
92                .llm(llm);
93
94            let mut options = self.options.unwrap_or_default();
95            options = options.with_stop_words(vec![STOP_WORD.to_string()]);
96            builder = builder.options(options);
97
98            if let Some(output_parser) = self.output_parser {
99                builder = builder.output_parser(output_parser);
100            }
101
102            builder.build()?
103        };
104
105        Ok(SQLDatabaseChain {
106            llmchain: llm_chain,
107            top_k,
108            database,
109        })
110    }
111}