use std::borrow::Cow;
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use std::time::Instant;
use crate::datatypes::Value;
use crate::error::KgError;
use crate::graph::dir_graph::DirGraph;
use crate::graph::embedder::Embedder;
use crate::graph::languages::cypher;
use crate::graph::languages::cypher::ast::{CypherQuery, OutputFormat};
use crate::graph::languages::cypher::result::CypherResult;
pub struct ExecuteOptions<'a> {
pub params: &'a HashMap<String, Value>,
pub deadline: Option<Instant>,
pub max_rows: Option<usize>,
pub lazy_eligible: bool,
pub disabled_passes: Option<&'a HashSet<String>>,
pub embedder: Option<Arc<dyn Embedder>>,
}
impl<'a> ExecuteOptions<'a> {
pub fn new(params: &'a HashMap<String, Value>) -> Self {
Self::eager(params)
}
pub fn eager(params: &'a HashMap<String, Value>) -> Self {
Self {
params,
deadline: None,
max_rows: None,
lazy_eligible: false,
disabled_passes: None,
embedder: None,
}
}
}
pub struct ExecuteOutcome {
pub result: CypherResult,
pub is_mutation: bool,
pub output_format: OutputFormat,
pub explain: bool,
}
pub fn execute_read(
graph: &DirGraph,
query: &str,
opts: &ExecuteOptions<'_>,
) -> Result<ExecuteOutcome, KgError> {
let (parsed, params) = prepare(graph, query, opts)?;
let is_mutation = cypher::is_mutation_query(&parsed);
if parsed.explain {
let result = cypher::generate_explain_result(&parsed, graph);
return Ok(ExecuteOutcome {
result,
is_mutation,
output_format: parsed.output_format,
explain: true,
});
}
if is_mutation {
return Err(KgError::Argument(
"execute_read called with a mutation query (CREATE/SET/DELETE/REMOVE/MERGE) \
— use execute_mut against a mutable graph view"
.to_string(),
));
}
let result = cypher::CypherExecutor::with_params(graph, ¶ms, opts.deadline)
.with_max_rows(opts.max_rows)
.with_streaming(opts.lazy_eligible)
.execute(&parsed)
.map_err(|message| KgError::CypherExecution {
message,
position: None,
})?;
Ok(ExecuteOutcome {
result,
is_mutation: false,
output_format: parsed.output_format,
explain: false,
})
}
pub fn execute_mut(
graph: &mut DirGraph,
query: &str,
opts: &ExecuteOptions<'_>,
) -> Result<ExecuteOutcome, KgError> {
let (parsed, params) = prepare(graph, query, opts)?;
let is_mutation = cypher::is_mutation_query(&parsed);
if parsed.explain {
let result = cypher::generate_explain_result(&parsed, graph);
return Ok(ExecuteOutcome {
result,
is_mutation,
output_format: parsed.output_format,
explain: true,
});
}
let result = if is_mutation {
cypher::execute_mutable(graph, &parsed, params, opts.deadline).map_err(|message| {
KgError::CypherExecution {
message,
position: None,
}
})?
} else {
cypher::CypherExecutor::with_params(graph, ¶ms, opts.deadline)
.with_max_rows(opts.max_rows)
.with_streaming(opts.lazy_eligible)
.execute(&parsed)
.map_err(|message| KgError::CypherExecution {
message,
position: None,
})?
};
Ok(ExecuteOutcome {
result,
is_mutation,
output_format: parsed.output_format,
explain: false,
})
}
fn prepare(
graph: &DirGraph,
query: &str,
opts: &ExecuteOptions<'_>,
) -> Result<(CypherQuery, HashMap<String, Value>), KgError> {
let mut parsed = cypher::parse_cypher(query)?;
cypher::validate_schema(&parsed, graph).map_err(KgError::from)?;
cypher::warn_unknown_pattern_refs(&parsed, graph);
let rewrite = cypher::rewrite_text_score(&mut parsed, opts.params).map_err(|message| {
KgError::CypherExecution {
message,
position: None,
}
})?;
let params: Cow<'_, HashMap<String, Value>> =
if !rewrite.texts_to_embed.is_empty() && !parsed.explain {
Cow::Owned(embed_into_params(opts, &rewrite)?)
} else {
Cow::Borrowed(opts.params)
};
let disabled_default = cypher::planner::empty_disabled_set();
let disabled_ref = opts.disabled_passes.unwrap_or(disabled_default);
cypher::planner::optimize_with_disabled(&mut parsed, graph, ¶ms, disabled_ref);
if opts.lazy_eligible {
cypher::mark_lazy_eligibility(&mut parsed);
}
Ok((parsed, params.into_owned()))
}
fn embed_into_params(
opts: &ExecuteOptions<'_>,
rewrite: &cypher::planner::simplification::TextScoreRewrite,
) -> Result<HashMap<String, Value>, KgError> {
let model = opts
.embedder
.as_ref()
.ok_or_else(|| KgError::CypherExecution {
message: "text_score() requires a registered embedding model. \
Call g.set_embedder(model) first (Python) or pass an embedder \
via ExecuteOptions::embedder (downstream Rust consumers)."
.to_string(),
position: None,
})?;
model.load().map_err(|message| KgError::CypherExecution {
message,
position: None,
})?;
let texts: Vec<String> = rewrite
.texts_to_embed
.iter()
.map(|(_, t)| t.clone())
.collect();
let embed_result = model.embed(&texts);
model.unload();
let embeddings: Vec<Vec<f32>> = embed_result.map_err(|message| KgError::CypherExecution {
message,
position: None,
})?;
if embeddings.len() != texts.len() {
return Err(KgError::CypherExecution {
message: format!(
"text_score: model.embed() returned {} vectors for {} texts",
embeddings.len(),
texts.len()
),
position: None,
});
}
let mut params = opts.params.clone();
for (i, (param_name, _)) in rewrite.texts_to_embed.iter().enumerate() {
let json = format!(
"[{}]",
embeddings[i]
.iter()
.map(|f| f.to_string())
.collect::<Vec<_>>()
.join(", ")
);
params.insert(param_name.clone(), Value::String(json));
}
Ok(params)
}