langchain_rust/tools/sql/
sql.rs

1use std::{collections::HashSet, error::Error};
2
3use async_trait::async_trait;
4use serde::{Deserialize, Serialize};
5
6#[derive(Serialize, Deserialize)]
7pub enum Dialect {
8    #[serde(rename = "mysql")]
9    MySQL,
10    #[serde(rename = "sqlite")]
11    SQLite,
12    #[serde(rename = "postgresql")]
13    PostgreSQL,
14}
15impl ToString for Dialect {
16    fn to_string(&self) -> String {
17        match self {
18            Dialect::MySQL => "mysql".to_string(),
19            Dialect::SQLite => "sqlite".to_string(),
20            Dialect::PostgreSQL => "postgresql".to_string(),
21        }
22    }
23}
24
25#[async_trait]
26pub trait Engine: Send + Sync {
27    // Dialect returns the dialect(e.g. mysql, sqlite, postgre) of the database.
28    fn dialect(&self) -> Dialect;
29    // Query executes the query and returns the columns and results.
30    async fn query(&self, query: &str) -> Result<(Vec<String>, Vec<Vec<String>>), Box<dyn Error>>;
31    // TableNames returns all the table names of the database.
32    async fn table_names(&self) -> Result<Vec<String>, Box<dyn Error>>;
33    // TableInfo returns the table information of the database.
34    // Typically, it returns the CREATE TABLE statement.
35    async fn table_info(&self, tables: &str) -> Result<String, Box<dyn Error>>;
36    // Close closes the database.
37    fn close(&self) -> Result<(), Box<dyn Error>>;
38}
39
40pub struct SQLDatabase {
41    pub engine: Box<dyn Engine>,
42    pub sample_rows_number: i32,
43    pub all_tables: HashSet<String>,
44}
45
46pub struct SQLDatabaseBuilder {
47    engine: Box<dyn Engine>,
48    sample_rows_number: i32,
49    ignore_tables: HashSet<String>,
50}
51
52impl SQLDatabaseBuilder {
53    pub fn new<E>(engine: E) -> Self
54    where
55        E: Engine + 'static,
56    {
57        SQLDatabaseBuilder {
58            engine: Box::new(engine),
59            sample_rows_number: 3, // Default value
60            ignore_tables: HashSet::new(),
61        }
62    }
63
64    // Function to set custom number of sample rows
65    pub fn custom_sample_rows_number(mut self, number: i32) -> Self {
66        self.sample_rows_number = number;
67        self
68    }
69
70    // Function to set tables to ignore
71    pub fn ignore_tables(mut self, ignore_tables: HashSet<String>) -> Self {
72        self.ignore_tables = ignore_tables;
73        self
74    }
75
76    // Function to build the SQLDatabase instance
77    pub async fn build(self) -> Result<SQLDatabase, Box<dyn Error>> {
78        let table_names_result = self.engine.table_names().await;
79
80        // Handle potential error from table_names call
81        let table_names = match table_names_result {
82            Ok(names) => names,
83            Err(error) => {
84                return Err(error);
85            }
86        };
87
88        // Filter out ignored tables
89        let all_tables: HashSet<String> = table_names
90            .into_iter()
91            .filter(|name| !self.ignore_tables.contains(name))
92            .collect();
93
94        Ok(SQLDatabase {
95            engine: self.engine,
96            sample_rows_number: self.sample_rows_number,
97            all_tables,
98        })
99    }
100}
101
102impl SQLDatabase {
103    pub fn dialect(&self) -> Dialect {
104        self.engine.dialect()
105    }
106
107    pub fn table_names(&self) -> Vec<String> {
108        self.all_tables.iter().cloned().collect()
109    }
110
111    pub async fn table_info(&self, tables: &[String]) -> Result<String, Box<dyn Error>> {
112        let mut tables: HashSet<String> = tables.to_vec().into_iter().collect();
113        if tables.is_empty() {
114            tables = self.all_tables.clone();
115        }
116        let mut info = String::new();
117        for table in tables {
118            let table_info = self.engine.table_info(&table).await?;
119            info.push_str(&table_info);
120            info.push_str("\n\n");
121
122            if self.sample_rows_number > 0 {
123                let sample_rows = self.sample_rows(&table).await?;
124                info.push_str("/*\n");
125                info.push_str(&sample_rows);
126                info.push_str("*/ \n\n");
127            }
128        }
129        Ok(info)
130    }
131
132    pub async fn query(&self, query: &str) -> Result<String, Box<dyn Error>> {
133        log::debug!("Query: {}", query);
134        let (cols, results) = self.engine.query(query).await?;
135        let mut str = cols.join("\t") + "\n";
136        for row in results {
137            str += &row.join("\t");
138            str.push('\n');
139        }
140        Ok(str)
141    }
142
143    pub fn close(&self) -> Result<(), Box<dyn Error>> {
144        self.engine.close()
145    }
146
147    pub async fn sample_rows(&self, table: &str) -> Result<String, Box<dyn Error>> {
148        let query = format!("SELECT * FROM {} LIMIT {}", table, self.sample_rows_number);
149        log::debug!("Sample Rows Query: {}", query);
150        self.query(&query).await
151    }
152}