use std::path::Path;
use anyhow::Context;
use rmcp::{
ServiceExt,
service::{RoleServer, RunningService},
transport::IntoTransport,
};
use rusqlite::OpenFlags;
use crate::access_control::AuthorizationResolver;
use crate::cli::Cli;
use crate::mcp::McpServerSqlite;
pub async fn serve<T, E, A>(
cli: Cli,
transport: T,
) -> anyhow::Result<RunningService<RoleServer, McpServerSqlite>>
where
T: IntoTransport<RoleServer, E, A>,
E: std::error::Error + Send + Sync + 'static,
{
let Cli {
database,
init_sql,
preset,
allow,
deny,
timeout_ms,
} = cli;
let is_new = is_new_database(&database);
tracing::info!(database = %database, preset = %preset, "starting server");
let flags = OpenFlags::SQLITE_OPEN_URI
| OpenFlags::SQLITE_OPEN_READ_WRITE
| OpenFlags::SQLITE_OPEN_CREATE;
let manager =
r2d2_sqlite::SqliteConnectionManager::file(&database).with_flags(flags);
let pool = r2d2::Pool::new(manager)
.context("Failed to create the connection pool")?;
if is_new && !init_sql.is_empty() {
run_init_scripts(&pool, &init_sql)?;
}
tracing::info!(
allow_rules = allow.len(),
deny_rules = deny.len(),
"access control configured"
);
let resolver = allow
.into_iter()
.map(|selector| (selector, true))
.chain(deny.into_iter().map(|selector| (selector, false)))
.fold(
AuthorizationResolver::from(preset),
|resolver, (selector, allow)| {
resolver.with_selector(selector, allow)
},
);
let query_timeout = timeout_ms.map(std::time::Duration::from_millis);
if let Some(timeout) = query_timeout {
tracing::info!(
timeout_ms = timeout.as_millis(),
"query timeout configured"
);
}
let server = McpServerSqlite::new(pool, resolver, query_timeout);
let service = server.serve(transport).await?;
tracing::info!("server ready");
Ok(service)
}
fn is_new_database(database: &str) -> bool {
let path = database
.strip_prefix("file:")
.unwrap_or(database)
.split('?')
.next()
.unwrap_or(database);
path == ":memory:" || path.is_empty() || !Path::new(path).exists()
}
fn run_init_scripts(
pool: &r2d2::Pool<r2d2_sqlite::SqliteConnectionManager>,
scripts: &[std::path::PathBuf],
) -> anyhow::Result<()> {
let conn = pool
.get()
.context("Failed to acquire a connection for init scripts")?;
for path in scripts {
tracing::info!(path = %path.display(), "executing init script");
let sql = std::fs::read_to_string(path)
.with_context(|| format!("Failed to read {}", path.display()))?;
conn.execute_batch(&sql)
.with_context(|| format!("Failed to execute {}", path.display()))?;
}
Ok(())
}