use crate::service::{errors::*, mcp_service::*};
use rmcp::{
ErrorData as McpError,
handler::server::wrapper::Parameters,
model::{CallToolResult, Content, CreateMessageRequestParams, Role, SamplingMessage},
};
use rudof_lib::{
InputSpec,
query::{detect_query_type, execute_query},
query_result_format::ResultQueryFormat,
query_type::QueryType,
};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use serde_json::json;
use std::io::Cursor;
use std::str::FromStr;
use super::helpers::*;
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
pub struct ExecuteSparqlQueryRequest {
pub query: Option<String>,
pub query_natural_language: Option<String>,
pub result_format: Option<String>,
}
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
pub struct QueryExecutionResponse {
pub query_type: String,
pub result_format: String,
pub status: String,
pub results: String,
pub result_size_bytes: usize,
pub result_lines: usize,
}
async fn generate_sparql_from_natural_language(
service: &RudofMcpService,
natural_language: &str,
) -> Result<String, McpError> {
let system_message = r#"You are a SPARQL query expert. Convert natural language questions into valid SPARQL queries.
- Only output the SPARQL query, no explanations or markdown formatting
- Use standard SPARQL syntax (SELECT, CONSTRUCT, ASK, or DESCRIBE)
- Include appropriate prefixes if needed
- Make the query efficient and correct
- If you need to guess prefixes, use common ones like rdf:, rdfs:, xsd:, ex:, etc."#;
let user_message = format!("Generate a SPARQL query for: {}", natural_language);
let context_guard = service.current_context.read().await;
let context = context_guard.as_ref().ok_or_else(|| {
internal_error(
"Sampling context error",
"Request context not found",
Some(json!({"operation":"generate_sparql_from_natural_language","phase":"get_context"})),
)
})?;
let sampling_request = CreateMessageRequestParams {
meta: None,
task: None,
messages: vec![
SamplingMessage {
role: Role::User,
content: Content::text(system_message),
},
SamplingMessage {
role: Role::User,
content: Content::text(user_message.clone()),
},
],
model_preferences: None,
system_prompt: None,
include_context: None,
temperature: Some(0.3),
max_tokens: 512,
stop_sequences: None,
metadata: None,
};
let response = context.peer.create_message(sampling_request).await.map_err(|e| {
internal_error(
"Query generation error",
e.to_string(),
Some(json!({"operation":"generate_sparql_from_natural_language","phase":"create_message"})),
)
})?;
let generated_query = if let Some(text_content) = response.message.content.as_text() {
text_content.text.clone()
} else {
return Err(internal_error(
"Sampling response error",
"Expected text response from LLM",
Some(json!({"operation":"generate_sparql_from_natural_language","phase":"extract_response_text"})),
));
};
let cleaned_query = if generated_query.starts_with("```") {
generated_query
.lines()
.skip(1)
.take_while(|line| !line.starts_with("```"))
.collect::<Vec<_>>()
.join("\n")
.trim()
.to_string()
} else {
generated_query.trim().to_string()
};
tracing::debug!(
natural_language = %natural_language,
generated_query = %cleaned_query,
"Generated SPARQL query from natural language via rmcp sampling"
);
Ok(cleaned_query)
}
pub async fn execute_sparql_query_impl(
service: &RudofMcpService,
Parameters(ExecuteSparqlQueryRequest {
query,
query_natural_language,
result_format,
}): Parameters<ExecuteSparqlQueryRequest>,
) -> Result<CallToolResult, McpError> {
let sparql_query = match (query, query_natural_language) {
(Some(q), None) => {
q
},
(None, Some(nl)) => {
generate_sparql_from_natural_language(service, &nl).await?
},
(Some(_), Some(_)) => {
return Ok(ToolExecutionError::new(
"Cannot provide both 'query' and 'query_natural_language'. Choose one.",
)
.into_call_tool_result());
},
(None, None) => {
return Ok(ToolExecutionError::with_hint(
"No query provided",
"Provide either 'query' with a SPARQL query string, or 'query_natural_language' with a description",
)
.into_call_tool_result());
},
};
let query_type_str = match detect_query_type(&sparql_query) {
Some(qt) => qt,
None => {
return Ok(ToolExecutionError::with_hint(
"Cannot determine query type",
"Ensure the query starts with SELECT, CONSTRUCT, ASK",
)
.into_call_tool_result());
},
};
let parsed_query_type = match QueryType::from_str(&query_type_str) {
Ok(qt) => qt,
Err(e) => {
return Ok(ToolExecutionError::with_hint(
format!("Invalid query type: {}", e),
"Supported query types: SELECT, CONSTRUCT, ASK",
)
.into_call_tool_result());
},
};
let result_format_str = result_format.as_deref().unwrap_or("Internal");
let parsed_result_format = match ResultQueryFormat::from_str(result_format_str) {
Ok(fmt) => fmt,
Err(e) => {
return Ok(ToolExecutionError::with_hint(
format!("Invalid result format '{}': {}", result_format_str, e),
format!("Supported formats: {}", SPARQL_RESULT_FORMATS),
)
.into_call_tool_result());
},
};
let query_spec = InputSpec::Str(sparql_query.clone());
let mut rudof = service.rudof.lock().await;
let mut output_buffer = Cursor::new(Vec::new());
execute_query(
&mut rudof,
&query_spec,
&parsed_query_type,
&parsed_result_format,
&mut output_buffer,
)
.map_err(|e| {
internal_error(
"Query execution error",
e.to_string(),
Some(json!({"operation":"execute_sparql_query_impl", "phase":"execute_query"})),
)
})?;
let output_bytes = output_buffer.into_inner();
let output_str = String::from_utf8(output_bytes).map_err(|e| {
internal_error(
"Conversion error",
e.to_string(),
Some(json!({"operation":"execute_sparql_query_impl", "phase":"utf8_conversion"})),
)
})?;
let result_size_bytes = output_str.len();
let result_lines = output_str.lines().count();
let response = QueryExecutionResponse {
query_type: query_type_str.clone(),
result_format: result_format_str.to_string(),
status: "success".to_string(),
results: output_str.to_string(),
result_size_bytes,
result_lines,
};
let structured = serde_json::to_value(&response).map_err(|e| {
internal_error(
"Serialization error",
e.to_string(),
Some(json!({"operation":"execute_sparql_query_impl", "phase":"serialize_response"})),
)
})?;
let summary = format!(
"# SPARQL Query Execution\n\n\
**Status:** ✓ Success\n\
**Query Type:** {}\n\
**Result Format:** {}\n\
**Result Size:** {} bytes\n\
**Result Lines:** {}\n",
query_type_str, result_format_str, result_size_bytes, result_lines
);
let query_display = format!("## Query\n\n```sparql\n{}\n```", sparql_query);
let results_display = match result_format_str.to_lowercase().as_str() {
"csv" => format!("## Results\n\n```csv\n{}\n```", output_str),
"jsonld" | "json" => format!("## Results\n\n```json\n{}\n```", output_str),
"turtle" | "n3" => format!("## Results\n\n```turtle\n{}\n```", output_str),
"ntriples" | "nquads" => format!("## Results\n\n```ntriples\n{}\n```", output_str),
"rdfxml" => format!("## Results\n\n```xml\n{}\n```", output_str),
"trig" => format!("## Results\n\n```trig\n{}\n```", output_str),
_ => format!("## Results\n\n```\n{}\n```", output_str),
};
let mut result = CallToolResult::success(vec![
Content::text(summary),
Content::text(query_display),
Content::text(results_display),
]);
result.structured_content = Some(structured);
Ok(result)
}