use crate::llm::OllamaClient;
use serde_json::{Value, json};
pub fn handle_expand_query(llm: Option<&OllamaClient>, params: &Value) -> Result<Value, String> {
let llm = llm.ok_or("query expansion requires smart or autonomous tier (Ollama LLM)")?;
let query = params["query"]
.as_str()
.ok_or(crate::errors::msg::QUERY_REQUIRED)?;
let terms = llm.expand_query(query).map_err(|e| e.to_string())?;
Ok(json!({"original": query, (crate::models::field_names::EXPANDED_TERMS): terms}))
}
use crate::mcp::registry::McpTool;
use schemars::JsonSchema;
use serde::Deserialize;
#[derive(Debug, Clone, Default, Deserialize, JsonSchema)]
#[allow(dead_code)]
pub struct ExpandQueryRequest {
pub query: String,
}
#[allow(dead_code)]
pub struct ExpandQueryTool;
impl McpTool for ExpandQueryTool {
fn name() -> &'static str {
crate::mcp::registry::tool_names::MEMORY_EXPAND_QUERY
}
fn description() -> &'static str {
"LLM-expand a search query into related terms (smart/autonomous tier)."
}
fn docs() -> &'static str {
"LLM query expansion. Smart/autonomous tier."
}
fn input_schema() -> Value {
crate::mcp::registry::input_schema_for::<ExpandQueryRequest>()
}
fn family() -> &'static str {
crate::profile::Family::Power.name()
}
}
#[cfg(test)]
mod d1_5_986_tests {
use super::*;
use crate::mcp::parity_test_helpers::{
assert_descriptions_match, assert_property_set_parity, derived_props_for,
};
#[test]
fn expand_query_parity_986() {
let derived = derived_props_for::<ExpandQueryRequest>();
assert_property_set_parity("memory_expand_query", &derived);
assert_descriptions_match("memory_expand_query", &derived);
}
#[test]
fn expand_query_tool_metadata_986() {
assert_eq!(ExpandQueryTool::name(), "memory_expand_query");
assert_eq!(ExpandQueryTool::family(), "power");
}
}
#[cfg(test)]
mod tests {
use super::handle_expand_query;
use crate::llm::OllamaClient;
use serde_json::json;
use wiremock::matchers::{method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
async fn mount_tags_ok(server: &MockServer) {
Mock::given(method("GET"))
.and(path("/api/tags"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({"models": []})))
.mount(server)
.await;
}
#[test]
fn rejects_when_llm_absent() {
let err = handle_expand_query(None, &json!({"query": "anything"})).unwrap_err();
assert!(
err.contains("smart") || err.contains("autonomous") || err.contains("Ollama"),
"expected tier-gating error, got: {err}"
);
}
#[tokio::test(flavor = "multi_thread")]
async fn rejects_when_query_missing() {
let server = MockServer::start().await;
mount_tags_ok(&server).await;
let uri = server.uri();
let err = tokio::task::spawn_blocking(move || {
let client = OllamaClient::new_with_url(&uri, "test-model").unwrap();
handle_expand_query(Some(&client), &json!({}))
.err()
.unwrap_or_default()
})
.await
.unwrap();
assert!(err.contains("query"), "expected query-required, got: {err}");
}
#[tokio::test(flavor = "multi_thread")]
async fn rejects_when_query_is_not_string() {
let server = MockServer::start().await;
mount_tags_ok(&server).await;
let uri = server.uri();
let err = tokio::task::spawn_blocking(move || {
let client = OllamaClient::new_with_url(&uri, "test-model").unwrap();
handle_expand_query(Some(&client), &json!({"query": 42}))
.err()
.unwrap_or_default()
})
.await
.unwrap();
assert!(err.contains("query"), "expected query-required, got: {err}");
}
#[tokio::test(flavor = "multi_thread")]
async fn success_shapes_expanded_terms() {
let server = MockServer::start().await;
mount_tags_ok(&server).await;
Mock::given(method("POST"))
.and(path("/api/chat"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"message": {"content": "alpha\nbeta\ngamma\n"},
})))
.mount(&server)
.await;
let uri = server.uri();
let value = tokio::task::spawn_blocking(move || {
let client = OllamaClient::new_with_url(&uri, "test-model").unwrap();
handle_expand_query(Some(&client), &json!({"query": "neural networks"}))
})
.await
.unwrap()
.expect("handler should succeed");
assert_eq!(value["original"], "neural networks");
let terms = value["expanded_terms"].as_array().unwrap();
assert_eq!(terms.len(), 3);
assert_eq!(terms[0], "alpha");
assert_eq!(terms[1], "beta");
assert_eq!(terms[2], "gamma");
}
#[tokio::test(flavor = "multi_thread")]
async fn success_with_empty_response_yields_no_terms() {
let server = MockServer::start().await;
mount_tags_ok(&server).await;
Mock::given(method("POST"))
.and(path("/api/chat"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"message": {"content": "\n\n \n"},
})))
.mount(&server)
.await;
let uri = server.uri();
let value = tokio::task::spawn_blocking(move || {
let client = OllamaClient::new_with_url(&uri, "test-model").unwrap();
handle_expand_query(Some(&client), &json!({"query": "x"}))
})
.await
.unwrap()
.expect("handler should succeed");
let terms = value["expanded_terms"].as_array().unwrap();
assert!(terms.is_empty(), "blank-only response collapses to []");
}
#[tokio::test(flavor = "multi_thread")]
async fn surfaces_llm_500_error_through_envelope() {
let server = MockServer::start().await;
mount_tags_ok(&server).await;
Mock::given(method("POST"))
.and(path("/api/chat"))
.respond_with(ResponseTemplate::new(500).set_body_string("boom"))
.mount(&server)
.await;
let uri = server.uri();
let err = tokio::task::spawn_blocking(move || {
let client = OllamaClient::new_with_url(&uri, "test-model").unwrap();
handle_expand_query(Some(&client), &json!({"query": "q"}))
.err()
.unwrap_or_default()
})
.await
.unwrap();
assert!(
err.contains("500") || err.contains("Chat generate failed"),
"expected upstream error, got: {err}"
);
}
#[tokio::test(flavor = "multi_thread")]
async fn surfaces_llm_malformed_json_error() {
let server = MockServer::start().await;
mount_tags_ok(&server).await;
Mock::given(method("POST"))
.and(path("/api/chat"))
.respond_with(
ResponseTemplate::new(200)
.set_body_string("{ not valid")
.insert_header(crate::HEADER_CONTENT_TYPE, crate::MIME_JSON),
)
.mount(&server)
.await;
let uri = server.uri();
let err = tokio::task::spawn_blocking(move || {
let client = OllamaClient::new_with_url(&uri, "test-model").unwrap();
handle_expand_query(Some(&client), &json!({"query": "q"}))
.err()
.unwrap_or_default()
})
.await
.unwrap();
assert!(
err.to_lowercase().contains("parse") || err.to_lowercase().contains("json"),
"expected parse-error, got: {err}"
);
}
}