Skip to main content

mcp_sql/db/
mod.rs

1pub mod convert;
2pub mod dialect;
3
4use sqlx::any::AnyPoolOptions;
5use sqlx::AnyPool;
6
7use crate::error::McpSqlError;
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub enum DbBackend {
11    Postgres,
12    Sqlite,
13    Mysql,
14}
15
16impl DbBackend {
17    pub fn from_url(url: &str) -> Result<Self, McpSqlError> {
18        if url.starts_with("postgres://") || url.starts_with("postgresql://") {
19            Ok(DbBackend::Postgres)
20        } else if url.starts_with("sqlite:") {
21            Ok(DbBackend::Sqlite)
22        } else if url.starts_with("mysql://") || url.starts_with("mariadb://") {
23            Ok(DbBackend::Mysql)
24        } else {
25            Err(McpSqlError::Other(format!(
26                "Unsupported database URL scheme: {url}"
27            )))
28        }
29    }
30
31    pub fn name(&self) -> &'static str {
32        match self {
33            DbBackend::Postgres => "postgres",
34            DbBackend::Sqlite => "sqlite",
35            DbBackend::Mysql => "mysql",
36        }
37    }
38}
39
40#[derive(Clone)]
41pub struct DatabaseEntry {
42    pub name: String,
43    pub pool: AnyPool,
44    pub backend: DbBackend,
45    pub url_redacted: String,
46}
47
48#[derive(Clone)]
49pub struct DatabaseManager {
50    pub databases: Vec<DatabaseEntry>,
51}
52
53impl DatabaseManager {
54    pub async fn new(urls: &[String]) -> Result<Self, McpSqlError> {
55        let mut databases = Vec::with_capacity(urls.len());
56
57        for url in urls {
58            let backend = DbBackend::from_url(url)?;
59            let name = extract_db_name(url, backend);
60
61            let pool = AnyPoolOptions::new()
62                .max_connections(5)
63                .connect(url)
64                .await?;
65
66            databases.push(DatabaseEntry {
67                name,
68                pool,
69                backend,
70                url_redacted: redact_url(url),
71            });
72        }
73
74        Ok(Self { databases })
75    }
76
77    /// Resolve which database to use. If `database` param is None and there's
78    /// exactly one DB, use it. Otherwise require an explicit name.
79    pub fn resolve(&self, database: Option<&str>) -> Result<&DatabaseEntry, McpSqlError> {
80        match database {
81            Some(name) => self
82                .databases
83                .iter()
84                .find(|d| d.name == name)
85                .ok_or_else(|| {
86                    let available: Vec<&str> = self.databases.iter().map(|d| d.name.as_str()).collect();
87                    McpSqlError::DatabaseNotFound(format!(
88                        "'{name}' not found. Available: {}",
89                        available.join(", ")
90                    ))
91                }),
92            None => {
93                if self.databases.len() == 1 {
94                    Ok(&self.databases[0])
95                } else {
96                    Err(McpSqlError::AmbiguousDatabase)
97                }
98            }
99        }
100    }
101}
102
103/// Extract a human-friendly name from the URL.
104fn extract_db_name(url: &str, backend: DbBackend) -> String {
105    match backend {
106        DbBackend::Sqlite => {
107            let path = url.strip_prefix("sqlite:").unwrap_or(url);
108            if path == ":memory:" || path.is_empty() {
109                return "memory".to_string();
110            }
111            // Use the filename without extension
112            std::path::Path::new(path)
113                .file_stem()
114                .and_then(|s| s.to_str())
115                .unwrap_or("sqlite")
116                .to_string()
117        }
118        DbBackend::Postgres | DbBackend::Mysql => {
119            // Parse as URL, extract the database name from the path
120            if let Ok(parsed) = url::Url::parse(url) {
121                let path = parsed.path().trim_start_matches('/');
122                if !path.is_empty() {
123                    return path.to_string();
124                }
125            }
126            backend.name().to_string()
127        }
128    }
129}
130
131/// Redact password from a database URL.
132fn redact_url(url: &str) -> String {
133    if let Ok(mut parsed) = url::Url::parse(url) {
134        if parsed.password().is_some() {
135            let _ = parsed.set_password(Some("****"));
136        }
137        parsed.to_string()
138    } else {
139        url.to_string()
140    }
141}
142
143#[cfg(test)]
144mod tests {
145    use super::*;
146
147    #[test]
148    fn test_backend_from_url() {
149        assert_eq!(
150            DbBackend::from_url("postgres://localhost/mydb").unwrap(),
151            DbBackend::Postgres
152        );
153        assert_eq!(
154            DbBackend::from_url("postgresql://localhost/mydb").unwrap(),
155            DbBackend::Postgres
156        );
157        assert_eq!(
158            DbBackend::from_url("sqlite:test.db").unwrap(),
159            DbBackend::Sqlite
160        );
161        assert_eq!(
162            DbBackend::from_url("sqlite::memory:").unwrap(),
163            DbBackend::Sqlite
164        );
165        assert_eq!(
166            DbBackend::from_url("mysql://localhost/mydb").unwrap(),
167            DbBackend::Mysql
168        );
169        assert!(DbBackend::from_url("oracle://localhost/mydb").is_err());
170    }
171
172    #[test]
173    fn test_extract_db_name() {
174        assert_eq!(
175            extract_db_name("postgres://user:pass@localhost/mydb", DbBackend::Postgres),
176            "mydb"
177        );
178        assert_eq!(
179            extract_db_name("sqlite:test.db", DbBackend::Sqlite),
180            "test"
181        );
182        assert_eq!(
183            extract_db_name("sqlite::memory:", DbBackend::Sqlite),
184            "memory"
185        );
186        assert_eq!(
187            extract_db_name("mysql://user:pass@localhost/app", DbBackend::Mysql),
188            "app"
189        );
190    }
191
192    #[test]
193    fn test_redact_url() {
194        assert_eq!(
195            redact_url("postgres://user:secret@localhost/mydb"),
196            "postgres://user:****@localhost/mydb"
197        );
198        assert_eq!(
199            redact_url("postgres://user@localhost/mydb"),
200            "postgres://user@localhost/mydb"
201        );
202        assert_eq!(redact_url("sqlite:test.db"), "sqlite:test.db");
203    }
204}