1pub mod convert;
2pub mod dialect;
3
4use sqlx::any::AnyPoolOptions;
5use sqlx::AnyPool;
6
7use crate::error::McpSqlError;
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub enum DbBackend {
11 Postgres,
12 Sqlite,
13 Mysql,
14}
15
16impl DbBackend {
17 pub fn from_url(url: &str) -> Result<Self, McpSqlError> {
18 if url.starts_with("postgres://") || url.starts_with("postgresql://") {
19 Ok(DbBackend::Postgres)
20 } else if url.starts_with("sqlite:") {
21 Ok(DbBackend::Sqlite)
22 } else if url.starts_with("mysql://") || url.starts_with("mariadb://") {
23 Ok(DbBackend::Mysql)
24 } else {
25 Err(McpSqlError::Other(format!(
26 "Unsupported database URL scheme: {url}"
27 )))
28 }
29 }
30
31 pub fn name(&self) -> &'static str {
32 match self {
33 DbBackend::Postgres => "postgres",
34 DbBackend::Sqlite => "sqlite",
35 DbBackend::Mysql => "mysql",
36 }
37 }
38}
39
40#[derive(Clone)]
41pub struct DatabaseEntry {
42 pub name: String,
43 pub pool: AnyPool,
44 pub backend: DbBackend,
45 pub url_redacted: String,
46}
47
48#[derive(Clone)]
49pub struct DatabaseManager {
50 pub databases: Vec<DatabaseEntry>,
51}
52
53impl DatabaseManager {
54 pub async fn new(urls: &[String]) -> Result<Self, McpSqlError> {
55 let mut databases = Vec::with_capacity(urls.len());
56
57 for url in urls {
58 let backend = DbBackend::from_url(url)?;
59 let name = extract_db_name(url, backend);
60
61 let pool = AnyPoolOptions::new()
62 .max_connections(5)
63 .connect(url)
64 .await?;
65
66 databases.push(DatabaseEntry {
67 name,
68 pool,
69 backend,
70 url_redacted: redact_url(url),
71 });
72 }
73
74 Ok(Self { databases })
75 }
76
77 pub fn resolve(&self, database: Option<&str>) -> Result<&DatabaseEntry, McpSqlError> {
80 match database {
81 Some(name) => self
82 .databases
83 .iter()
84 .find(|d| d.name == name)
85 .ok_or_else(|| {
86 let available: Vec<&str> = self.databases.iter().map(|d| d.name.as_str()).collect();
87 McpSqlError::DatabaseNotFound(format!(
88 "'{name}' not found. Available: {}",
89 available.join(", ")
90 ))
91 }),
92 None => {
93 if self.databases.len() == 1 {
94 Ok(&self.databases[0])
95 } else {
96 Err(McpSqlError::AmbiguousDatabase)
97 }
98 }
99 }
100 }
101}
102
103fn extract_db_name(url: &str, backend: DbBackend) -> String {
105 match backend {
106 DbBackend::Sqlite => {
107 let path = url.strip_prefix("sqlite:").unwrap_or(url);
108 if path == ":memory:" || path.is_empty() {
109 return "memory".to_string();
110 }
111 std::path::Path::new(path)
113 .file_stem()
114 .and_then(|s| s.to_str())
115 .unwrap_or("sqlite")
116 .to_string()
117 }
118 DbBackend::Postgres | DbBackend::Mysql => {
119 if let Ok(parsed) = url::Url::parse(url) {
121 let path = parsed.path().trim_start_matches('/');
122 if !path.is_empty() {
123 return path.to_string();
124 }
125 }
126 backend.name().to_string()
127 }
128 }
129}
130
131fn redact_url(url: &str) -> String {
133 if let Ok(mut parsed) = url::Url::parse(url) {
134 if parsed.password().is_some() {
135 let _ = parsed.set_password(Some("****"));
136 }
137 parsed.to_string()
138 } else {
139 url.to_string()
140 }
141}
142
143#[cfg(test)]
144mod tests {
145 use super::*;
146
147 #[test]
148 fn test_backend_from_url() {
149 assert_eq!(
150 DbBackend::from_url("postgres://localhost/mydb").unwrap(),
151 DbBackend::Postgres
152 );
153 assert_eq!(
154 DbBackend::from_url("postgresql://localhost/mydb").unwrap(),
155 DbBackend::Postgres
156 );
157 assert_eq!(
158 DbBackend::from_url("sqlite:test.db").unwrap(),
159 DbBackend::Sqlite
160 );
161 assert_eq!(
162 DbBackend::from_url("sqlite::memory:").unwrap(),
163 DbBackend::Sqlite
164 );
165 assert_eq!(
166 DbBackend::from_url("mysql://localhost/mydb").unwrap(),
167 DbBackend::Mysql
168 );
169 assert!(DbBackend::from_url("oracle://localhost/mydb").is_err());
170 }
171
172 #[test]
173 fn test_extract_db_name() {
174 assert_eq!(
175 extract_db_name("postgres://user:pass@localhost/mydb", DbBackend::Postgres),
176 "mydb"
177 );
178 assert_eq!(
179 extract_db_name("sqlite:test.db", DbBackend::Sqlite),
180 "test"
181 );
182 assert_eq!(
183 extract_db_name("sqlite::memory:", DbBackend::Sqlite),
184 "memory"
185 );
186 assert_eq!(
187 extract_db_name("mysql://user:pass@localhost/app", DbBackend::Mysql),
188 "app"
189 );
190 }
191
192 #[test]
193 fn test_redact_url() {
194 assert_eq!(
195 redact_url("postgres://user:secret@localhost/mydb"),
196 "postgres://user:****@localhost/mydb"
197 );
198 assert_eq!(
199 redact_url("postgres://user@localhost/mydb"),
200 "postgres://user@localhost/mydb"
201 );
202 assert_eq!(redact_url("sqlite:test.db"), "sqlite:test.db");
203 }
204}