langchain_rust/tools/sql/
sql.rs1use std::{collections::HashSet, error::Error};
2
3use async_trait::async_trait;
4use serde::{Deserialize, Serialize};
5
6#[derive(Serialize, Deserialize)]
7pub enum Dialect {
8 #[serde(rename = "mysql")]
9 MySQL,
10 #[serde(rename = "sqlite")]
11 SQLite,
12 #[serde(rename = "postgresql")]
13 PostgreSQL,
14}
15impl ToString for Dialect {
16 fn to_string(&self) -> String {
17 match self {
18 Dialect::MySQL => "mysql".to_string(),
19 Dialect::SQLite => "sqlite".to_string(),
20 Dialect::PostgreSQL => "postgresql".to_string(),
21 }
22 }
23}
24
25#[async_trait]
26pub trait Engine: Send + Sync {
27 fn dialect(&self) -> Dialect;
29 async fn query(&self, query: &str) -> Result<(Vec<String>, Vec<Vec<String>>), Box<dyn Error>>;
31 async fn table_names(&self) -> Result<Vec<String>, Box<dyn Error>>;
33 async fn table_info(&self, tables: &str) -> Result<String, Box<dyn Error>>;
36 fn close(&self) -> Result<(), Box<dyn Error>>;
38}
39
40pub struct SQLDatabase {
41 pub engine: Box<dyn Engine>,
42 pub sample_rows_number: i32,
43 pub all_tables: HashSet<String>,
44}
45
46pub struct SQLDatabaseBuilder {
47 engine: Box<dyn Engine>,
48 sample_rows_number: i32,
49 ignore_tables: HashSet<String>,
50}
51
52impl SQLDatabaseBuilder {
53 pub fn new<E>(engine: E) -> Self
54 where
55 E: Engine + 'static,
56 {
57 SQLDatabaseBuilder {
58 engine: Box::new(engine),
59 sample_rows_number: 3, ignore_tables: HashSet::new(),
61 }
62 }
63
64 pub fn custom_sample_rows_number(mut self, number: i32) -> Self {
66 self.sample_rows_number = number;
67 self
68 }
69
70 pub fn ignore_tables(mut self, ignore_tables: HashSet<String>) -> Self {
72 self.ignore_tables = ignore_tables;
73 self
74 }
75
76 pub async fn build(self) -> Result<SQLDatabase, Box<dyn Error>> {
78 let table_names_result = self.engine.table_names().await;
79
80 let table_names = match table_names_result {
82 Ok(names) => names,
83 Err(error) => {
84 return Err(error);
85 }
86 };
87
88 let all_tables: HashSet<String> = table_names
90 .into_iter()
91 .filter(|name| !self.ignore_tables.contains(name))
92 .collect();
93
94 Ok(SQLDatabase {
95 engine: self.engine,
96 sample_rows_number: self.sample_rows_number,
97 all_tables,
98 })
99 }
100}
101
102impl SQLDatabase {
103 pub fn dialect(&self) -> Dialect {
104 self.engine.dialect()
105 }
106
107 pub fn table_names(&self) -> Vec<String> {
108 self.all_tables.iter().cloned().collect()
109 }
110
111 pub async fn table_info(&self, tables: &[String]) -> Result<String, Box<dyn Error>> {
112 let mut tables: HashSet<String> = tables.to_vec().into_iter().collect();
113 if tables.is_empty() {
114 tables = self.all_tables.clone();
115 }
116 let mut info = String::new();
117 for table in tables {
118 let table_info = self.engine.table_info(&table).await?;
119 info.push_str(&table_info);
120 info.push_str("\n\n");
121
122 if self.sample_rows_number > 0 {
123 let sample_rows = self.sample_rows(&table).await?;
124 info.push_str("/*\n");
125 info.push_str(&sample_rows);
126 info.push_str("*/ \n\n");
127 }
128 }
129 Ok(info)
130 }
131
132 pub async fn query(&self, query: &str) -> Result<String, Box<dyn Error>> {
133 log::debug!("Query: {}", query);
134 let (cols, results) = self.engine.query(query).await?;
135 let mut str = cols.join("\t") + "\n";
136 for row in results {
137 str += &row.join("\t");
138 str.push('\n');
139 }
140 Ok(str)
141 }
142
143 pub fn close(&self) -> Result<(), Box<dyn Error>> {
144 self.engine.close()
145 }
146
147 pub async fn sample_rows(&self, table: &str) -> Result<String, Box<dyn Error>> {
148 let query = format!("SELECT * FROM {} LIMIT {}", table, self.sample_rows_number);
149 log::debug!("Sample Rows Query: {}", query);
150 self.query(&query).await
151 }
152}