#![cfg(any(test, feature = "dev_mock"))]
use async_trait::async_trait;
use serde_json::{json, Value};
use std::collections::HashMap;
use std::sync::Mutex;
use pmcp_server_toolkit::sql::{
translate_placeholders, ConnectorError, Dialect, SqlConnector, TranslatedSql,
};
pub struct MysqlMock {
pub tables: HashMap<String, Vec<Value>>,
pub last_translated_sql: Mutex<Option<String>>,
pub last_positional_args: Mutex<Option<Vec<Value>>>,
}
impl MysqlMock {
#[must_use]
pub fn employee_directory() -> Self {
let mut tables = HashMap::new();
tables.insert(
"employees".into(),
vec![
json!({"id": 1, "name": "Ada Lovelace"}),
json!({"id": 2, "name": "Alan Turing"}),
],
);
Self {
tables,
last_translated_sql: Mutex::new(None),
last_positional_args: Mutex::new(None),
}
}
}
#[async_trait]
impl SqlConnector for MysqlMock {
fn dialect(&self) -> Dialect {
Dialect::MySql
}
async fn execute(
&self,
sql: &str,
params: &[(String, Value)],
) -> Result<Vec<Value>, ConnectorError> {
let TranslatedSql {
sql: translated,
ordered_params,
} = translate_placeholders(sql, Dialect::MySql);
let positional: Vec<Value> = ordered_params
.iter()
.map(|n| {
params
.iter()
.find(|(k, _)| k == n)
.map_or(Value::Null, |(_, v)| v.clone())
})
.collect();
if let Ok(mut g) = self.last_translated_sql.lock() {
*g = Some(translated.clone());
}
if let Ok(mut g) = self.last_positional_args.lock() {
*g = Some(positional.clone());
}
cheap_query_engine(&self.tables, &translated, &positional)
}
async fn schema_text(&self) -> Result<String, ConnectorError> {
Ok("-- SHOW CREATE TABLE (MySQL):\n\
CREATE TABLE `employees` (\n \
`id` INT NOT NULL,\n \
`name` VARCHAR(255) NOT NULL\n\
) ENGINE=InnoDB;\n\
CREATE TABLE `departments` (\n \
`id` INT NOT NULL,\n \
`name` VARCHAR(255) NOT NULL\n\
) ENGINE=InnoDB;\n"
.into())
}
}
fn cheap_query_engine(
tables: &HashMap<String, Vec<Value>>,
sql: &str,
args: &[Value],
) -> Result<Vec<Value>, ConnectorError> {
if sql.contains("FROM employees WHERE id = ?") {
let id = args
.first()
.and_then(serde_json::Value::as_i64)
.unwrap_or(-1);
let rows = tables.get("employees").cloned().unwrap_or_default();
return Ok(rows
.into_iter()
.filter(|r| r["id"].as_i64() == Some(id))
.collect());
}
if sql.contains("SELECT * FROM employees") {
return Ok(tables.get("employees").cloned().unwrap_or_default());
}
Ok(vec![])
}