use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use crate::mcp::error::McpError;
use crate::mcp::expose::ExposeResolver;
use crate::mcp::filter::{odata, parse};
use super::search_api::{RawHit, SearchApiAdapter};
#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Deserialize, Serialize, JsonSchema)]
#[serde(rename_all = "snake_case")]
pub enum IncludeContent {
#[default]
Snippet,
Full,
AgenticAnswer,
}
#[derive(Debug, Deserialize, JsonSchema)]
pub struct SearchRequest {
pub query: String,
pub data_sources: Option<Vec<String>>,
#[serde(rename = "where")]
pub r#where: Option<Value>,
#[serde(default = "default_top")]
pub top: usize,
pub cursor: Option<String>,
#[serde(default)]
pub include_deleted: bool,
#[serde(default)]
pub include_content: IncludeContent,
}
fn default_top() -> usize {
25
}
#[derive(Debug, Clone)]
pub struct SearchToolConfig {
pub disable_agentic: bool,
pub knowledge_base_name: String,
pub default_top: usize,
pub max_top: usize,
}
impl Default for SearchToolConfig {
fn default() -> Self {
Self {
disable_agentic: false,
knowledge_base_name: "quelch-kb".to_string(),
default_top: 25,
max_top: 100,
}
}
}
#[derive(Debug, Serialize)]
pub struct SearchResponse {
pub items: Vec<SearchItem>,
pub answer: Option<String>,
pub citations: Option<Vec<Value>>,
pub next_cursor: Option<String>,
pub total_estimate: u64,
}
#[derive(Debug, Serialize)]
pub struct SearchItem {
pub id: String,
pub score: f64,
pub data_source: String,
pub source_link: String,
pub snippet: Option<String>,
pub body: Option<String>,
pub fields: Value,
}
pub async fn run(
api: &dyn SearchApiAdapter,
expose: &ExposeResolver,
schema: &crate::mcp::schema::SchemaCatalog,
config: &SearchToolConfig,
req: SearchRequest,
) -> Result<SearchResponse, McpError> {
if config.disable_agentic && req.include_content == IncludeContent::AgenticAnswer {
return Err(McpError::InvalidArgument(
"agentic_answer is unavailable when disable_agentic is set".to_string(),
));
}
let top = req.top.min(config.max_top);
let sources: Vec<(String, String)> = match &req.data_sources {
Some(names) => {
let mut resolved = Vec::new();
for name in names {
let ds = expose.resolve(name)?;
resolved.push((name.clone(), ds.kind.clone()));
}
resolved
}
None => {
expose
.list_all()
.iter()
.filter(|(_, ds)| {
schema
.lookup(&ds.kind)
.map(|k| k.searchable)
.unwrap_or(false)
})
.map(|(name, ds)| (name.clone(), ds.kind.clone()))
.collect()
}
};
if sources.is_empty() {
return Ok(SearchResponse {
items: vec![],
answer: None,
citations: None,
next_cursor: None,
total_estimate: 0,
});
}
let odata_filter: Option<String> = match &req.r#where {
Some(v) => {
let ast = parse(v)?;
Some(odata::build(&ast, req.include_deleted)?)
}
None => {
if req.include_deleted {
None
} else {
Some("_deleted ne true".to_string())
}
}
};
let use_knowledge_base =
!config.disable_agentic && req.include_content == IncludeContent::AgenticAnswer;
let include_full_body = matches!(req.include_content, IncludeContent::Full);
let include_synthesis = matches!(req.include_content, IncludeContent::AgenticAnswer);
if use_knowledge_base {
let raw = api
.search_knowledge_base(
&config.knowledge_base_name,
&req.query,
odata_filter.as_deref(),
top,
req.cursor.as_deref(),
include_synthesis,
include_full_body,
)
.await?;
let data_source_name = sources
.first()
.map(|(n, _)| n.as_str())
.unwrap_or("unknown");
let items = map_hits(raw.hits, data_source_name, req.include_content);
return Ok(SearchResponse {
total_estimate: raw.total_estimate,
next_cursor: raw.next_cursor,
answer: raw.answer,
citations: raw.citations,
items,
});
}
let mut all_hits: Vec<(String, RawHit)> = Vec::new();
let mut total_estimate: u64 = 0;
let mut last_cursor: Option<String> = None;
let cursor_map = if sources.len() > 1 {
decode_cursor_map(req.cursor.as_deref())
} else {
None
};
let single_cursor = if sources.len() == 1 {
req.cursor.as_deref()
} else {
None
};
for (source_name, _kind) in &sources {
let per_source_cursor = if sources.len() == 1 {
single_cursor
} else {
cursor_map
.as_ref()
.and_then(|m| m.get(source_name.as_str()))
.and_then(Value::as_str)
};
let raw = api
.search_index(
source_name,
&req.query,
odata_filter.as_deref(),
top,
per_source_cursor,
include_full_body,
)
.await?;
total_estimate = total_estimate.saturating_add(raw.total_estimate);
if let Some(token) = raw.next_cursor {
last_cursor = Some(token);
}
for hit in raw.hits {
all_hits.push((source_name.clone(), hit));
}
}
all_hits.sort_by(|(_, a), (_, b)| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
all_hits.truncate(top);
let items: Vec<SearchItem> = all_hits
.into_iter()
.map(|(source, hit)| map_single_hit(hit, &source, req.include_content))
.collect();
let next_cursor = if sources.len() == 1 {
last_cursor
} else {
None
};
Ok(SearchResponse {
items,
answer: None,
citations: None,
next_cursor,
total_estimate,
})
}
fn decode_cursor_map(cursor: Option<&str>) -> Option<serde_json::Map<String, Value>> {
use base64::Engine;
let encoded = cursor?;
let bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD
.decode(encoded)
.ok()?;
let val: Value = serde_json::from_slice(&bytes).ok()?;
val.as_object().cloned()
}
fn map_hits(hits: Vec<RawHit>, data_source: &str, content: IncludeContent) -> Vec<SearchItem> {
hits.into_iter()
.map(|h| map_single_hit(h, data_source, content))
.collect()
}
fn map_single_hit(hit: RawHit, data_source: &str, content: IncludeContent) -> SearchItem {
let source_link = hit
.fields
.get("source_link")
.and_then(Value::as_str)
.map(String::from)
.unwrap_or_else(|| {
format!("urn:quelch:{}:{}", data_source, hit.id)
});
let snippet = match content {
IncludeContent::Snippet | IncludeContent::Full | IncludeContent::AgenticAnswer => {
hit.snippet.clone()
}
};
let body = match content {
IncludeContent::Full => hit.body.clone(),
_ => None,
};
SearchItem {
id: hit.id,
score: hit.score,
data_source: data_source.to_string(),
source_link,
snippet,
body,
fields: hit.fields,
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::mcp::schema::SchemaCatalog;
use crate::mcp::tools::search_api::mock::{MockSearchApi, SearchCall};
use crate::mcp::tools::test_helpers::build_expose;
fn build_expose_searchable() -> ExposeResolver {
build_expose(&[
("jira_issues", "jira_issue", "jira-issues"),
("confluence_pages", "confluence_page", "confluence-pages"),
])
}
fn build_expose_single() -> ExposeResolver {
build_expose(&[("jira_issues", "jira_issue", "jira-issues")])
}
fn build_expose_with_non_searchable() -> ExposeResolver {
build_expose(&[
("jira_issues", "jira_issue", "jira-issues"),
("jira_sprints", "jira_sprint", "jira-sprints"),
])
}
fn default_config() -> SearchToolConfig {
SearchToolConfig {
disable_agentic: false,
knowledge_base_name: "test-kb".to_string(),
default_top: 25,
max_top: 100,
}
}
fn agentic_req(query: &str) -> SearchRequest {
SearchRequest {
query: query.to_string(),
data_sources: Some(vec!["jira_issues".to_string()]),
r#where: None,
top: 10,
cursor: None,
include_deleted: false,
include_content: IncludeContent::AgenticAnswer,
}
}
fn snippet_req(query: &str) -> SearchRequest {
SearchRequest {
query: query.to_string(),
data_sources: Some(vec!["jira_issues".to_string()]),
r#where: None,
top: 10,
cursor: None,
include_deleted: false,
include_content: IncludeContent::Snippet,
}
}
#[tokio::test]
async fn search_routes_through_knowledge_base() {
let api = MockSearchApi::new().with_kb_response(MockSearchApi::agentic_response);
let expose = build_expose_single();
let schema = SchemaCatalog::new();
let config = default_config();
let resp = run(&api, &expose, &schema, &config, agentic_req("test query"))
.await
.unwrap();
let calls = api.calls_snapshot();
assert_eq!(calls.len(), 1, "expected exactly one API call");
assert!(
matches!(&calls[0], SearchCall::KnowledgeBase { .. }),
"expected KB call, got: {:?}",
calls[0]
);
assert!(resp.answer.is_some(), "answer should be populated");
}
#[tokio::test]
async fn search_disable_agentic_routes_through_index() {
let api = MockSearchApi::new();
let expose = build_expose_single();
let schema = SchemaCatalog::new();
let config = SearchToolConfig {
disable_agentic: true,
..default_config()
};
let resp = run(&api, &expose, &schema, &config, snippet_req("find bugs"))
.await
.unwrap();
let calls = api.calls_snapshot();
assert_eq!(calls.len(), 1);
assert!(
matches!(&calls[0], SearchCall::Index { .. }),
"expected index call when disable_agentic=true, got: {:?}",
calls[0]
);
assert!(resp.answer.is_none());
}
#[tokio::test]
async fn search_snippet_routes_through_index_not_kb() {
let api = MockSearchApi::new();
let expose = build_expose_single();
let schema = SchemaCatalog::new();
let config = default_config();
run(&api, &expose, &schema, &config, snippet_req("text"))
.await
.unwrap();
let calls = api.calls_snapshot();
assert!(
matches!(&calls[0], SearchCall::Index { .. }),
"snippet mode should use direct index search"
);
}
#[tokio::test]
async fn search_include_content_full_returns_body() {
let api = MockSearchApi::new().with_index_response(MockSearchApi::full_body_response);
let expose = build_expose_single();
let schema = SchemaCatalog::new();
let config = SearchToolConfig {
disable_agentic: true,
..default_config()
};
let req = SearchRequest {
query: "test".to_string(),
data_sources: Some(vec!["jira_issues".to_string()]),
r#where: None,
top: 10,
cursor: None,
include_deleted: false,
include_content: IncludeContent::Full,
};
let resp = run(&api, &expose, &schema, &config, req).await.unwrap();
let calls = api.calls_snapshot();
assert!(
matches!(
&calls[0],
SearchCall::Index {
include_full_body: true,
..
}
),
"expected include_full_body=true in index call"
);
assert!(
resp.items[0].body.is_some(),
"items should have body populated for include_content=full"
);
}
#[tokio::test]
async fn search_include_content_agentic_answer_returns_answer_field() {
let api = MockSearchApi::new().with_kb_response(MockSearchApi::agentic_response);
let expose = build_expose_single();
let schema = SchemaCatalog::new();
let config = default_config();
let resp = run(
&api,
&expose,
&schema,
&config,
agentic_req("what is done?"),
)
.await
.unwrap();
assert!(
resp.answer.is_some(),
"answer should be present for agentic_answer"
);
assert_eq!(resp.answer.unwrap(), "The synthesised answer is: 42.");
assert!(resp.citations.is_some(), "citations should be present");
}
#[tokio::test]
async fn search_excludes_soft_deleted_by_default() {
let api = MockSearchApi::new();
let expose = build_expose_single();
let schema = SchemaCatalog::new();
let config = SearchToolConfig {
disable_agentic: true,
..default_config()
};
let req = SearchRequest {
query: "bugs".to_string(),
data_sources: Some(vec!["jira_issues".to_string()]),
r#where: None,
top: 10,
cursor: None,
include_deleted: false, include_content: IncludeContent::Snippet,
};
run(&api, &expose, &schema, &config, req).await.unwrap();
let calls = api.calls_snapshot();
if let SearchCall::Index { odata_filter, .. } = &calls[0] {
let filter = odata_filter.as_deref().unwrap_or("");
assert!(
filter.contains("_deleted ne true"),
"filter should exclude soft-deleted; got: {filter}"
);
} else {
panic!("expected Index call");
}
}
#[tokio::test]
async fn search_includes_soft_deleted_when_set() {
let api = MockSearchApi::new();
let expose = build_expose_single();
let schema = SchemaCatalog::new();
let config = SearchToolConfig {
disable_agentic: true,
..default_config()
};
let req = SearchRequest {
query: "bugs".to_string(),
data_sources: Some(vec!["jira_issues".to_string()]),
r#where: None,
top: 10,
cursor: None,
include_deleted: true,
include_content: IncludeContent::Snippet,
};
run(&api, &expose, &schema, &config, req).await.unwrap();
let calls = api.calls_snapshot();
if let SearchCall::Index { odata_filter, .. } = &calls[0] {
let has_deleted_guard = odata_filter
.as_deref()
.map(|f| f.contains("_deleted"))
.unwrap_or(false);
assert!(
!has_deleted_guard,
"filter should NOT contain _deleted predicate when include_deleted=true; got: {odata_filter:?}"
);
} else {
panic!("expected Index call");
}
}
#[tokio::test]
async fn search_forbidden_for_unexposed_data_source() {
let api = MockSearchApi::new();
let expose = build_expose_single(); let schema = SchemaCatalog::new();
let config = SearchToolConfig {
disable_agentic: true,
..default_config()
};
let req = SearchRequest {
query: "test".to_string(),
data_sources: Some(vec!["confluence_pages".to_string()]),
r#where: None,
top: 10,
cursor: None,
include_deleted: false,
include_content: IncludeContent::Snippet,
};
let err = run(&api, &expose, &schema, &config, req).await.unwrap_err();
assert!(
matches!(err, McpError::Forbidden(_)),
"expected Forbidden for unexposed source, got: {err:?}"
);
}
#[tokio::test]
async fn search_uses_all_searchable_when_data_sources_omitted() {
let api = MockSearchApi::new();
let expose = build_expose_with_non_searchable();
let schema = SchemaCatalog::new();
let config = SearchToolConfig {
disable_agentic: true,
..default_config()
};
let req = SearchRequest {
query: "open issues".to_string(),
data_sources: None, r#where: None,
top: 10,
cursor: None,
include_deleted: false,
include_content: IncludeContent::Snippet,
};
run(&api, &expose, &schema, &config, req).await.unwrap();
let calls = api.calls_snapshot();
let index_calls: Vec<&str> = calls
.iter()
.filter_map(|c| {
if let SearchCall::Index { index_name, .. } = c {
Some(index_name.as_str())
} else {
None
}
})
.collect();
assert!(
index_calls.contains(&"jira_issues"),
"should call jira_issues; got: {index_calls:?}"
);
assert!(
!index_calls.contains(&"jira_sprints"),
"should NOT call jira_sprints (not searchable); got: {index_calls:?}"
);
}
#[tokio::test]
async fn search_uses_all_exposed_searchable_when_no_data_sources() {
let api = MockSearchApi::new();
let expose = build_expose_searchable(); let schema = SchemaCatalog::new();
let config = SearchToolConfig {
disable_agentic: true,
..default_config()
};
let req = SearchRequest {
query: "team".to_string(),
data_sources: None,
r#where: None,
top: 10,
cursor: None,
include_deleted: false,
include_content: IncludeContent::Snippet,
};
run(&api, &expose, &schema, &config, req).await.unwrap();
let calls = api.calls_snapshot();
let index_calls: Vec<&str> = calls
.iter()
.filter_map(|c| {
if let SearchCall::Index { index_name, .. } = c {
Some(index_name.as_str())
} else {
None
}
})
.collect();
assert_eq!(index_calls.len(), 2, "should call both searchable sources");
}
#[tokio::test]
async fn search_paginates_via_cursor() {
let api = MockSearchApi::new().with_index_response(|| {
let mut r = MockSearchApi::default_response();
r.next_cursor = Some("next-token-1".to_string());
r
});
let expose = build_expose_single();
let schema = SchemaCatalog::new();
let config = SearchToolConfig {
disable_agentic: true,
..default_config()
};
let req1 = SearchRequest {
query: "paginate".to_string(),
data_sources: Some(vec!["jira_issues".to_string()]),
r#where: None,
top: 1,
cursor: None,
include_deleted: false,
include_content: IncludeContent::Snippet,
};
let resp1 = run(&api, &expose, &schema, &config, req1).await.unwrap();
assert!(
resp1.next_cursor.is_some(),
"first page should have next_cursor"
);
let cursor = resp1.next_cursor.unwrap();
let api2 = MockSearchApi::new();
let req2 = SearchRequest {
query: "paginate".to_string(),
data_sources: Some(vec!["jira_issues".to_string()]),
r#where: None,
top: 1,
cursor: Some(cursor),
include_deleted: false,
include_content: IncludeContent::Snippet,
};
run(&api2, &expose, &schema, &config, req2).await.unwrap();
let calls2 = api2.calls_snapshot();
if let SearchCall::Index { cursor, .. } = &calls2[0] {
assert!(
cursor.is_some(),
"second request should pass cursor to the API"
);
}
}
#[tokio::test]
async fn search_agentic_answer_with_disable_agentic_returns_invalid_argument() {
let api = MockSearchApi::new();
let expose = build_expose_single();
let schema = SchemaCatalog::new();
let config = SearchToolConfig {
disable_agentic: true,
..default_config()
};
let err = run(&api, &expose, &schema, &config, agentic_req("test"))
.await
.unwrap_err();
assert!(
matches!(err, McpError::InvalidArgument(_)),
"expected InvalidArgument, got: {err:?}"
);
}
}