use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use crate::cosmos::CosmosBackend;
use crate::mcp::error::McpError;
use crate::mcp::expose::ExposeResolver;
use crate::mcp::filter::{cosmos_sql::SqlBuilder, parse};
use super::{OrderBy, SortDir};
#[derive(Debug, Deserialize, JsonSchema)]
pub struct QueryRequest {
pub data_source: String,
#[serde(rename = "where")]
pub r#where: Option<Value>,
pub order_by: Option<Vec<OrderBy>>,
#[serde(default = "default_top")]
pub top: usize,
pub cursor: Option<String>,
#[serde(default)]
pub count_only: bool,
#[serde(default)]
pub include_deleted: bool,
}
fn default_top() -> usize {
50
}
#[derive(Debug, Serialize)]
pub struct QueryResponse {
pub items: Vec<Value>,
pub next_cursor: Option<String>,
pub total: u64,
}
pub async fn run(
cosmos: &dyn CosmosBackend,
expose: &ExposeResolver,
req: QueryRequest,
) -> Result<QueryResponse, McpError> {
let resolved = expose.resolve(&req.data_source)?;
let where_ast = match req.r#where {
Some(ref v) => Some(parse(v)?),
None => None,
};
let builder = SqlBuilder::new(req.include_deleted);
let user_filter = match &where_ast {
Some(w) => Some(builder.build(w)?),
None => None,
};
let mut sql = String::from("SELECT * FROM c");
let mut params: Vec<(String, Value)> = Vec::new();
if let Some(uf) = user_filter {
sql.push_str(" WHERE ");
sql.push_str(&uf.sql_fragment);
params.extend(uf.params);
} else if !req.include_deleted {
sql.push_str(" WHERE (NOT IS_DEFINED(c._deleted) OR c._deleted = false)");
}
if let Some(orderings) = &req.order_by {
sql.push_str(" ORDER BY ");
let mut parts: Vec<String> = Vec::with_capacity(orderings.len());
for o in orderings {
if !is_valid_field_path(&o.field) {
return Err(McpError::InvalidArgument(format!(
"order_by.field '{}' contains invalid characters; only \
[A-Za-z0-9_.] are permitted, must start with a letter or \
underscore",
o.field
)));
}
let dir = match o.dir {
SortDir::Asc => "ASC",
SortDir::Desc => "DESC",
};
parts.push(format!("c.{} {}", o.field, dir));
}
sql.push_str(&parts.join(", "));
}
let container = &resolved.backed_by[0].container;
if req.count_only {
let count_sql = if let Some(idx) = sql.find(" WHERE ") {
format!("SELECT VALUE COUNT(1) FROM c{}", &sql[idx..])
} else {
"SELECT VALUE COUNT(1) FROM c".to_string()
};
let count_sql = if let Some(o_idx) = count_sql.find(" ORDER BY") {
count_sql[..o_idx].to_string()
} else {
count_sql
};
let mut stream = cosmos.query(container, &count_sql, params).await?;
let total = if let Some(page) = stream.next_page().await? {
page.first().and_then(Value::as_u64).unwrap_or(0)
} else {
0
};
return Ok(QueryResponse {
items: vec![],
next_cursor: None,
total,
});
}
let mut stream = cosmos.query(container, &sql, params).await?;
let mut items = Vec::new();
let mut total: u64 = 0;
if let Some(page) = stream.next_page().await? {
total = page.len() as u64;
items.extend(page.into_iter().take(req.top));
}
let next_cursor = stream.continuation_token().map(String::from);
Ok(QueryResponse {
items,
next_cursor,
total,
})
}
fn is_valid_field_path(s: &str) -> bool {
if s.is_empty() {
return false;
}
for segment in s.split('.') {
if segment.is_empty() {
return false;
}
let mut chars = segment.chars();
let first = chars.next().unwrap();
if !(first.is_ascii_alphabetic() || first == '_') {
return false;
}
if !chars.all(|c| c.is_ascii_alphanumeric() || c == '_') {
return false;
}
}
true
}
#[cfg(test)]
mod field_path_tests {
use super::is_valid_field_path;
#[test]
fn accepts_simple_names() {
assert!(is_valid_field_path("status"));
assert!(is_valid_field_path("project_key"));
assert!(is_valid_field_path("_deleted"));
}
#[test]
fn accepts_dotted_paths() {
assert!(is_valid_field_path("assignee.email"));
assert!(is_valid_field_path("sprint.state"));
assert!(is_valid_field_path("a.b.c.d"));
}
#[test]
fn rejects_empty_and_dotted_edge_cases() {
assert!(!is_valid_field_path(""));
assert!(!is_valid_field_path("."));
assert!(!is_valid_field_path(".foo"));
assert!(!is_valid_field_path("foo."));
assert!(!is_valid_field_path("foo..bar"));
}
#[test]
fn rejects_starting_with_digit() {
assert!(!is_valid_field_path("1status"));
assert!(!is_valid_field_path("foo.1bar"));
}
#[test]
fn rejects_injection_attempts() {
assert!(!is_valid_field_path("status; DROP TABLE issues"));
assert!(!is_valid_field_path("status' OR '1'='1"));
assert!(!is_valid_field_path("status DESC, id LIMIT 1"));
assert!(!is_valid_field_path("status\nFROM"));
assert!(!is_valid_field_path("status WHERE x = 1"));
assert!(!is_valid_field_path("status[0]"));
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::mcp::tools::test_helpers::{
build_cosmos_with_jira_issues, build_expose, build_expose_jira_issues,
};
use serde_json::json;
#[tokio::test]
async fn query_returns_matching_docs() {
let cosmos = build_cosmos_with_jira_issues().await;
let expose = build_expose_jira_issues();
let req = QueryRequest {
data_source: "jira_issues".into(),
r#where: Some(json!({"status": "Open"})),
order_by: None,
top: 50,
cursor: None,
count_only: false,
include_deleted: false,
};
let result = run(&cosmos, &expose, req).await.unwrap();
assert_eq!(result.total, 3);
assert_eq!(result.items.len(), 3);
}
#[tokio::test]
async fn query_excludes_soft_deleted_by_default() {
let cosmos = build_cosmos_with_jira_issues().await;
let expose = build_expose_jira_issues();
let req = QueryRequest {
data_source: "jira_issues".into(),
r#where: None,
order_by: None,
top: 50,
cursor: None,
count_only: false,
include_deleted: false,
};
let result = run(&cosmos, &expose, req).await.unwrap();
assert_eq!(result.total, 5);
let ids: Vec<&str> = result
.items
.iter()
.map(|d| d["id"].as_str().unwrap())
.collect();
assert!(!ids.contains(&"i6"), "soft-deleted i6 should be excluded");
}
#[tokio::test]
async fn query_with_include_deleted_returns_tombstoned() {
let cosmos = build_cosmos_with_jira_issues().await;
let expose = build_expose_jira_issues();
let req = QueryRequest {
data_source: "jira_issues".into(),
r#where: None,
order_by: None,
top: 50,
cursor: None,
count_only: false,
include_deleted: true,
};
let result = run(&cosmos, &expose, req).await.unwrap();
assert_eq!(result.total, 6);
let ids: Vec<&str> = result
.items
.iter()
.map(|d| d["id"].as_str().unwrap())
.collect();
assert!(
ids.contains(&"i6"),
"i6 should be included when include_deleted=true"
);
}
#[tokio::test]
async fn query_returns_forbidden_for_unexposed_data_source() {
let cosmos = build_cosmos_with_jira_issues().await;
let expose = build_expose_jira_issues(); let req = QueryRequest {
data_source: "jira_sprints".into(),
r#where: None,
order_by: None,
top: 50,
cursor: None,
count_only: false,
include_deleted: false,
};
let err = run(&cosmos, &expose, req).await.unwrap_err();
assert!(
matches!(err, McpError::Forbidden(_)),
"expected Forbidden, got {err:?}"
);
}
#[tokio::test]
async fn query_count_only_returns_only_total() {
let cosmos = build_cosmos_with_jira_issues().await;
let expose = build_expose_jira_issues();
let req = QueryRequest {
data_source: "jira_issues".into(),
r#where: None,
order_by: None,
top: 50,
cursor: None,
count_only: true,
include_deleted: false,
};
let result = run(&cosmos, &expose, req).await.unwrap();
assert_eq!(result.total, 5); assert!(
result.items.is_empty(),
"items should be empty for count_only"
);
}
#[tokio::test]
async fn query_top_limits_page_size() {
let cosmos = build_cosmos_with_jira_issues().await;
let expose = build_expose_jira_issues();
let req = QueryRequest {
data_source: "jira_issues".into(),
r#where: None,
order_by: None,
top: 2,
cursor: None,
count_only: false,
include_deleted: false,
};
let result = run(&cosmos, &expose, req).await.unwrap();
assert_eq!(result.items.len(), 2);
assert_eq!(result.total, 5);
}
#[tokio::test]
async fn query_multiple_exposed_sources() {
let cosmos = build_cosmos_with_jira_issues().await;
let expose = build_expose(&[
("jira_issues", "jira_issue", "jira-issues"),
("jira_sprints", "jira_sprint", "jira-sprints"),
]);
let req = QueryRequest {
data_source: "jira_sprints".into(),
r#where: None,
order_by: None,
top: 50,
cursor: None,
count_only: false,
include_deleted: false,
};
let result = run(&cosmos, &expose, req).await.unwrap();
assert_eq!(result.total, 0);
}
}