mcp_sqlite/
server.rs

1/*!
2 * # SQLite MCP服务器实现
3 *
4 * 本模块实现了SQLite MCP服务器的核心功能,包括:
5 *
6 * - SQLite连接管理
7 * - MCP方法实现(query, execute, executemany, executescript)
8 * - 参数处理和结果格式化
9 *
10 * ## 主要组件
11 *
12 * - [`SQLiteRouter`][]: 实现MCP路由器接口,处理客户端请求
13 *
14 * ## 支持的MCP方法
15 *
16 * ### `query`
17 *
18 * 执行SQL查询并返回结果。
19 *
20 * #### 查询参数
21 *
22 * - `query`:要执行的SQL查询
23 * - `params`:(可选)绑定到查询的参数
24 *
25 * #### 查询返回值
26 *
27 * - `columns`:列名
28 * - `rows`:查询返回的行
29 *
30 * ### `execute`
31 *
32 * 执行SQL语句。
33 *
34 * #### 执行参数
35 *
36 * - `statement`:要执行的SQL语句
37 * - `params`:(可选)绑定到语句的参数
38 *
39 * #### 执行返回值
40 *
41 * - `rowcount`:受影响的行数
42 * - `lastrowid`:最后插入行的ID(如适用)
43 *
44 * ### `executemany`
45 *
46 * 使用不同参数多次执行SQL语句。
47 *
48 * #### 批量执行参数
49 *
50 * - `statement`:要执行的SQL语句
51 * - `params_list`:绑定到语句的参数列表
52 *
53 * #### 批量执行返回值
54 *
55 * - `rowcount`:受影响的行数
56 *
57 * ### `executescript`
58 *
59 * 执行SQL脚本。
60 *
61 * #### 脚本参数
62 *
63 * - `script`:要执行的SQL脚本
64 *
65 * #### 脚本返回值
66 *
67 * - `rowcount`:受影响的行数
68 */
69
70use std::{future::Future, pin::Pin, sync::Arc};
71
72use base64::{engine::general_purpose::STANDARD, Engine};
73use mcp_core_fishcode2025::{
74    handler::{PromptError, ResourceError, ToolError},
75    prompt::Prompt,
76    protocol::ServerCapabilities,
77    Content, Resource, Tool,
78};
79use mcp_server_fishcode2025::router::CapabilitiesBuilder;
80use rusqlite::{Connection, Row, ToSql};
81use serde_json::{json, Value};
82use tokio::sync::Mutex;
83use tracing::debug;
84
85/// SQLite MCP服务器路由器
86///
87/// 负责处理MCP客户端请求,执行SQL操作,并返回结果
88pub struct SQLiteRouter {
89    /// SQLite数据库连接
90    conn: Arc<Mutex<Connection>>,
91}
92
93impl SQLiteRouter {
94    /// 创建一个新的SQLite MCP服务器路由器
95    ///
96    /// # 参数
97    ///
98    /// * `db_path` - SQLite数据库文件路径,使用":memory:"表示内存数据库
99    ///
100    /// # 返回值
101    ///
102    /// 成功时返回`SQLiteRouter`实例,失败时返回SQLite错误
103    ///
104    /// # 示例
105    ///
106    /// ```
107    /// use mcp_sqlite::server::SQLiteRouter;
108    ///
109    /// let router = SQLiteRouter::new(":memory:").expect("创建路由器失败");
110    /// ```
111    pub fn new(db_path: &str) -> Result<Self, rusqlite::Error> {
112        let conn = Connection::open(db_path)?;
113        Ok(Self {
114            conn: Arc::new(Mutex::new(conn)),
115        })
116    }
117
118    /// 执行SQL查询并返回结果
119    ///
120    /// # 参数
121    ///
122    /// * `params` - 包含查询参数的JSON对象,必须包含"query"字段,可选包含"params"字段
123    ///
124    /// # 返回值
125    ///
126    /// 成功时返回包含查询结果的JSON对象,失败时返回工具错误
127    async fn query(&self, params: Value) -> Result<Value, ToolError> {
128        // 获取查询参数
129        let query = match params.get("query") {
130            Some(Value::String(q)) => q,
131            _ => {
132                return Err(ToolError::InvalidParameters(
133                    "Missing required parameter: query".into(),
134                ))
135            }
136        };
137
138        // 获取绑定参数
139        let params_json = json!([]);
140        let bind_params = params.get("params").unwrap_or(&params_json);
141        let bind_params = match bind_params {
142            Value::Array(arr) => arr,
143            _ => {
144                return Err(ToolError::InvalidParameters(
145                    "params must be an array".into(),
146                ))
147            }
148        };
149
150        // 执行查询
151        let conn = self.conn.lock().await;
152
153        let mut stmt = match conn.prepare(query) {
154            Ok(stmt) => stmt,
155            Err(e) => {
156                return Err(ToolError::ExecutionError(format!(
157                    "Failed to prepare query: {}",
158                    e
159                )))
160            }
161        };
162
163        // 将JSON参数转换为SQLite参数
164        let sql_params: Vec<Box<dyn ToSql>> =
165            bind_params.iter().map(|v| json_value_to_sql(v)).collect();
166
167        let sql_params_refs: Vec<&dyn ToSql> = sql_params.iter().map(|p| p.as_ref()).collect();
168
169        // 先获取列名,避免借用冲突
170        let column_names: Vec<String> = {
171            let names = stmt.column_names();
172            names.iter().map(|s| s.to_string()).collect()
173        };
174
175        // 执行查询
176        let mut rows = match stmt.query(sql_params_refs.as_slice()) {
177            Ok(rows) => rows,
178            Err(e) => {
179                return Err(ToolError::ExecutionError(format!(
180                    "Failed to execute query: {}",
181                    e
182                )))
183            }
184        };
185
186        // 获取结果行
187        let mut result_rows = Vec::new();
188        while let Ok(Some(row)) = rows.next() {
189            let row_values = extract_row_values(row, &column_names);
190            result_rows.push(row_values);
191        }
192
193        Ok(json!({
194            "columns": column_names,
195            "rows": result_rows,
196        }))
197    }
198
199    /// 执行SQL语句
200    async fn execute(&self, params: Value) -> Result<Value, ToolError> {
201        // 获取语句参数
202        let statement = match params.get("statement") {
203            Some(Value::String(s)) => s,
204            _ => {
205                return Err(ToolError::InvalidParameters(
206                    "Missing required parameter: statement".into(),
207                ))
208            }
209        };
210
211        // 获取绑定参数
212        let params_json = json!([]);
213        let bind_params = params.get("params").unwrap_or(&params_json);
214        let bind_params = match bind_params {
215            Value::Array(arr) => arr,
216            _ => {
217                return Err(ToolError::InvalidParameters(
218                    "params must be an array".into(),
219                ))
220            }
221        };
222
223        // 执行语句
224        let conn = self.conn.lock().await;
225
226        // 将JSON参数转换为SQLite参数
227        let sql_params: Vec<Box<dyn ToSql>> =
228            bind_params.iter().map(|v| json_value_to_sql(v)).collect();
229
230        let sql_params_refs: Vec<&dyn ToSql> = sql_params.iter().map(|p| p.as_ref()).collect();
231
232        // 执行语句
233        let result = conn.execute(statement, sql_params_refs.as_slice());
234
235        match result {
236            Ok(rows_affected) => {
237                // 获取最后插入的行ID
238                let last_insert_id = conn.last_insert_rowid();
239
240                Ok(json!({
241                    "rowcount": rows_affected,
242                    "lastrowid": last_insert_id,
243                }))
244            }
245            Err(e) => Err(ToolError::ExecutionError(format!(
246                "Failed to execute statement: {}",
247                e
248            ))),
249        }
250    }
251
252    /// 执行多个SQL语句
253    async fn executemany(&self, params: Value) -> Result<Value, ToolError> {
254        // 获取语句参数
255        let statement = match params.get("statement") {
256            Some(Value::String(s)) => s,
257            _ => {
258                return Err(ToolError::InvalidParameters(
259                    "Missing required parameter: statement".into(),
260                ))
261            }
262        };
263
264        // 获取参数列表
265        let params_list = match params.get("params_list") {
266            Some(Value::Array(list)) => list,
267            _ => {
268                return Err(ToolError::InvalidParameters(
269                    "Missing required parameter: params_list".into(),
270                ))
271            }
272        };
273
274        // 执行语句
275        let conn = self.conn.lock().await;
276
277        let mut stmt = match conn.prepare(statement) {
278            Ok(stmt) => stmt,
279            Err(e) => {
280                return Err(ToolError::ExecutionError(format!(
281                    "Failed to prepare statement: {}",
282                    e
283                )))
284            }
285        };
286
287        let mut rows_affected = 0;
288
289        for params_item in params_list {
290            match params_item {
291                Value::Array(params) => {
292                    // 将JSON参数转换为SQLite参数
293                    let sql_params: Vec<Box<dyn ToSql>> =
294                        params.iter().map(|v| json_value_to_sql(v)).collect();
295
296                    let sql_params_refs: Vec<&dyn ToSql> =
297                        sql_params.iter().map(|p| p.as_ref()).collect();
298
299                    match stmt.execute(sql_params_refs.as_slice()) {
300                        Ok(count) => rows_affected += count,
301                        Err(e) => {
302                            return Err(ToolError::ExecutionError(format!(
303                                "Failed to execute statement: {}",
304                                e
305                            )))
306                        }
307                    }
308                }
309                _ => {
310                    return Err(ToolError::InvalidParameters(
311                        "params_list must contain arrays".into(),
312                    ))
313                }
314            }
315        }
316
317        Ok(json!({
318            "rowcount": rows_affected,
319        }))
320    }
321
322    /// 执行SQL脚本
323    async fn executescript(&self, params: Value) -> Result<Value, ToolError> {
324        // 获取脚本参数
325        let script = match params.get("script") {
326            Some(Value::String(s)) => s,
327            _ => {
328                return Err(ToolError::InvalidParameters(
329                    "Missing required parameter: script".into(),
330                ))
331            }
332        };
333
334        // 执行脚本
335        let conn = self.conn.lock().await;
336
337        match conn.execute_batch(script) {
338            Ok(_) => {
339                // 由于execute_batch不返回受影响的行数,我们返回0
340                Ok(json!({
341                    "rowcount": 0,
342                }))
343            }
344            Err(e) => Err(ToolError::ExecutionError(format!(
345                "Failed to execute script: {}",
346                e
347            ))),
348        }
349    }
350}
351
352impl mcp_server_fishcode2025::Router for SQLiteRouter {
353    fn name(&self) -> String {
354        "sqlite".to_string()
355    }
356
357    fn instructions(&self) -> String {
358        "SQLite数据库访问服务,提供执行SQL查询和语句的能力。".to_string()
359    }
360
361    fn capabilities(&self) -> ServerCapabilities {
362        CapabilitiesBuilder::new().with_tools(true).build()
363    }
364
365    fn list_tools(&self) -> Vec<Tool> {
366        vec![
367            Tool::new(
368                "query".to_string(),
369                "执行SQL查询并返回结果".to_string(),
370                json!({
371                    "type": "object",
372                    "required": ["query"],
373                    "properties": {
374                        "query": {
375                            "type": "string",
376                            "description": "要执行的SQL查询"
377                        },
378                        "params": {
379                            "type": "array",
380                            "description": "绑定到查询的参数"
381                        }
382                    }
383                }),
384            ),
385            Tool::new(
386                "execute".to_string(),
387                "执行SQL语句".to_string(),
388                json!({
389                    "type": "object",
390                    "required": ["statement"],
391                    "properties": {
392                        "statement": {
393                            "type": "string",
394                            "description": "要执行的SQL语句"
395                        },
396                        "params": {
397                            "type": "array",
398                            "description": "绑定到语句的参数"
399                        }
400                    }
401                }),
402            ),
403            Tool::new(
404                "executemany".to_string(),
405                "使用不同参数多次执行SQL语句".to_string(),
406                json!({
407                    "type": "object",
408                    "required": ["statement", "params_list"],
409                    "properties": {
410                        "statement": {
411                            "type": "string",
412                            "description": "要执行的SQL语句"
413                        },
414                        "params_list": {
415                            "type": "array",
416                            "description": "绑定到语句的参数列表"
417                        }
418                    }
419                }),
420            ),
421            Tool::new(
422                "executescript".to_string(),
423                "执行SQL脚本".to_string(),
424                json!({
425                    "type": "object",
426                    "required": ["script"],
427                    "properties": {
428                        "script": {
429                            "type": "string",
430                            "description": "要执行的SQL脚本"
431                        }
432                    }
433                }),
434            ),
435        ]
436    }
437
438    fn call_tool(
439        &self,
440        tool_name: &str,
441        arguments: Value,
442    ) -> Pin<Box<dyn Future<Output = Result<Vec<Content>, ToolError>> + Send + 'static>> {
443        let self_clone = self.clone();
444        let tool_name = tool_name.to_string(); // 获取所有权
445
446        Box::pin(async move {
447            debug!("Calling tool: {}", tool_name);
448
449            let result = match tool_name.as_str() {
450                "query" => self_clone.query(arguments).await?,
451                "execute" => self_clone.execute(arguments).await?,
452                "executemany" => self_clone.executemany(arguments).await?,
453                "executescript" => self_clone.executescript(arguments).await?,
454                _ => return Err(ToolError::NotFound(format!("Unknown tool: {}", tool_name))),
455            };
456
457            // 使用Content::text方法将JSON转换为字符串
458            let json_string = serde_json::to_string(&result).unwrap_or_default();
459            Ok(vec![Content::text(json_string)])
460        })
461    }
462
463    fn list_resources(&self) -> Vec<Resource> {
464        vec![]
465    }
466
467    fn read_resource(
468        &self,
469        _uri: &str,
470    ) -> Pin<Box<dyn Future<Output = Result<String, ResourceError>> + Send + 'static>> {
471        Box::pin(async { Err(ResourceError::NotFound("Resource not found".into())) })
472    }
473
474    fn list_prompts(&self) -> Vec<Prompt> {
475        vec![]
476    }
477
478    fn get_prompt(
479        &self,
480        _prompt_name: &str,
481    ) -> Pin<Box<dyn Future<Output = Result<String, PromptError>> + Send + 'static>> {
482        Box::pin(async { Err(PromptError::NotFound("Prompt not found".into())) })
483    }
484}
485
486impl Clone for SQLiteRouter {
487    fn clone(&self) -> Self {
488        Self {
489            conn: Arc::clone(&self.conn),
490        }
491    }
492}
493
494/// 将JSON值转换为SQLite参数
495fn json_value_to_sql(value: &Value) -> Box<dyn ToSql> {
496    match value {
497        Value::Null => Box::new(Option::<String>::None),
498        Value::Bool(b) => Box::new(*b),
499        Value::Number(n) => {
500            if n.is_i64() {
501                Box::new(n.as_i64().unwrap())
502            } else if n.is_u64() {
503                Box::new(n.as_u64().unwrap() as i64)
504            } else {
505                Box::new(n.as_f64().unwrap())
506            }
507        }
508        Value::String(s) => Box::new(s.clone()),
509        Value::Array(_) => Box::new(value.to_string()),
510        Value::Object(_) => Box::new(value.to_string()),
511    }
512}
513
514/// 从SQLite行中提取值
515fn extract_row_values(row: &Row, column_names: &[String]) -> Value {
516    let mut values = serde_json::Map::new();
517
518    for (i, name) in column_names.iter().enumerate() {
519        let value = match row.get_ref(i) {
520            Ok(rusqlite::types::ValueRef::Null) => Value::Null,
521            Ok(rusqlite::types::ValueRef::Integer(i)) => Value::Number(i.into()),
522            Ok(rusqlite::types::ValueRef::Real(f)) => {
523                if let Some(n) = serde_json::Number::from_f64(f) {
524                    Value::Number(n)
525                } else {
526                    Value::Null
527                }
528            }
529            Ok(rusqlite::types::ValueRef::Text(t)) => {
530                Value::String(String::from_utf8_lossy(t).to_string())
531            }
532            Ok(rusqlite::types::ValueRef::Blob(b)) => Value::String(STANDARD.encode(b)),
533            Err(_) => Value::Null,
534        };
535
536        values.insert(name.clone(), value);
537    }
538
539    Value::Object(values)
540}