Skip to main content

sery_mcp/
lib.rs

1//! # sery-mcp — local-files MCP server.
2//!
3//! `sery-mcp` is primarily a **binary** — most users run
4//! `cargo install sery-mcp` and configure their MCP client to spawn
5//! it. The library surface exists for the rare downstream that wants
6//! to embed the same tool implementations into their own MCP server
7//! (e.g. Sery Link's desktop app spawning the logic in-process
8//! instead of as a subprocess).
9//!
10//! ## v0.3.0 surface
11//!
12//! - [`SeryMcpServer`] — the configured MCP server. Construct via
13//!   [`SeryMcpServer::new`] with a single `--root` path; serve with
14//!   [`rmcp::ServiceExt::serve`] and your transport of choice.
15//! - **Six tools**, all read-only:
16//!   - `list_folder` — enumerate files (scankit)
17//!   - `search_files` — filename + extension search with scoring (scankit)
18//!   - `get_schema` — column names + types + row count (tabkit)
19//!   - `sample_rows` — N rows of sampled data, header-keyed (tabkit)
20//!   - `read_document` — DOCX/PDF/PPTX/HTML/IPYNB → markdown (mdkit)
21//!   - `query_sql` — read-only SQL queries against a CSV / Parquet
22//!     file (`DataFusion`). The file is registered as table `data`
23//!     for the duration of the call.
24//!
25//! ## Privacy + threat model
26//!
27//! `sery-mcp` opens no sockets and makes no outbound network calls.
28//! All file reads are bounded by `--root`: any tool argument that
29//! tries to escape via `..` or absolute paths is rejected before the
30//! filesystem call. Tools are read-only by design — no `write_file`,
31//! no `delete`, no `execute`.
32
33#![doc(html_root_url = "https://docs.rs/sery-mcp/0.4.3")]
34#![cfg_attr(docsrs, feature(doc_cfg))]
35// Pedantic lints we deliberately accept:
36//   * doc_markdown — prose mentions SQL keywords, library names, and
37//     filesystem path patterns that aren't always worth backticking.
38//   * items_after_statements — `use ...` inside match arms keeps
39//     type imports next to the arms that consume them; moving them
40//     to function-top harms locality in `arrow_value_to_json`.
41//   * case_sensitive_file_extension_comparisons — we lowercase the
42//     path string before `ends_with` checks; clippy can't see through
43//     the local rebinding.
44#![allow(
45    clippy::doc_markdown,
46    clippy::items_after_statements,
47    clippy::case_sensitive_file_extension_comparisons
48)]
49
50use std::path::{Component, Path, PathBuf};
51use std::sync::OnceLock;
52
53use rmcp::{
54    handler::server::{router::tool::ToolRouter, wrapper::Parameters},
55    model::{
56        CallToolResult, Content, Implementation, ProtocolVersion, ServerCapabilities, ServerInfo,
57    },
58    schemars, tool, tool_handler, tool_router, ErrorData as McpError, ServerHandler,
59};
60
61/// The crate version as reported by Cargo at build time.
62pub const VERSION: &str = env!("CARGO_PKG_VERSION");
63
64/// 50 MB cap on document extraction. Mirrors Sery Link's scanner
65/// default — beyond this, mdkit's pandoc / libpdfium backends start
66/// to trip on memory limits and the LLM context window can't hold
67/// the result anyway. Configurable via tool argument in a future
68/// version; for v0.2 it's a hard cap.
69const MAX_DOCUMENT_BYTES: u64 = 50 * 1024 * 1024;
70
71// ---------------------------------------------------------------------------
72// Lazy backends
73// ---------------------------------------------------------------------------
74
75fn mdkit_engine() -> &'static mdkit::Engine {
76    static ENGINE: OnceLock<mdkit::Engine> = OnceLock::new();
77    ENGINE.get_or_init(mdkit::Engine::with_defaults)
78}
79
80fn tabkit_engine() -> &'static tabkit::Engine {
81    static ENGINE: OnceLock<tabkit::Engine> = OnceLock::new();
82    ENGINE.get_or_init(tabkit::Engine::with_defaults)
83}
84
85// ---------------------------------------------------------------------------
86// Tool input schemas
87// ---------------------------------------------------------------------------
88
89/// Input for the `list_folder` tool.
90#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
91pub struct ListFolderRequest {
92    /// Subdirectory under `--root`. Defaults to the root.
93    #[serde(default)]
94    #[schemars(
95        description = "Subdirectory under the configured --root. Must be relative — no '..' segments, no absolute paths. Defaults to the root."
96    )]
97    pub path: Option<String>,
98    /// Cap on the number of returned entries. Defaults to 1000.
99    #[serde(default)]
100    #[schemars(description = "Maximum entries to return. Defaults to 1000.")]
101    pub limit: Option<usize>,
102}
103
104/// Input for the `search_files` tool.
105#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
106pub struct SearchFilesRequest {
107    /// The query string (case-insensitive substring of basename).
108    #[schemars(
109        description = "Search term (case-insensitive). Matched against file basenames; whole-path matches score lower."
110    )]
111    pub query: String,
112    /// Optional extension filter. Only files matching one of these
113    /// extensions (lowercase, no leading dot) are considered.
114    #[serde(default)]
115    #[schemars(
116        description = "Restrict to files whose extension matches one of these (lowercase, no leading dot, e.g. ['csv','parquet'])."
117    )]
118    pub extensions: Option<Vec<String>>,
119    /// Cap on results. Defaults to 50.
120    #[serde(default)]
121    #[schemars(description = "Maximum results to return. Defaults to 50.")]
122    pub limit: Option<usize>,
123}
124
125/// Input for the `get_schema` tool.
126#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
127pub struct GetSchemaRequest {
128    /// Path to a tabular file under `--root`.
129    #[schemars(
130        description = "Relative path to a tabular file (CSV / TSV / Parquet / XLSX / XLS / XLSB / XLSM / ODS) under --root."
131    )]
132    pub path: String,
133    /// Optional sheet name for multi-sheet workbooks.
134    #[serde(default)]
135    #[schemars(
136        description = "For multi-sheet XLSX / ODS files: which sheet to inspect. Defaults to the first non-empty sheet."
137    )]
138    pub sheet: Option<String>,
139}
140
141/// Input for the `sample_rows` tool.
142#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
143pub struct SampleRowsRequest {
144    /// Path to a tabular file under `--root`.
145    #[schemars(description = "Relative path to a tabular file under --root.")]
146    pub path: String,
147    /// How many rows to return. Defaults to 5; capped at 100.
148    #[serde(default)]
149    #[schemars(description = "Sample-row count. Defaults to 5; capped at 100.")]
150    pub limit: Option<usize>,
151    /// Optional sheet name for multi-sheet workbooks.
152    #[serde(default)]
153    #[schemars(description = "For multi-sheet XLSX / ODS files: which sheet to sample.")]
154    pub sheet: Option<String>,
155}
156
157/// Input for the `read_document` tool.
158#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
159pub struct ReadDocumentRequest {
160    /// Path to a document under `--root`.
161    #[schemars(
162        description = "Relative path to a document file (DOCX / PDF / PPTX / HTML / IPYNB / EPUB / RTF / ODT) under --root. 50 MB cap."
163    )]
164    pub path: String,
165}
166
167/// Input for the `query_sql` tool.
168///
169/// **Single-file mode**: pass `path` and reference the file as table `data`.
170/// **Multi-file mode**: pass `tables` mapping LLM-chosen names → relative paths,
171/// then JOIN them in `sql`. Mutually exclusive — pick one shape per call.
172///
173/// Glob patterns (`*`, `?`, `[...]`) are supported in both modes — the
174/// SQL backend expands them at read time. They stay bounded by
175/// `--root` because the path validator rejects `..` and absolute
176/// paths up-front.
177#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
178pub struct QuerySqlRequest {
179    /// Single-file shortcut. Registered as table `data`.
180    #[serde(default)]
181    #[schemars(
182        description = "Single-file mode. Relative path (or glob pattern like '2024/*.csv') under --root. The file(s) are registered as table `data` for the duration of this query. Mutually exclusive with `tables`."
183    )]
184    pub path: Option<String>,
185    /// Multi-file mode: map of table name → relative path (or glob).
186    /// Each entry is registered as a SQL table in the same query
187    /// session so the LLM can JOIN across files.
188    #[serde(default)]
189    #[schemars(
190        description = "Multi-file mode. Map of {table_name: relative_path} — each path becomes a SQL table you can JOIN. Names must be valid SQL identifiers ([a-zA-Z_][a-zA-Z0-9_]*). Cap of 16 tables per call. Mutually exclusive with `path`."
191    )]
192    pub tables: Option<std::collections::HashMap<String, String>>,
193    /// The SQL query.
194    #[schemars(
195        description = "SQL query — supports window functions, CTEs, glob reads, JOINs across the registered tables. Read-only — INSERT/UPDATE/DELETE/DDL/ATTACH/COPY/PRAGMA all rejected at validation time."
196    )]
197    pub sql: String,
198    /// Cap on returned rows. Defaults to 100; capped at 1000.
199    #[serde(default)]
200    #[schemars(
201        description = "Maximum rows to return. Defaults to 100, capped at 1000. Use SQL LIMIT for tighter caps."
202    )]
203    pub limit: Option<usize>,
204}
205
206// ---------------------------------------------------------------------------
207// Tool output shapes
208// ---------------------------------------------------------------------------
209
210/// One entry in a `list_folder` response.
211#[derive(Debug, serde::Serialize)]
212pub struct FileEntry {
213    /// Path relative to the configured `--root`.
214    pub relative_path: String,
215    /// File size in bytes at walk time.
216    pub size_bytes: u64,
217    /// Last-modified timestamp as RFC 3339, when the filesystem reports one.
218    pub modified: Option<String>,
219    /// Lowercase, dot-less extension. Empty when the file has none.
220    pub extension: String,
221}
222
223/// One result in a `search_files` response.
224#[derive(Debug, serde::Serialize)]
225pub struct SearchHit {
226    /// Path relative to the configured `--root`.
227    pub relative_path: String,
228    /// File size in bytes at walk time.
229    pub size_bytes: u64,
230    /// Lowercase, dot-less extension.
231    pub extension: String,
232    /// Match score in `[0.0, 1.0]`. See `search_files` doc for the rubric.
233    pub score: f64,
234    /// Short human-readable explanation of the match category.
235    pub why_matched: &'static str,
236}
237
238/// One column in a `get_schema` response.
239#[derive(Debug, serde::Serialize)]
240pub struct ColumnInfo {
241    /// Column header. Falls back to `column_<idx>` when the source has none.
242    pub name: String,
243    /// Inferred type as a stable lowercase string (`"integer"`, `"text"`, …).
244    #[serde(rename = "type")]
245    pub data_type: &'static str,
246    /// `true` when any sample row had a null/empty cell in this position.
247    pub nullable: bool,
248}
249
250/// `get_schema` response.
251#[derive(Debug, serde::Serialize)]
252pub struct SchemaResponse {
253    /// The path the caller passed in, echoed back for tool-call audit.
254    pub relative_path: String,
255    /// Lowercase extension (`"csv"`, `"parquet"`, …).
256    pub format: String,
257    /// Columns in source order.
258    pub columns: Vec<ColumnInfo>,
259    /// Total row count when known. `None` when the backend skipped a
260    /// full scan.
261    pub row_count: Option<u64>,
262    /// Backend metadata — for XLSX this carries `"sheet"`, for CSV
263    /// it can carry `"delimiter"`. Stable keys are documented in
264    /// tabkit's per-backend docs.
265    #[serde(skip_serializing_if = "std::collections::HashMap::is_empty")]
266    pub metadata: std::collections::HashMap<String, String>,
267}
268
269/// `sample_rows` response.
270#[derive(Debug, serde::Serialize)]
271pub struct SamplesResponse {
272    /// The path the caller passed in.
273    pub relative_path: String,
274    /// Lowercase extension (`"csv"`, `"parquet"`, …).
275    pub format: String,
276    /// Column headers in source order.
277    pub columns: Vec<String>,
278    /// Sample rows as JSON objects keyed by column header.
279    pub rows: Vec<serde_json::Map<String, serde_json::Value>>,
280    /// Total row count when known.
281    pub row_count: Option<u64>,
282}
283
284/// `read_document` response.
285#[derive(Debug, serde::Serialize)]
286pub struct DocumentResponse {
287    /// The path the caller passed in.
288    pub relative_path: String,
289    /// Lowercase extension (`"pdf"`, `"docx"`, …).
290    pub format: String,
291    /// Extracted markdown text — the whole document.
292    pub markdown: String,
293    /// Document title when the backend could derive one.
294    pub title: Option<String>,
295    /// Extracted markdown character count.
296    pub char_count: usize,
297    /// Source file size in bytes.
298    pub size_bytes: u64,
299}
300
301/// `query_sql` response.
302#[derive(Debug, serde::Serialize)]
303pub struct QueryResponse {
304    /// What the caller passed in, echoed back in human-readable form.
305    /// For single-file mode: `"path/to/file.csv"`. For multi-file:
306    /// `"customers=customers.csv, orders=orders.parquet"`.
307    pub input: String,
308    /// Lowercase extension of the (first) queried file. Empty when
309    /// glob patterns mix multiple formats.
310    pub format: String,
311    /// Result column names in the order they appear in the projection.
312    pub columns: Vec<String>,
313    /// Result rows as JSON objects keyed by column name.
314    pub rows: Vec<serde_json::Map<String, serde_json::Value>>,
315    /// Number of rows returned (after the row cap).
316    pub row_count: usize,
317    /// `true` when the result was capped by the row limit and the
318    /// underlying query produced more rows. The LLM should use this
319    /// to decide whether to refine the SQL with a tighter `WHERE`
320    /// or `LIMIT`.
321    pub truncated: bool,
322}
323
324// ---------------------------------------------------------------------------
325// Server
326// ---------------------------------------------------------------------------
327
328/// A configured MCP server. Cheap to construct + clone; share a single
329/// instance across the rmcp serve loop.
330#[derive(Clone)]
331pub struct SeryMcpServer {
332    root: PathBuf,
333    // The router is consumed by the `#[tool_handler]` macro that
334    // implements `ServerHandler` below — it dispatches incoming
335    // tool/call requests to the right `#[tool]` method via this field.
336    // Rust's dead-code analysis can't see that cross-macro usage so
337    // we suppress the lint at the field rather than file-wide.
338    #[allow(dead_code)]
339    tool_router: ToolRouter<SeryMcpServer>,
340}
341
342#[tool_router]
343impl SeryMcpServer {
344    /// Construct a new server with the given filesystem root.
345    pub fn new(root: PathBuf) -> Self {
346        Self {
347            root,
348            tool_router: Self::tool_router(),
349        }
350    }
351
352    /// Returns the canonical root this server is exposing.
353    pub fn root(&self) -> &Path {
354        &self.root
355    }
356
357    // ── Tools ─────────────────────────────────────────────────────
358
359    #[tool(
360        description = "List files under the configured --root (or a sub-path). Returns one JSON object per file with relative_path, size_bytes, modified (ISO 8601), and extension. Read-only; never returns file contents. Path-traversal rejected."
361    )]
362    fn list_folder(
363        &self,
364        Parameters(req): Parameters<ListFolderRequest>,
365    ) -> Result<CallToolResult, McpError> {
366        let target = self.resolve_subpath(req.path.as_deref())?;
367        let limit = req.limit.unwrap_or(1000);
368        let entries = self.walk_entries(&target, limit)?;
369        as_json_result(&entries)
370    }
371
372    #[tool(
373        description = "Search files by name. Case-insensitive substring match against the basename, ranked: exact basename match (1.0), basename startswith (0.8), basename contains (0.5), path contains (0.2). Optional `extensions` filter restricts to specific file types. Returns up to `limit` hits sorted by score then path."
374    )]
375    fn search_files(
376        &self,
377        Parameters(req): Parameters<SearchFilesRequest>,
378    ) -> Result<CallToolResult, McpError> {
379        let limit = req.limit.unwrap_or(50);
380        let query = req.query.trim().to_lowercase();
381        if query.is_empty() {
382            return Err(McpError::invalid_params("'query' must not be empty", None));
383        }
384        let ext_filter: Option<Vec<String>> = req
385            .extensions
386            .map(|v| v.into_iter().map(|s| s.to_ascii_lowercase()).collect());
387
388        let scanner = scankit::Scanner::new(scankit::ScanConfig::default().follow_symlinks(false))
389            .map_err(|e| McpError::internal_error(format!("scankit init: {e}"), None))?;
390
391        let mut hits: Vec<SearchHit> = Vec::new();
392        for result in scanner.walk(&self.root) {
393            let Ok(entry) = result else { continue };
394            if let Some(filter) = ext_filter.as_ref() {
395                if !filter.iter().any(|e| e == &entry.extension) {
396                    continue;
397                }
398            }
399            let basename = entry
400                .path
401                .file_name()
402                .and_then(|s| s.to_str())
403                .map(str::to_lowercase)
404                .unwrap_or_default();
405            let stem = entry
406                .path
407                .file_stem()
408                .and_then(|s| s.to_str())
409                .map(str::to_lowercase)
410                .unwrap_or_default();
411            let relative_path =
412                path_to_forward_slash(entry.path.strip_prefix(&self.root).unwrap_or(&entry.path));
413            let relative_lower = relative_path.to_lowercase();
414
415            let (score, why) = if stem == query || basename == query {
416                (1.0, "exact basename match")
417            } else if basename.starts_with(&query) {
418                (0.8, "basename starts with query")
419            } else if basename.contains(&query) {
420                (0.5, "basename contains query")
421            } else if relative_lower.contains(&query) {
422                (0.2, "path contains query")
423            } else {
424                continue;
425            };
426
427            hits.push(SearchHit {
428                relative_path,
429                size_bytes: entry.size_bytes,
430                extension: entry.extension,
431                score,
432                why_matched: why,
433            });
434        }
435        hits.sort_by(|a, b| {
436            b.score
437                .partial_cmp(&a.score)
438                .unwrap_or(std::cmp::Ordering::Equal)
439                .then_with(|| a.relative_path.cmp(&b.relative_path))
440        });
441        hits.truncate(limit);
442        as_json_result(&hits)
443    }
444
445    #[tool(
446        description = "Return column names + inferred types + row count for a tabular file (CSV / TSV / Parquet / XLSX / XLS / XLSB / XLSM / ODS). Backed by tabkit. row_count is null for very large files where a full scan was skipped. Specify `sheet` for multi-sheet workbooks."
447    )]
448    fn get_schema(
449        &self,
450        Parameters(req): Parameters<GetSchemaRequest>,
451    ) -> Result<CallToolResult, McpError> {
452        let path = self.resolve_required_file(&req.path)?;
453        let mut options = tabkit::ReadOptions::default().max_sample_rows(0);
454        if let Some(sheet) = req.sheet {
455            options = options.sheet_name(sheet);
456        }
457        let table = tabkit_engine()
458            .read(&path, &options)
459            .map_err(|e| McpError::internal_error(format!("tabkit read: {e}"), None))?;
460        let response = SchemaResponse {
461            relative_path: req.path,
462            format: extension_of(&path),
463            columns: table
464                .columns
465                .iter()
466                .map(|c| ColumnInfo {
467                    name: c.name.clone(),
468                    data_type: data_type_str(c.data_type),
469                    nullable: c.nullable,
470                })
471                .collect(),
472            row_count: table.row_count,
473            metadata: table.metadata,
474        };
475        as_json_result(&response)
476    }
477
478    #[tool(
479        description = "Return the first N rows of a tabular file as header-keyed JSON objects. Defaults to 5 rows; capped at 100. Specify `sheet` for multi-sheet workbooks. Use sparingly — sample rows can contain PII; this tool returns raw cell values without redaction."
480    )]
481    fn sample_rows(
482        &self,
483        Parameters(req): Parameters<SampleRowsRequest>,
484    ) -> Result<CallToolResult, McpError> {
485        let path = self.resolve_required_file(&req.path)?;
486        let limit = req.limit.unwrap_or(5).min(100);
487        let mut options = tabkit::ReadOptions::default().max_sample_rows(limit);
488        if let Some(sheet) = req.sheet {
489            options = options.sheet_name(sheet);
490        }
491        let table = tabkit_engine()
492            .read(&path, &options)
493            .map_err(|e| McpError::internal_error(format!("tabkit read: {e}"), None))?;
494        let column_names: Vec<String> = table.columns.iter().map(|c| c.name.clone()).collect();
495        let rows = table
496            .sample_rows
497            .iter()
498            .map(|row| {
499                let mut obj = serde_json::Map::new();
500                for (i, col) in column_names.iter().enumerate() {
501                    let v = row.get(i).map_or(serde_json::Value::Null, value_to_json);
502                    obj.insert(col.clone(), v);
503                }
504                obj
505            })
506            .collect();
507        let response = SamplesResponse {
508            relative_path: req.path,
509            format: extension_of(&path),
510            columns: column_names,
511            rows,
512            row_count: table.row_count,
513        };
514        as_json_result(&response)
515    }
516
517    #[tool(
518        description = "Run a read-only SQL query against one or more CSV / TSV / Parquet files. \
519                       Single-file: pass `path`, reference as table `data` in your SQL. \
520                       Multi-file: pass `tables` (a {name: path} map), reference each name as a SQL table — lets you JOIN across files. \
521                       Glob patterns (`*`, `?`) are supported in both — expanded at read time. \
522                       Full SQL dialect: window functions, CTEs, smart CSV sniffing, native XLSX. \
523                       Read-only by design — INSERT/UPDATE/DELETE/DDL/ATTACH/COPY/PRAGMA are rejected at validation time. \
524                       Returns header-keyed JSON rows; capped at 1000 (default 100). Set `truncated: true` when more rows exist."
525    )]
526    fn query_sql(
527        &self,
528        Parameters(req): Parameters<QuerySqlRequest>,
529    ) -> Result<CallToolResult, McpError> {
530        let limit = req.limit.unwrap_or(100).min(1000);
531        let table_specs = self.resolve_table_specs(&req)?;
532        validate_query_sql(&req.sql)?;
533
534        let conn = duckdb::Connection::open_in_memory()
535            .map_err(|e| McpError::internal_error(format!("sql backend open: {e}"), None))?;
536
537        // Register each (name, path) as a SQL view in the session.
538        // CREATE OR REPLACE VIEW <name> AS SELECT * FROM read_csv_auto / read_parquet
539        for spec in &table_specs {
540            let setup = build_register_view(&spec.table, &spec.path_for_sql, spec.format)?;
541            conn.execute_batch(&setup).map_err(|e| {
542                McpError::internal_error(format!("register table {}: {e}", spec.table), None)
543            })?;
544        }
545
546        // Wrap in an outer LIMIT for truncation detection. We ask for
547        // limit+1 rows; if we hit limit, set truncated=true and drop the
548        // extra. This is cheap because the SQL planner pushes the
549        // limit down past the user's projection.
550        let wrapped_sql = format!("SELECT * FROM ({}) LIMIT {}", req.sql, limit + 1);
551
552        let mut stmt = conn
553            .prepare(&wrapped_sql)
554            .map_err(|e| McpError::invalid_params(format!("sql prepare: {e}"), None))?;
555
556        // `query_arrow` calls `execute` internally; we read column
557        // names from its schema rather than from `stmt.column_names()`
558        // (which panics on this version of the binding when the
559        // prepared statement hasn't been executed yet).
560        let arrow_iter = stmt
561            .query_arrow(duckdb::params![])
562            .map_err(|e| McpError::invalid_params(format!("sql execute: {e}"), None))?;
563        let schema = arrow_iter.get_schema();
564        let columns: Vec<String> = schema.fields().iter().map(|f| f.name().clone()).collect();
565
566        let mut rows: Vec<serde_json::Map<String, serde_json::Value>> = Vec::with_capacity(limit);
567        let mut truncated = false;
568        'outer: for batch in arrow_iter {
569            for row_idx in 0..batch.num_rows() {
570                if rows.len() == limit {
571                    truncated = true;
572                    break 'outer;
573                }
574                let mut obj = serde_json::Map::with_capacity(columns.len());
575                for (col_idx, col_name) in columns.iter().enumerate() {
576                    let array = batch.column(col_idx);
577                    obj.insert(
578                        col_name.clone(),
579                        arrow_value_to_json(array.as_ref(), row_idx),
580                    );
581                }
582                rows.push(obj);
583            }
584        }
585
586        let response = QueryResponse {
587            input: describe_input(&table_specs),
588            format: table_specs
589                .first()
590                .map(|s| s.format.to_string())
591                .unwrap_or_default(),
592            row_count: rows.len(),
593            columns,
594            rows,
595            truncated,
596        };
597        as_json_result(&response)
598    }
599
600    #[tool(
601        description = "Convert a document file (DOCX / PDF / PPTX / HTML / IPYNB / EPUB / RTF / ODT) to markdown. Backed by mdkit (libpdfium for PDF, pandoc for office formats, html2md for HTML). 50 MB file size cap; larger files return an error. Returns the full extracted text — pair with a chunk-aware caller if your LLM context window can't hold the whole document."
602    )]
603    fn read_document(
604        &self,
605        Parameters(req): Parameters<ReadDocumentRequest>,
606    ) -> Result<CallToolResult, McpError> {
607        let path = self.resolve_required_file(&req.path)?;
608        let metadata = std::fs::metadata(&path)
609            .map_err(|e| McpError::internal_error(format!("stat: {e}"), None))?;
610        if metadata.len() > MAX_DOCUMENT_BYTES {
611            return Err(McpError::invalid_params(
612                format!(
613                    "file is {} bytes; read_document caps at {} bytes (50 MB)",
614                    metadata.len(),
615                    MAX_DOCUMENT_BYTES
616                ),
617                None,
618            ));
619        }
620        let document = mdkit_engine()
621            .extract(&path)
622            .map_err(|e| McpError::internal_error(format!("mdkit extract: {e}"), None))?;
623        let format = extension_of(&path);
624        let response = DocumentResponse {
625            char_count: document.markdown.chars().count(),
626            relative_path: req.path,
627            format,
628            title: document.title,
629            markdown: document.markdown,
630            size_bytes: metadata.len(),
631        };
632        as_json_result(&response)
633    }
634
635    // ── Internals ─────────────────────────────────────────────────
636
637    /// Resolve a tool-supplied sub-path against `self.root` for the
638    /// "may be omitted, defaults to root" case (used by `list_folder`).
639    fn resolve_subpath(&self, sub: Option<&str>) -> Result<PathBuf, McpError> {
640        let raw = match sub {
641            None => return Ok(self.root.clone()),
642            Some(s) if s.is_empty() || s == "." => return Ok(self.root.clone()),
643            Some(s) => s,
644        };
645        validate_relative_components(raw)?;
646        Ok(self.root.join(raw))
647    }
648
649    /// Resolve a tool-supplied path that **must** point to a regular
650    /// file under `self.root` (used by `get_schema`, `sample_rows`,
651    /// `read_document`).
652    fn resolve_required_file(&self, sub: &str) -> Result<PathBuf, McpError> {
653        if sub.is_empty() {
654            return Err(McpError::invalid_params("'path' must not be empty", None));
655        }
656        validate_relative_components(sub)?;
657        let joined = self.root.join(sub);
658        let metadata = std::fs::metadata(&joined)
659            .map_err(|e| McpError::invalid_params(format!("path not readable: {e}"), None))?;
660        if !metadata.is_file() {
661            return Err(McpError::invalid_params(
662                "'path' must refer to a regular file (not a directory or symlink)",
663                None,
664            ));
665        }
666        Ok(joined)
667    }
668
669    /// Resolve a path that may be a regular file OR a glob pattern
670    /// (`*`, `?`, `[...]`). Used by `query_sql`, where the SQL
671    /// backend does its own glob expansion at read time.
672    fn resolve_required_path_or_glob(&self, sub: &str) -> Result<PathBuf, McpError> {
673        if sub.is_empty() {
674            return Err(McpError::invalid_params("path must not be empty", None));
675        }
676        validate_relative_components(sub)?;
677        let joined = self.root.join(sub);
678        if !is_glob_pattern(sub) {
679            // Non-glob: enforce file-exists up-front so the LLM gets
680            // a clean error instead of "no files" buried inside a
681            // SQL execution error from the backend.
682            let metadata = std::fs::metadata(&joined)
683                .map_err(|e| McpError::invalid_params(format!("path not readable: {e}"), None))?;
684            if !metadata.is_file() {
685                return Err(McpError::invalid_params(
686                    "path must refer to a regular file or a glob pattern",
687                    None,
688                ));
689            }
690        }
691        Ok(joined)
692    }
693
694    /// Translate the `path` / `tables` fields of a [`QuerySqlRequest`]
695    /// into a normalised list of [`TableSpec`]s ready to register
696    /// with the SQL backend. Enforces:
697    ///
698    /// - Exactly one of `path` / `tables` is set.
699    /// - At least one table.
700    /// - At most 16 tables.
701    /// - Every table name is a valid SQL identifier.
702    /// - Every path resolves under `--root`.
703    /// - File extension is csv / tsv / parquet.
704    fn resolve_table_specs(&self, req: &QuerySqlRequest) -> Result<Vec<TableSpec>, McpError> {
705        match (&req.path, &req.tables) {
706            (Some(_), Some(_)) => Err(McpError::invalid_params(
707                "pass either `path` (single-file) or `tables` (multi-file), not both",
708                None,
709            )),
710            (None, None) => Err(McpError::invalid_params(
711                "must pass either `path` or `tables`",
712                None,
713            )),
714            (Some(path), None) => {
715                let resolved = self.resolve_required_path_or_glob(path)?;
716                let format = format_for_query_sql(path)?;
717                Ok(vec![TableSpec {
718                    table: "data".to_string(),
719                    path_for_sql: resolved.to_string_lossy().into_owned(),
720                    relative_path: path.clone(),
721                    format,
722                }])
723            }
724            (None, Some(tables)) => {
725                if tables.is_empty() {
726                    return Err(McpError::invalid_params("`tables` must not be empty", None));
727                }
728                if tables.len() > 16 {
729                    return Err(McpError::invalid_params("at most 16 tables per call", None));
730                }
731                let mut specs: Vec<TableSpec> = Vec::with_capacity(tables.len());
732                for (name, path) in tables {
733                    if !is_valid_sql_identifier(name) {
734                        return Err(McpError::invalid_params(
735                            format!(
736                                "table name '{name}' is not a valid SQL identifier \
737                                 ([a-zA-Z_][a-zA-Z0-9_]*)"
738                            ),
739                            None,
740                        ));
741                    }
742                    let resolved = self.resolve_required_path_or_glob(path)?;
743                    let format = format_for_query_sql(path)?;
744                    specs.push(TableSpec {
745                        table: name.clone(),
746                        path_for_sql: resolved.to_string_lossy().into_owned(),
747                        relative_path: path.clone(),
748                        format,
749                    });
750                }
751                // Sort for deterministic registration order — makes
752                // tests + logs reproducible. HashMap iteration order
753                // would otherwise be random.
754                specs.sort_by(|a, b| a.table.cmp(&b.table));
755                Ok(specs)
756            }
757        }
758    }
759
760    /// Walk `target` via [`scankit::Scanner`], capping output at
761    /// `limit` entries. Errors from individual `scankit::walk` items
762    /// (permission denied, transient I/O) are silently dropped.
763    fn walk_entries(&self, target: &Path, limit: usize) -> Result<Vec<FileEntry>, McpError> {
764        let scanner = scankit::Scanner::new(scankit::ScanConfig::default().follow_symlinks(false))
765            .map_err(|e| McpError::internal_error(format!("scankit init: {e}"), None))?;
766
767        let mut out = Vec::new();
768        for result in scanner.walk(target) {
769            if out.len() >= limit {
770                break;
771            }
772            let Ok(entry) = result else { continue };
773            let relative =
774                path_to_forward_slash(entry.path.strip_prefix(&self.root).unwrap_or(&entry.path));
775            out.push(FileEntry {
776                relative_path: relative,
777                size_bytes: entry.size_bytes,
778                modified: entry
779                    .modified
780                    .map(|t| chrono::DateTime::<chrono::Utc>::from(t).to_rfc3339()),
781                extension: entry.extension,
782            });
783        }
784        Ok(out)
785    }
786}
787
788// ---------------------------------------------------------------------------
789// ServerHandler — protocol metadata
790// ---------------------------------------------------------------------------
791
792#[tool_handler]
793impl ServerHandler for SeryMcpServer {
794    fn get_info(&self) -> ServerInfo {
795        // We build `Implementation` by hand rather than calling
796        // `Implementation::from_build_env()` because the latter
797        // captures rmcp's crate name + version at rmcp's compile
798        // time — clients would see `serverInfo.name = "rmcp"`. The
799        // CARGO_PKG_* macros expand against the crate currently
800        // being compiled (sery-mcp), giving the right identity.
801        let mut server_info =
802            Implementation::new(env!("CARGO_PKG_NAME"), env!("CARGO_PKG_VERSION"));
803        server_info.description = Some(env!("CARGO_PKG_DESCRIPTION").to_string());
804        let homepage = env!("CARGO_PKG_HOMEPAGE");
805        if !homepage.is_empty() {
806            server_info.website_url = Some(homepage.to_string());
807        }
808
809        ServerInfo::new(ServerCapabilities::builder().enable_tools().build())
810            .with_server_info(server_info)
811            .with_protocol_version(ProtocolVersion::V_2024_11_05)
812            .with_instructions(
813                "sery-mcp exposes the local files under the configured --root as MCP tools. \
814                 All tools are read-only. Path arguments are validated to fall under --root \
815                 (no .. escape, no absolute paths). v0.3 ships six tools: list_folder, \
816                 search_files, get_schema, sample_rows, read_document (DOCX/PDF/PPTX/HTML/IPYNB \
817                 → markdown), and query_sql (DataFusion-backed SQL on CSV/Parquet — file is \
818                 registered as table `data` for the duration of the call). \
819                 See https://github.com/seryai/sery-mcp."
820                    .to_string(),
821            )
822    }
823}
824
825// ---------------------------------------------------------------------------
826// Free helpers
827// ---------------------------------------------------------------------------
828
829/// Reject absolute paths, `..` segments, drive prefixes, and root
830/// anchors. Cheaper + safer than `canonicalize()` (no symlink TOCTOU).
831fn validate_relative_components(raw: &str) -> Result<(), McpError> {
832    let p = Path::new(raw);
833    // `Path::is_absolute()` is platform-aware: on Windows it only
834    // returns true for `C:\...` / UNC. We additionally reject leading
835    // `/` and `\` on every platform so a Unix-style absolute path
836    // like `/etc/passwd` can't sneak through on Windows (where
837    // is_absolute() would say false but the path still escapes the
838    // configured `--root` semantically).
839    if p.is_absolute() || raw.starts_with('/') || raw.starts_with('\\') {
840        return Err(McpError::invalid_params(
841            "'path' must be relative to --root (no absolute paths)",
842            None,
843        ));
844    }
845    for component in p.components() {
846        match component {
847            Component::ParentDir => {
848                return Err(McpError::invalid_params(
849                    "'path' must not contain '..' (no escaping the configured --root)",
850                    None,
851                ));
852            }
853            Component::Prefix(_) | Component::RootDir => {
854                // Reachable on Windows when the input is a drive
855                // prefix like `C:\Foo` — Path::is_absolute() returns
856                // true so the earlier branch fires first; this arm
857                // is defensive against future Component variants.
858                return Err(McpError::invalid_params(
859                    "'path' must be relative to --root (no absolute paths)",
860                    None,
861                ));
862            }
863            _ => {}
864        }
865    }
866    Ok(())
867}
868
869/// Convert a `Path` (possibly Windows-flavoured with `\`) to a
870/// forward-slash string.
871///
872/// All MCP-facing paths use `/` regardless of host platform. The LLM
873/// should see the same shape in every tool result, so a Windows host
874/// returning `data\finance\sales.csv` would (a) confuse cross-
875/// platform LLM prompts and (b) break round-tripping when the LLM
876/// passes the path back into a tool that does substring matching.
877/// Normalising at the output boundary keeps the contract simple.
878///
879/// Input parsing (PathBuf::join) handles either separator on Windows
880/// already, so we don't need a corresponding normalise on the way in.
881fn path_to_forward_slash(path: &Path) -> String {
882    let s = path.to_string_lossy().into_owned();
883    if std::path::MAIN_SEPARATOR == '/' {
884        s
885    } else {
886        s.replace(std::path::MAIN_SEPARATOR, "/")
887    }
888}
889
890/// Lowercase, dot-less file extension. Empty string when the file
891/// has no extension.
892fn extension_of(path: &Path) -> String {
893    path.extension()
894        .and_then(|s| s.to_str())
895        .map(str::to_ascii_lowercase)
896        .unwrap_or_default()
897}
898
899/// Map a `tabkit::DataType` to a stable, lowercase JSON string.
900fn data_type_str(t: tabkit::DataType) -> &'static str {
901    match t {
902        tabkit::DataType::Bool => "boolean",
903        tabkit::DataType::Integer => "integer",
904        tabkit::DataType::Float => "float",
905        tabkit::DataType::Date => "date",
906        tabkit::DataType::DateTime => "datetime",
907        tabkit::DataType::Text => "text",
908        // Covers `Unknown` plus any future `#[non_exhaustive]`
909        // additions tabkit ships in a minor version.
910        _ => "unknown",
911    }
912}
913
914/// Convert a tabkit cell value to a JSON value for sample-row output.
915fn value_to_json(v: &tabkit::Value) -> serde_json::Value {
916    match v {
917        tabkit::Value::Bool(b) => serde_json::Value::Bool(*b),
918        tabkit::Value::Integer(i) => serde_json::Value::Number((*i).into()),
919        tabkit::Value::Float(f) => serde_json::Number::from_f64(*f)
920            .map_or(serde_json::Value::Null, serde_json::Value::Number),
921        tabkit::Value::Date(s) | tabkit::Value::DateTime(s) | tabkit::Value::Text(s) => {
922            serde_json::Value::String(s.clone())
923        }
924        // Covers `Null` plus any future `#[non_exhaustive]` additions
925        // — all map cleanly to JSON null.
926        _ => serde_json::Value::Null,
927    }
928}
929
930/// One file registered as a SQL table inside the session `query_sql`
931/// opens. `table` is what the LLM uses in its SQL; `path_for_sql` is
932/// the absolute filesystem path (or glob) we interpolate into the
933/// backend's `read_csv_auto` / `read_parquet` calls.
934#[derive(Debug)]
935struct TableSpec {
936    table: String,
937    path_for_sql: String,
938    relative_path: String,
939    format: &'static str,
940}
941
942/// Reject SQL that contains DDL / DML / admin keywords.
943///
944/// We tokenise the query on non-alphanumeric chars (so
945/// `SELECT "INSERTION" FROM data` doesn't false-positive on
946/// `INSERT`), then check each token against the blacklist. False
947/// positives are possible only when a query *literally* references
948/// a forbidden keyword as a string value (`WHERE name = 'INSERT'`);
949/// the LLM can reword in those rare cases. False negatives — which
950/// would be security holes — aren't possible because every
951/// dangerous SQL statement starts with one of these keywords.
952fn validate_query_sql(sql: &str) -> Result<(), McpError> {
953    let trimmed = sql.trim();
954    if trimmed.is_empty() {
955        return Err(McpError::invalid_params("`sql` must not be empty", None));
956    }
957    let upper = trimmed.to_ascii_uppercase();
958    if !upper.starts_with("SELECT") && !upper.starts_with("WITH") {
959        return Err(McpError::invalid_params(
960            "sql must start with SELECT or WITH (read-only queries only)",
961            None,
962        ));
963    }
964
965    const FORBIDDEN: &[&str] = &[
966        "INSERT",
967        "UPDATE",
968        "DELETE",
969        "CREATE",
970        "DROP",
971        "ALTER",
972        "ATTACH",
973        "DETACH",
974        "COPY",
975        "PRAGMA",
976        "INSTALL",
977        "LOAD",
978        "EXPORT",
979        "IMPORT",
980        "CHECKPOINT",
981        "VACUUM",
982        "ANALYZE",
983        "TRUNCATE",
984        "GRANT",
985        "REVOKE",
986        "BEGIN",
987        "COMMIT",
988        "ROLLBACK",
989        "SAVEPOINT",
990    ];
991    let tokens: std::collections::HashSet<&str> = upper
992        .split(|c: char| !c.is_ascii_alphanumeric() && c != '_')
993        .filter(|t| !t.is_empty())
994        .collect();
995    for kw in FORBIDDEN {
996        if tokens.contains(*kw) {
997            return Err(McpError::invalid_params(
998                format!("forbidden SQL keyword: {kw} (query_sql is read-only)"),
999                None,
1000            ));
1001        }
1002    }
1003    Ok(())
1004}
1005
1006/// Build the `CREATE OR REPLACE VIEW <table> AS SELECT * FROM
1007/// read_csv_auto(...) / read_parquet(...)` setup statement the SQL
1008/// backend runs to register a file as a queryable table.
1009fn build_register_view(table: &str, path_for_sql: &str, format: &str) -> Result<String, McpError> {
1010    let escaped_path = sql_string_literal(path_for_sql);
1011    let read_call = match format {
1012        "csv" => format!("read_csv_auto({escaped_path})"),
1013        "tsv" => format!("read_csv_auto({escaped_path}, delim='\\t')"),
1014        "parquet" => format!("read_parquet({escaped_path})"),
1015        other => {
1016            return Err(McpError::invalid_params(
1017                format!(
1018                    "query_sql supports csv / tsv / parquet only — got '{other}'. \
1019                     Use get_schema or sample_rows for XLSX/ODS files."
1020                ),
1021                None,
1022            ));
1023        }
1024    };
1025    // `table` is already validated as a SQL identifier; safe to
1026    // interpolate without quotes.
1027    Ok(format!(
1028        "CREATE OR REPLACE VIEW {table} AS SELECT * FROM {read_call}"
1029    ))
1030}
1031
1032/// Echo back what the caller registered, in human-readable form.
1033/// Single file: `"sales.csv"`. Multi file: `"customers=customers.csv,
1034/// orders=orders.parquet"`.
1035fn describe_input(specs: &[TableSpec]) -> String {
1036    if specs.len() == 1 && specs[0].table == "data" {
1037        return specs[0].relative_path.clone();
1038    }
1039    specs
1040        .iter()
1041        .map(|s| format!("{}={}", s.table, s.relative_path))
1042        .collect::<Vec<_>>()
1043        .join(", ")
1044}
1045
1046/// Detect `query_sql` glob patterns. The SQL backend supports `*`,
1047/// `**`, `?`, and `[...]`.
1048fn is_glob_pattern(s: &str) -> bool {
1049    s.contains('*') || s.contains('?') || s.contains('[')
1050}
1051
1052/// Strict SQL identifier check: `[a-zA-Z_][a-zA-Z0-9_]*`. We don't
1053/// support quoted identifiers (e.g. `"my table"`) for table names —
1054/// keeps the safe-interpolation invariant simple.
1055fn is_valid_sql_identifier(name: &str) -> bool {
1056    let mut chars = name.chars();
1057    let Some(first) = chars.next() else {
1058        return false;
1059    };
1060    if !(first.is_ascii_alphabetic() || first == '_') {
1061        return false;
1062    }
1063    chars.all(|c| c.is_ascii_alphanumeric() || c == '_')
1064}
1065
1066/// Resolve a `query_sql` path argument's format.
1067fn format_for_query_sql(path: &str) -> Result<&'static str, McpError> {
1068    let lower = path.to_ascii_lowercase();
1069    if lower.ends_with(".csv") {
1070        Ok("csv")
1071    } else if lower.ends_with(".tsv") {
1072        Ok("tsv")
1073    } else if lower.ends_with(".parquet") {
1074        Ok("parquet")
1075    } else {
1076        Err(McpError::invalid_params(
1077            format!(
1078                "query_sql expects a path / glob ending in .csv, .tsv, or .parquet — got '{path}'"
1079            ),
1080            None,
1081        ))
1082    }
1083}
1084
1085/// Escape a string for use inside a single-quoted SQL string literal:
1086/// doubles every embedded `'`. Used for filesystem paths that might
1087/// contain quotes (legal on macOS/Linux, rare in practice).
1088fn sql_string_literal(s: &str) -> String {
1089    format!("'{}'", s.replace('\'', "''"))
1090}
1091
1092/// Convert one cell of an Arrow array to a JSON value.
1093///
1094/// The SQL backend's Arrow output uses the standard `arrow` crate
1095/// types — the same matching code works against any Arrow-emitting
1096/// engine. Numeric / boolean types map to native JSON. Date /
1097/// timestamp types serialise as ISO 8601 strings (round-trips
1098/// cleanly through MCP / JSON / the LLM). Anything we don't recognise
1099/// downgrades to its `Display` representation via Arrow's
1100/// `ArrayFormatter` — keeps `query_sql` resilient to the backend
1101/// returning new types in minor versions.
1102#[allow(clippy::too_many_lines)] // exhaustive type-match by design — splitting harms readability
1103fn arrow_value_to_json(array: &dyn duckdb::arrow::array::Array, row: usize) -> serde_json::Value {
1104    use duckdb::arrow::array::{
1105        BooleanArray, Date32Array, Date64Array, Decimal128Array, Float32Array, Float64Array,
1106        Int16Array, Int32Array, Int64Array, Int8Array, LargeStringArray, StringArray,
1107        TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray,
1108        TimestampSecondArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array,
1109    };
1110    use duckdb::arrow::datatypes::DataType;
1111
1112    if array.is_null(row) {
1113        return serde_json::Value::Null;
1114    }
1115
1116    macro_rules! number {
1117        ($arr:ty) => {{
1118            let typed = array
1119                .as_any()
1120                .downcast_ref::<$arr>()
1121                .expect("downcast matches the matched DataType");
1122            serde_json::json!(typed.value(row))
1123        }};
1124    }
1125
1126    match array.data_type() {
1127        DataType::Boolean => {
1128            let typed = array
1129                .as_any()
1130                .downcast_ref::<BooleanArray>()
1131                .expect("BooleanArray");
1132            serde_json::Value::Bool(typed.value(row))
1133        }
1134        DataType::Int8 => number!(Int8Array),
1135        DataType::Int16 => number!(Int16Array),
1136        DataType::Int32 => number!(Int32Array),
1137        DataType::Int64 => number!(Int64Array),
1138        DataType::UInt8 => number!(UInt8Array),
1139        DataType::UInt16 => number!(UInt16Array),
1140        DataType::UInt32 => number!(UInt32Array),
1141        DataType::UInt64 => number!(UInt64Array),
1142        DataType::Float32 => {
1143            let typed = array
1144                .as_any()
1145                .downcast_ref::<Float32Array>()
1146                .expect("Float32Array");
1147            serde_json::Number::from_f64(f64::from(typed.value(row)))
1148                .map_or(serde_json::Value::Null, serde_json::Value::Number)
1149        }
1150        DataType::Float64 => {
1151            let typed = array
1152                .as_any()
1153                .downcast_ref::<Float64Array>()
1154                .expect("Float64Array");
1155            serde_json::Number::from_f64(typed.value(row))
1156                .map_or(serde_json::Value::Null, serde_json::Value::Number)
1157        }
1158        DataType::Utf8 => {
1159            let typed = array
1160                .as_any()
1161                .downcast_ref::<StringArray>()
1162                .expect("StringArray");
1163            serde_json::Value::String(typed.value(row).to_string())
1164        }
1165        DataType::LargeUtf8 => {
1166            let typed = array
1167                .as_any()
1168                .downcast_ref::<LargeStringArray>()
1169                .expect("LargeStringArray");
1170            serde_json::Value::String(typed.value(row).to_string())
1171        }
1172        DataType::Date32 => {
1173            let typed = array
1174                .as_any()
1175                .downcast_ref::<Date32Array>()
1176                .expect("Date32Array");
1177            typed
1178                .value_as_date(row)
1179                .map_or(serde_json::Value::Null, |d| {
1180                    serde_json::Value::String(d.format("%Y-%m-%d").to_string())
1181                })
1182        }
1183        DataType::Date64 => {
1184            let typed = array
1185                .as_any()
1186                .downcast_ref::<Date64Array>()
1187                .expect("Date64Array");
1188            typed
1189                .value_as_date(row)
1190                .map_or(serde_json::Value::Null, |d| {
1191                    serde_json::Value::String(d.format("%Y-%m-%d").to_string())
1192                })
1193        }
1194        DataType::Decimal128(_, scale) => {
1195            // SUM/AVG of integer columns returns HUGEINT (a 128-bit
1196            // integer), which Arrow encodes as Decimal128(38, 0). For
1197            // scale-0 values that fit in i64, emit a JSON number so
1198            // the LLM gets `100` (not `"100"`). Larger or non-zero-
1199            // scale values fall through to a string preserving full
1200            // precision.
1201            let typed = array
1202                .as_any()
1203                .downcast_ref::<Decimal128Array>()
1204                .expect("Decimal128Array");
1205            let raw = typed.value(row);
1206            if *scale == 0 {
1207                if let Ok(fits) = i64::try_from(raw) {
1208                    return serde_json::Value::Number(fits.into());
1209                }
1210            }
1211            use duckdb::arrow::util::display::{ArrayFormatter, FormatOptions};
1212            ArrayFormatter::try_new(array, &FormatOptions::default()).map_or_else(
1213                |_| serde_json::Value::String(format!("(decimal {raw})")),
1214                |fmt| serde_json::Value::String(fmt.value(row).to_string()),
1215            )
1216        }
1217        DataType::Timestamp(_, _) => {
1218            // Cover all four precision variants by trying the most
1219            // common first. DataFusion CSV/Parquet readers emit
1220            // microsecond timestamps by default.
1221            if let Some(typed) = array.as_any().downcast_ref::<TimestampMicrosecondArray>() {
1222                return typed
1223                    .value_as_datetime(row)
1224                    .map_or(serde_json::Value::Null, |d| {
1225                        serde_json::Value::String(d.and_utc().to_rfc3339())
1226                    });
1227            }
1228            if let Some(typed) = array.as_any().downcast_ref::<TimestampMillisecondArray>() {
1229                return typed
1230                    .value_as_datetime(row)
1231                    .map_or(serde_json::Value::Null, |d| {
1232                        serde_json::Value::String(d.and_utc().to_rfc3339())
1233                    });
1234            }
1235            if let Some(typed) = array.as_any().downcast_ref::<TimestampNanosecondArray>() {
1236                return typed
1237                    .value_as_datetime(row)
1238                    .map_or(serde_json::Value::Null, |d| {
1239                        serde_json::Value::String(d.and_utc().to_rfc3339())
1240                    });
1241            }
1242            if let Some(typed) = array.as_any().downcast_ref::<TimestampSecondArray>() {
1243                return typed
1244                    .value_as_datetime(row)
1245                    .map_or(serde_json::Value::Null, |d| {
1246                        serde_json::Value::String(d.and_utc().to_rfc3339())
1247                    });
1248            }
1249            serde_json::Value::String(format!("(unsupported timestamp at row {row})"))
1250        }
1251        // For decimals, lists, structs, dictionaries, etc. — fall back
1252        // to DataFusion's `pretty_format_value`-style display rather
1253        // than panicking. Keeps query_sql resilient to schemas we
1254        // didn't anticipate.
1255        _ => {
1256            use duckdb::arrow::util::display::{ArrayFormatter, FormatOptions};
1257            ArrayFormatter::try_new(array, &FormatOptions::default()).map_or_else(
1258                |_| {
1259                    serde_json::Value::String(format!("(unrenderable {} value)", array.data_type()))
1260                },
1261                |fmt| serde_json::Value::String(fmt.value(row).to_string()),
1262            )
1263        }
1264    }
1265}
1266
1267/// Serialize any `Serialize` value to pretty JSON wrapped in a
1268/// `CallToolResult::success`. Centralised so all tools format the
1269/// same way.
1270fn as_json_result<T: serde::Serialize>(value: &T) -> Result<CallToolResult, McpError> {
1271    let json = serde_json::to_string_pretty(value)
1272        .map_err(|e| McpError::internal_error(format!("serialize result: {e}"), None))?;
1273    Ok(CallToolResult::success(vec![Content::text(json)]))
1274}
1275
1276// ---------------------------------------------------------------------------
1277// Tests
1278// ---------------------------------------------------------------------------
1279
1280#[cfg(test)]
1281mod tests {
1282    use super::*;
1283    use std::fs;
1284    use tempfile::TempDir;
1285
1286    fn make_server(root: &Path) -> SeryMcpServer {
1287        SeryMcpServer::new(root.canonicalize().expect("temp dir must canonicalise"))
1288    }
1289
1290    // ── path-resolution ──
1291
1292    #[test]
1293    fn resolve_subpath_defaults_to_root() {
1294        let dir = TempDir::new().unwrap();
1295        let server = make_server(dir.path());
1296        for input in [None, Some(""), Some(".")] {
1297            let resolved = server.resolve_subpath(input).unwrap();
1298            assert_eq!(resolved, server.root);
1299        }
1300    }
1301
1302    #[test]
1303    fn resolve_subpath_rejects_absolute() {
1304        let dir = TempDir::new().unwrap();
1305        let server = make_server(dir.path());
1306        let err = server.resolve_subpath(Some("/etc/passwd")).unwrap_err();
1307        assert!(format!("{err:?}").contains("absolute"));
1308    }
1309
1310    #[test]
1311    fn resolve_subpath_rejects_parent_dir() {
1312        let dir = TempDir::new().unwrap();
1313        let server = make_server(dir.path());
1314        let err = server.resolve_subpath(Some("../etc")).unwrap_err();
1315        assert!(format!("{err:?}").contains(".."));
1316    }
1317
1318    #[test]
1319    fn resolve_required_file_rejects_directory() {
1320        let dir = TempDir::new().unwrap();
1321        fs::create_dir(dir.path().join("sub")).unwrap();
1322        let server = make_server(dir.path());
1323        let err = server.resolve_required_file("sub").unwrap_err();
1324        assert!(format!("{err:?}").contains("regular file"));
1325    }
1326
1327    #[test]
1328    fn resolve_required_file_rejects_missing() {
1329        let dir = TempDir::new().unwrap();
1330        let server = make_server(dir.path());
1331        let err = server.resolve_required_file("nope.csv").unwrap_err();
1332        assert!(format!("{err:?}").contains("not readable"));
1333    }
1334
1335    #[test]
1336    fn resolve_required_file_accepts_real_file() {
1337        let dir = TempDir::new().unwrap();
1338        fs::write(dir.path().join("a.csv"), "x,y\n").unwrap();
1339        let server = make_server(dir.path());
1340        let resolved = server.resolve_required_file("a.csv").unwrap();
1341        assert_eq!(resolved, server.root.join("a.csv"));
1342    }
1343
1344    // ── walk_entries ──
1345
1346    #[test]
1347    fn walk_entries_emits_files_under_root() {
1348        let dir = TempDir::new().unwrap();
1349        fs::write(dir.path().join("a.csv"), "x,y\n1,2\n").unwrap();
1350        fs::write(dir.path().join("b.txt"), "hello").unwrap();
1351        let server = make_server(dir.path());
1352        let entries = server.walk_entries(server.root(), 100).unwrap();
1353        assert_eq!(entries.len(), 2);
1354        let names: Vec<_> = entries.iter().map(|e| e.relative_path.clone()).collect();
1355        assert!(names.contains(&"a.csv".to_string()));
1356        assert!(names.contains(&"b.txt".to_string()));
1357    }
1358
1359    #[test]
1360    fn walk_entries_respects_limit() {
1361        let dir = TempDir::new().unwrap();
1362        for i in 0..10 {
1363            fs::write(dir.path().join(format!("f{i}.txt")), "x").unwrap();
1364        }
1365        let server = make_server(dir.path());
1366        let entries = server.walk_entries(server.root(), 3).unwrap();
1367        assert_eq!(entries.len(), 3);
1368    }
1369
1370    #[test]
1371    fn walk_entries_lowercases_extension() {
1372        let dir = TempDir::new().unwrap();
1373        fs::write(dir.path().join("REPORT.PDF"), "%PDF-").unwrap();
1374        let server = make_server(dir.path());
1375        let entries = server.walk_entries(server.root(), 100).unwrap();
1376        assert_eq!(entries.len(), 1);
1377        assert_eq!(entries[0].extension, "pdf");
1378    }
1379
1380    // ── get_schema (tabkit) ──
1381
1382    #[test]
1383    fn get_schema_returns_csv_columns() {
1384        let dir = TempDir::new().unwrap();
1385        fs::write(
1386            dir.path().join("sales.csv"),
1387            "id,name,amount\n1,alice,99.5\n2,bob,150.0\n",
1388        )
1389        .unwrap();
1390        let server = make_server(dir.path());
1391        let result = server
1392            .get_schema(Parameters(GetSchemaRequest {
1393                path: "sales.csv".into(),
1394                sheet: None,
1395            }))
1396            .unwrap();
1397        let payload = result_text(&result);
1398        let parsed: SchemaResponseDe = serde_json::from_str(&payload).unwrap();
1399        assert_eq!(parsed.format, "csv");
1400        assert_eq!(parsed.columns.len(), 3);
1401        let names: Vec<_> = parsed.columns.iter().map(|c| c.name.as_str()).collect();
1402        assert_eq!(names, vec!["id", "name", "amount"]);
1403    }
1404
1405    // ── sample_rows ──
1406
1407    #[test]
1408    fn sample_rows_returns_header_keyed_objects() {
1409        let dir = TempDir::new().unwrap();
1410        fs::write(
1411            dir.path().join("sales.csv"),
1412            "id,name,amount\n1,alice,99.5\n2,bob,150.0\n3,eve,200.0\n",
1413        )
1414        .unwrap();
1415        let server = make_server(dir.path());
1416        let result = server
1417            .sample_rows(Parameters(SampleRowsRequest {
1418                path: "sales.csv".into(),
1419                limit: Some(2),
1420                sheet: None,
1421            }))
1422            .unwrap();
1423        let payload = result_text(&result);
1424        let parsed: SamplesResponseDe = serde_json::from_str(&payload).unwrap();
1425        assert_eq!(parsed.columns, vec!["id", "name", "amount"]);
1426        assert_eq!(parsed.rows.len(), 2);
1427        assert_eq!(parsed.rows[0].get("name").unwrap().as_str(), Some("alice"));
1428    }
1429
1430    // ── search_files ──
1431
1432    #[test]
1433    fn search_files_ranks_basename_match_above_path_match() {
1434        let dir = TempDir::new().unwrap();
1435        fs::create_dir_all(dir.path().join("data/finance")).unwrap();
1436        fs::write(dir.path().join("data/finance/sales.csv"), "x").unwrap();
1437        fs::write(dir.path().join("salesreport.csv"), "x").unwrap();
1438        fs::write(dir.path().join("revenue.csv"), "x").unwrap();
1439        let server = make_server(dir.path());
1440        let result = server
1441            .search_files(Parameters(SearchFilesRequest {
1442                query: "sales".into(),
1443                extensions: None,
1444                limit: None,
1445            }))
1446            .unwrap();
1447        let payload = result_text(&result);
1448        let hits: Vec<SearchHitDe> = serde_json::from_str(&payload).unwrap();
1449        assert_eq!(hits.len(), 2);
1450        // sales.csv (exact stem) should outrank salesreport.csv (startswith)
1451        assert_eq!(hits[0].relative_path, "data/finance/sales.csv");
1452        assert!(hits[0].score > hits[1].score);
1453    }
1454
1455    #[test]
1456    fn search_files_extension_filter() {
1457        let dir = TempDir::new().unwrap();
1458        fs::write(dir.path().join("notes.csv"), "x").unwrap();
1459        fs::write(dir.path().join("notes.txt"), "x").unwrap();
1460        let server = make_server(dir.path());
1461        let result = server
1462            .search_files(Parameters(SearchFilesRequest {
1463                query: "notes".into(),
1464                extensions: Some(vec!["csv".into()]),
1465                limit: None,
1466            }))
1467            .unwrap();
1468        let hits: Vec<SearchHitDe> = serde_json::from_str(&result_text(&result)).unwrap();
1469        assert_eq!(hits.len(), 1);
1470        assert_eq!(hits[0].extension, "csv");
1471    }
1472
1473    // ── query_sql ──
1474
1475    fn query_req(
1476        path: Option<&str>,
1477        tables: Option<std::collections::HashMap<String, String>>,
1478        sql: &str,
1479        limit: Option<usize>,
1480    ) -> QuerySqlRequest {
1481        QuerySqlRequest {
1482            path: path.map(String::from),
1483            tables,
1484            sql: sql.into(),
1485            limit,
1486        }
1487    }
1488
1489    #[test]
1490    fn query_sql_csv_happy_path() {
1491        let dir = TempDir::new().unwrap();
1492        fs::write(
1493            dir.path().join("sales.csv"),
1494            "id,name,amount\n1,alice,100\n2,bob,250\n3,eve,50\n",
1495        )
1496        .unwrap();
1497        let server = make_server(dir.path());
1498        let result = server
1499            .query_sql(Parameters(query_req(
1500                Some("sales.csv"),
1501                None,
1502                "SELECT name, amount FROM data WHERE amount > 75 ORDER BY amount",
1503                None,
1504            )))
1505            .unwrap();
1506        let parsed: QueryResponseDe = serde_json::from_str(&result_text(&result)).unwrap();
1507        assert_eq!(parsed.format, "csv");
1508        assert_eq!(parsed.columns, vec!["name", "amount"]);
1509        assert_eq!(parsed.row_count, 2);
1510        assert!(!parsed.truncated);
1511        assert_eq!(parsed.rows[0].get("name").unwrap().as_str(), Some("alice"));
1512        assert_eq!(parsed.rows[1].get("name").unwrap().as_str(), Some("bob"));
1513        assert_eq!(parsed.input, "sales.csv");
1514    }
1515
1516    #[test]
1517    fn query_sql_truncates_at_limit() {
1518        use std::fmt::Write as _;
1519        let dir = TempDir::new().unwrap();
1520        let mut csv = String::from("n\n");
1521        for i in 0..20 {
1522            writeln!(csv, "{i}").unwrap();
1523        }
1524        fs::write(dir.path().join("nums.csv"), csv).unwrap();
1525        let server = make_server(dir.path());
1526        let result = server
1527            .query_sql(Parameters(query_req(
1528                Some("nums.csv"),
1529                None,
1530                "SELECT n FROM data",
1531                Some(5),
1532            )))
1533            .unwrap();
1534        let parsed: QueryResponseDe = serde_json::from_str(&result_text(&result)).unwrap();
1535        assert_eq!(parsed.row_count, 5);
1536        assert!(parsed.truncated);
1537    }
1538
1539    #[test]
1540    fn query_sql_rejects_unsupported_format() {
1541        let dir = TempDir::new().unwrap();
1542        fs::write(dir.path().join("notes.txt"), "hi").unwrap();
1543        let server = make_server(dir.path());
1544        let err = server
1545            .query_sql(Parameters(query_req(
1546                Some("notes.txt"),
1547                None,
1548                "SELECT 1",
1549                None,
1550            )))
1551            .unwrap_err();
1552        assert!(format!("{err:?}").to_lowercase().contains(".csv"));
1553    }
1554
1555    #[test]
1556    fn query_sql_surfaces_sql_parse_errors() {
1557        let dir = TempDir::new().unwrap();
1558        fs::write(dir.path().join("a.csv"), "x\n1\n").unwrap();
1559        let server = make_server(dir.path());
1560        let err = server
1561            .query_sql(Parameters(query_req(
1562                Some("a.csv"),
1563                None,
1564                "SELEKT * FROM data",
1565                None,
1566            )))
1567            .unwrap_err();
1568        let msg = format!("{err:?}").to_lowercase();
1569        assert!(msg.contains("sql") || msg.contains("read-only"));
1570    }
1571
1572    #[test]
1573    fn query_sql_blocks_ddl() {
1574        let dir = TempDir::new().unwrap();
1575        fs::write(dir.path().join("a.csv"), "x\n1\n").unwrap();
1576        let server = make_server(dir.path());
1577        for evil in [
1578            "DROP TABLE data",
1579            "ATTACH '/etc/passwd' AS p",
1580            "INSERT INTO data VALUES (1)",
1581            "PRAGMA table_info('data')",
1582        ] {
1583            let err = server
1584                .query_sql(Parameters(query_req(Some("a.csv"), None, evil, None)))
1585                .unwrap_err();
1586            let msg = format!("{err:?}").to_lowercase();
1587            assert!(
1588                msg.contains("forbidden") || msg.contains("read-only"),
1589                "expected SQL '{evil}' to be rejected; got {err:?}"
1590            );
1591        }
1592    }
1593
1594    #[test]
1595    fn query_sql_multi_file_join() {
1596        let dir = TempDir::new().unwrap();
1597        fs::write(
1598            dir.path().join("customers.csv"),
1599            "id,name\n1,Alice\n2,Bob\n",
1600        )
1601        .unwrap();
1602        fs::write(
1603            dir.path().join("orders.csv"),
1604            "customer_id,amount\n1,100\n1,50\n2,200\n",
1605        )
1606        .unwrap();
1607        let server = make_server(dir.path());
1608        let mut tables = std::collections::HashMap::new();
1609        tables.insert("customers".into(), "customers.csv".into());
1610        tables.insert("orders".into(), "orders.csv".into());
1611        let result = server
1612            .query_sql(Parameters(query_req(
1613                None,
1614                Some(tables),
1615                "SELECT c.name, SUM(o.amount) AS total \
1616                 FROM customers c JOIN orders o ON c.id = o.customer_id \
1617                 GROUP BY c.name ORDER BY total DESC",
1618                None,
1619            )))
1620            .unwrap();
1621        let parsed: QueryResponseDe = serde_json::from_str(&result_text(&result)).unwrap();
1622        assert_eq!(parsed.columns, vec!["name", "total"]);
1623        assert_eq!(parsed.row_count, 2);
1624        assert_eq!(parsed.rows[0].get("name").unwrap().as_str(), Some("Bob"));
1625        assert_eq!(parsed.rows[1].get("name").unwrap().as_str(), Some("Alice"));
1626        // input echoes alphabetised because resolve_table_specs sorts.
1627        assert!(parsed.input.contains("customers=customers.csv"));
1628        assert!(parsed.input.contains("orders=orders.csv"));
1629    }
1630
1631    #[test]
1632    fn query_sql_glob_pattern() {
1633        let dir = TempDir::new().unwrap();
1634        fs::write(dir.path().join("jan.csv"), "amt\n10\n20\n").unwrap();
1635        fs::write(dir.path().join("feb.csv"), "amt\n30\n40\n").unwrap();
1636        let server = make_server(dir.path());
1637        let result = server
1638            .query_sql(Parameters(query_req(
1639                Some("*.csv"),
1640                None,
1641                "SELECT SUM(amt) AS total FROM data",
1642                None,
1643            )))
1644            .unwrap();
1645        let parsed: QueryResponseDe = serde_json::from_str(&result_text(&result)).unwrap();
1646        assert_eq!(parsed.row_count, 1);
1647        assert_eq!(
1648            parsed.rows[0]
1649                .get("total")
1650                .and_then(serde_json::Value::as_i64),
1651            Some(100)
1652        );
1653    }
1654
1655    #[test]
1656    fn query_sql_rejects_both_path_and_tables() {
1657        let dir = TempDir::new().unwrap();
1658        fs::write(dir.path().join("a.csv"), "x\n1\n").unwrap();
1659        let server = make_server(dir.path());
1660        let mut tables = std::collections::HashMap::new();
1661        tables.insert("t".into(), "a.csv".into());
1662        let err = server
1663            .query_sql(Parameters(query_req(
1664                Some("a.csv"),
1665                Some(tables),
1666                "SELECT 1",
1667                None,
1668            )))
1669            .unwrap_err();
1670        assert!(format!("{err:?}").contains("either"));
1671    }
1672
1673    #[test]
1674    fn query_sql_rejects_invalid_table_name() {
1675        let dir = TempDir::new().unwrap();
1676        fs::write(dir.path().join("a.csv"), "x\n1\n").unwrap();
1677        let server = make_server(dir.path());
1678        let mut tables = std::collections::HashMap::new();
1679        tables.insert("evil; DROP TABLE x".into(), "a.csv".into());
1680        let err = server
1681            .query_sql(Parameters(query_req(None, Some(tables), "SELECT 1", None)))
1682            .unwrap_err();
1683        assert!(format!("{err:?}").contains("identifier"));
1684    }
1685
1686    #[test]
1687    fn search_files_rejects_empty_query() {
1688        let dir = TempDir::new().unwrap();
1689        let server = make_server(dir.path());
1690        let err = server
1691            .search_files(Parameters(SearchFilesRequest {
1692                query: "   ".into(),
1693                extensions: None,
1694                limit: None,
1695            }))
1696            .unwrap_err();
1697        assert!(format!("{err:?}").contains("empty"));
1698    }
1699
1700    // ── helpers used only by tests ──
1701
1702    fn result_text(result: &CallToolResult) -> String {
1703        let first = result.content.first().expect("at least one content item");
1704        // CallToolResult.content[i] is a `Content`; downcast to text via
1705        // serde round-trip is overkill — the SDK exposes the raw text via
1706        // `as_text()`. Fall back to JSON-serialising if not text.
1707        if let Some(text) = first.as_text() {
1708            text.text.clone()
1709        } else {
1710            serde_json::to_string(&first).unwrap()
1711        }
1712    }
1713
1714    /// Owned mirror of `SchemaResponse` (the source struct lives behind
1715    /// `pub use` and serialises field-by-field; we want a deserialiser
1716    /// for tests).
1717    #[derive(serde::Deserialize)]
1718    struct SchemaResponseDe {
1719        #[allow(dead_code)]
1720        relative_path: String,
1721        format: String,
1722        columns: Vec<ColumnInfoDe>,
1723        #[allow(dead_code)]
1724        row_count: Option<u64>,
1725    }
1726
1727    #[derive(serde::Deserialize)]
1728    struct ColumnInfoDe {
1729        name: String,
1730        #[serde(rename = "type")]
1731        #[allow(dead_code)]
1732        data_type: String,
1733        #[allow(dead_code)]
1734        nullable: bool,
1735    }
1736
1737    #[derive(serde::Deserialize)]
1738    struct SamplesResponseDe {
1739        #[allow(dead_code)]
1740        relative_path: String,
1741        #[allow(dead_code)]
1742        format: String,
1743        columns: Vec<String>,
1744        rows: Vec<serde_json::Map<String, serde_json::Value>>,
1745        #[allow(dead_code)]
1746        row_count: Option<u64>,
1747    }
1748
1749    #[derive(serde::Deserialize)]
1750    struct SearchHitDe {
1751        relative_path: String,
1752        #[allow(dead_code)]
1753        size_bytes: u64,
1754        extension: String,
1755        score: f64,
1756        #[allow(dead_code)]
1757        why_matched: String,
1758    }
1759
1760    #[derive(serde::Deserialize)]
1761    struct QueryResponseDe {
1762        input: String,
1763        format: String,
1764        columns: Vec<String>,
1765        rows: Vec<serde_json::Map<String, serde_json::Value>>,
1766        row_count: usize,
1767        truncated: bool,
1768    }
1769}