use serde::{Deserialize, Serialize};
use crate::{
ComparisonOp, Engine, EngineError, Predicate, QueryAst, QueryStep, RetrievalModality,
ScalarValue, SearchHit, SearchHitSource, SearchMatchMode, SearchRows, TextQuery,
compile_retrieval_plan, compile_search_plan, compile_search_plan_from_queries,
};
use fathomdb_query::CompileError;
#[derive(Clone, Copy, Debug, Deserialize, Serialize, Eq, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum PySearchMode {
Search,
TextSearch,
FallbackSearch,
}
#[derive(Clone, Debug, Deserialize)]
pub struct PySearchRequest {
pub root_kind: String,
pub strict_query: String,
#[serde(default)]
pub relaxed_query: Option<String>,
pub mode: PySearchMode,
pub limit: usize,
#[serde(default)]
pub filters: Vec<PySearchFilter>,
#[serde(default)]
pub attribution_requested: bool,
}
#[derive(Clone, Debug, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum PySearchFilter {
FilterKindEq {
kind: String,
},
FilterLogicalIdEq {
logical_id: String,
},
FilterSourceRefEq {
source_ref: String,
},
FilterContentRefEq {
content_ref: String,
},
FilterContentRefNotNull {},
FilterJsonTextEq {
path: String,
value: String,
},
FilterJsonBoolEq {
path: String,
value: bool,
},
FilterJsonIntegerGt {
path: String,
value: i64,
},
FilterJsonIntegerGte {
path: String,
value: i64,
},
FilterJsonIntegerLt {
path: String,
value: i64,
},
FilterJsonIntegerLte {
path: String,
value: i64,
},
FilterJsonTimestampGt {
path: String,
value: i64,
},
FilterJsonTimestampGte {
path: String,
value: i64,
},
FilterJsonTimestampLt {
path: String,
value: i64,
},
FilterJsonTimestampLte {
path: String,
value: i64,
},
FilterJsonFusedTextEq {
path: String,
value: String,
},
FilterJsonFusedTimestampGt {
path: String,
value: i64,
},
FilterJsonFusedTimestampGte {
path: String,
value: i64,
},
FilterJsonFusedTimestampLt {
path: String,
value: i64,
},
FilterJsonFusedTimestampLte {
path: String,
value: i64,
},
}
impl From<PySearchFilter> for QueryStep {
fn from(value: PySearchFilter) -> Self {
match value {
PySearchFilter::FilterKindEq { kind } => QueryStep::Filter(Predicate::KindEq(kind)),
PySearchFilter::FilterLogicalIdEq { logical_id } => {
QueryStep::Filter(Predicate::LogicalIdEq(logical_id))
}
PySearchFilter::FilterSourceRefEq { source_ref } => {
QueryStep::Filter(Predicate::SourceRefEq(source_ref))
}
PySearchFilter::FilterContentRefEq { content_ref } => {
QueryStep::Filter(Predicate::ContentRefEq(content_ref))
}
PySearchFilter::FilterContentRefNotNull {} => {
QueryStep::Filter(Predicate::ContentRefNotNull)
}
PySearchFilter::FilterJsonTextEq { path, value } => {
QueryStep::Filter(Predicate::JsonPathEq {
path,
value: ScalarValue::Text(value),
})
}
PySearchFilter::FilterJsonBoolEq { path, value } => {
QueryStep::Filter(Predicate::JsonPathEq {
path,
value: ScalarValue::Bool(value),
})
}
PySearchFilter::FilterJsonIntegerGt { path, value }
| PySearchFilter::FilterJsonTimestampGt { path, value } => {
QueryStep::Filter(Predicate::JsonPathCompare {
path,
op: ComparisonOp::Gt,
value: ScalarValue::Integer(value),
})
}
PySearchFilter::FilterJsonIntegerGte { path, value }
| PySearchFilter::FilterJsonTimestampGte { path, value } => {
QueryStep::Filter(Predicate::JsonPathCompare {
path,
op: ComparisonOp::Gte,
value: ScalarValue::Integer(value),
})
}
PySearchFilter::FilterJsonIntegerLt { path, value }
| PySearchFilter::FilterJsonTimestampLt { path, value } => {
QueryStep::Filter(Predicate::JsonPathCompare {
path,
op: ComparisonOp::Lt,
value: ScalarValue::Integer(value),
})
}
PySearchFilter::FilterJsonIntegerLte { path, value }
| PySearchFilter::FilterJsonTimestampLte { path, value } => {
QueryStep::Filter(Predicate::JsonPathCompare {
path,
op: ComparisonOp::Lte,
value: ScalarValue::Integer(value),
})
}
PySearchFilter::FilterJsonFusedTextEq { path, value } => {
QueryStep::Filter(Predicate::JsonPathFusedEq { path, value })
}
PySearchFilter::FilterJsonFusedTimestampGt { path, value } => {
QueryStep::Filter(Predicate::JsonPathFusedTimestampCmp {
path,
op: ComparisonOp::Gt,
value,
})
}
PySearchFilter::FilterJsonFusedTimestampGte { path, value } => {
QueryStep::Filter(Predicate::JsonPathFusedTimestampCmp {
path,
op: ComparisonOp::Gte,
value,
})
}
PySearchFilter::FilterJsonFusedTimestampLt { path, value } => {
QueryStep::Filter(Predicate::JsonPathFusedTimestampCmp {
path,
op: ComparisonOp::Lt,
value,
})
}
PySearchFilter::FilterJsonFusedTimestampLte { path, value } => {
QueryStep::Filter(Predicate::JsonPathFusedTimestampCmp {
path,
op: ComparisonOp::Lte,
value,
})
}
}
}
}
#[derive(Clone, Copy, Debug, Deserialize, Serialize, Eq, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum PySearchHitSource {
Chunk,
Property,
Vector,
}
impl From<SearchHitSource> for PySearchHitSource {
fn from(value: SearchHitSource) -> Self {
match value {
SearchHitSource::Chunk => Self::Chunk,
SearchHitSource::Property => Self::Property,
SearchHitSource::Vector => Self::Vector,
}
}
}
#[derive(Clone, Copy, Debug, Deserialize, Serialize, Eq, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum PySearchMatchMode {
Strict,
Relaxed,
}
impl From<SearchMatchMode> for PySearchMatchMode {
fn from(value: SearchMatchMode) -> Self {
match value {
SearchMatchMode::Strict => Self::Strict,
SearchMatchMode::Relaxed => Self::Relaxed,
}
}
}
#[derive(Clone, Copy, Debug, Deserialize, Serialize, Eq, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum PyRetrievalModality {
Text,
Vector,
}
impl From<RetrievalModality> for PyRetrievalModality {
fn from(value: RetrievalModality) -> Self {
match value {
RetrievalModality::Text => Self::Text,
RetrievalModality::Vector => Self::Vector,
}
}
}
#[derive(Clone, Debug, Deserialize, Serialize, Eq, PartialEq)]
pub struct PySearchNodeRow {
pub row_id: String,
pub logical_id: String,
pub kind: String,
pub properties: String,
pub content_ref: Option<String>,
pub last_accessed_at: Option<i64>,
}
#[derive(Clone, Debug, Deserialize, Serialize, Eq, PartialEq)]
pub struct PyHitAttribution {
pub matched_paths: Vec<String>,
}
#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
pub struct PySearchHit {
pub node: PySearchNodeRow,
pub score: f64,
pub modality: PyRetrievalModality,
pub source: PySearchHitSource,
pub match_mode: Option<PySearchMatchMode>,
pub snippet: Option<String>,
pub written_at: i64,
pub projection_row_id: Option<String>,
pub vector_distance: Option<f64>,
pub attribution: Option<PyHitAttribution>,
}
impl From<SearchHit> for PySearchHit {
fn from(value: SearchHit) -> Self {
Self {
node: PySearchNodeRow {
row_id: value.node.row_id,
logical_id: value.node.logical_id,
kind: value.node.kind,
properties: value.node.properties,
content_ref: value.node.content_ref,
last_accessed_at: value.node.last_accessed_at,
},
score: value.score,
modality: value.modality.into(),
source: value.source.into(),
match_mode: value.match_mode.map(Into::into),
snippet: value.snippet,
written_at: value.written_at,
projection_row_id: value.projection_row_id,
vector_distance: value.vector_distance,
attribution: value.attribution.map(|a| PyHitAttribution {
matched_paths: a.matched_paths,
}),
}
}
}
#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
pub struct PySearchRows {
pub hits: Vec<PySearchHit>,
pub was_degraded: bool,
pub fallback_used: bool,
pub strict_hit_count: usize,
pub relaxed_hit_count: usize,
pub vector_hit_count: usize,
}
impl From<SearchRows> for PySearchRows {
fn from(value: SearchRows) -> Self {
Self {
hits: value.hits.into_iter().map(PySearchHit::from).collect(),
was_degraded: value.was_degraded,
fallback_used: value.fallback_used,
strict_hit_count: value.strict_hit_count,
relaxed_hit_count: value.relaxed_hit_count,
vector_hit_count: value.vector_hit_count,
}
}
}
#[derive(Debug)]
pub enum SearchFfiError {
Parse(serde_json::Error),
Compile(CompileError),
Engine(EngineError),
Serialize(serde_json::Error),
}
impl std::fmt::Display for SearchFfiError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Parse(e) => write!(f, "search request JSON parse error: {e}"),
Self::Compile(e) => write!(f, "search plan compile error: {e:?}"),
Self::Engine(e) => write!(f, "search execution error: {e}"),
Self::Serialize(e) => write!(f, "search response serialize error: {e}"),
}
}
}
impl std::error::Error for SearchFfiError {}
fn build_filter_ast(request: &PySearchRequest) -> QueryAst {
let steps = request
.filters
.iter()
.cloned()
.map(QueryStep::from)
.collect();
QueryAst {
root_kind: request.root_kind.clone(),
steps,
expansions: Vec::new(),
final_limit: None,
}
}
pub fn execute_search_json(engine: &Engine, request_json: &str) -> Result<String, SearchFfiError> {
let request: PySearchRequest =
serde_json::from_str(request_json).map_err(SearchFfiError::Parse)?;
let limit = request.limit;
let attribution = request.attribution_requested;
if matches!(request.mode, PySearchMode::Search) {
let mut ast = build_filter_ast(&request);
ast.steps.insert(
0,
QueryStep::Search {
query: request.strict_query.clone(),
limit,
},
);
let mut plan = compile_retrieval_plan(&ast).map_err(SearchFfiError::Compile)?;
plan.text.strict.attribution_requested = attribution;
if let Some(relaxed) = plan.text.relaxed.as_mut() {
relaxed.attribution_requested = attribution;
}
let rows: SearchRows = engine
.coordinator()
.execute_retrieval_plan(&plan, &request.strict_query)
.map_err(SearchFfiError::Engine)?;
let py_rows = PySearchRows::from(rows);
return serde_json::to_string(&py_rows).map_err(SearchFfiError::Serialize);
}
let strict = TextQuery::parse(&request.strict_query);
let ast = build_filter_ast(&request);
let mut plan = match request.mode {
PySearchMode::Search => unreachable!("Search handled above"),
PySearchMode::TextSearch => {
let mut ast_with_text = ast;
ast_with_text.steps.insert(
0,
QueryStep::TextSearch {
query: strict,
limit,
},
);
compile_search_plan(&ast_with_text).map_err(SearchFfiError::Compile)?
}
PySearchMode::FallbackSearch => {
let relaxed = request.relaxed_query.as_deref().map(TextQuery::parse);
let mut ast_with_sentinel = ast;
ast_with_sentinel.steps.insert(
0,
QueryStep::TextSearch {
query: TextQuery::Empty,
limit,
},
);
compile_search_plan_from_queries(
&ast_with_sentinel,
strict,
relaxed,
limit,
attribution,
)
.map_err(SearchFfiError::Compile)?
}
};
plan.strict.attribution_requested = attribution;
if let Some(relaxed) = plan.relaxed.as_mut() {
relaxed.attribution_requested = attribution;
}
let rows: SearchRows = engine
.coordinator()
.execute_compiled_search_plan(&plan)
.map_err(SearchFfiError::Engine)?;
let py_rows = PySearchRows::from(rows);
serde_json::to_string(&py_rows).map_err(SearchFfiError::Serialize)
}
#[cfg(test)]
#[allow(clippy::expect_used, clippy::panic)]
mod tests {
use super::{
PyHitAttribution, PyRetrievalModality, PySearchHit, PySearchHitSource, PySearchMatchMode,
PySearchNodeRow, PySearchRows,
};
#[test]
fn search_rows_serde_roundtrip_empty() {
let rows = PySearchRows {
hits: Vec::new(),
was_degraded: false,
fallback_used: false,
strict_hit_count: 0,
relaxed_hit_count: 0,
vector_hit_count: 0,
};
let json = serde_json::to_string(&rows).expect("serialize");
let parsed: PySearchRows = serde_json::from_str(&json).expect("deserialize");
assert_eq!(rows, parsed);
}
#[test]
fn search_rows_serde_roundtrip_with_hit() {
let hit = PySearchHit {
node: PySearchNodeRow {
row_id: "row-1".into(),
logical_id: "node-1".into(),
kind: "Goal".into(),
properties: r#"{"name":"test"}"#.into(),
content_ref: Some("s3://x".into()),
last_accessed_at: Some(1_700_000_000),
},
score: 1.25,
modality: PyRetrievalModality::Text,
source: PySearchHitSource::Chunk,
match_mode: Some(PySearchMatchMode::Strict),
snippet: Some("... <b>test</b> ...".into()),
written_at: 1_700_000_001,
projection_row_id: Some("chunk-1".into()),
vector_distance: None,
attribution: Some(PyHitAttribution {
matched_paths: vec!["$.name".into()],
}),
};
let rows = PySearchRows {
hits: vec![hit],
was_degraded: false,
fallback_used: true,
strict_hit_count: 1,
relaxed_hit_count: 0,
vector_hit_count: 0,
};
let json = serde_json::to_string(&rows).expect("serialize");
let parsed: PySearchRows = serde_json::from_str(&json).expect("deserialize");
assert_eq!(rows, parsed);
}
#[test]
fn retrieval_modality_snake_case_wire_form() {
let json = serde_json::to_string(&PyRetrievalModality::Text).expect("serialize");
assert_eq!(json, "\"text\"");
let json = serde_json::to_string(&PyRetrievalModality::Vector).expect("serialize");
assert_eq!(json, "\"vector\"");
}
#[test]
fn search_hit_source_snake_case_wire_form() {
let json = serde_json::to_string(&PySearchHitSource::Chunk).expect("serialize");
assert_eq!(json, "\"chunk\"");
let json = serde_json::to_string(&PySearchHitSource::Property).expect("serialize");
assert_eq!(json, "\"property\"");
let json = serde_json::to_string(&PySearchHitSource::Vector).expect("serialize");
assert_eq!(json, "\"vector\"");
}
#[test]
fn search_match_mode_snake_case_wire_form() {
let json = serde_json::to_string(&PySearchMatchMode::Strict).expect("serialize");
assert_eq!(json, "\"strict\"");
let json = serde_json::to_string(&PySearchMatchMode::Relaxed).expect("serialize");
assert_eq!(json, "\"relaxed\"");
}
#[test]
fn search_request_deserializes_text_search_shape() {
use super::{PySearchFilter, PySearchMode, PySearchRequest};
let request: PySearchRequest = serde_json::from_str(
r#"{
"mode": "text_search",
"root_kind": "Goal",
"strict_query": "budget",
"limit": 10,
"filters": [{"type":"filter_kind_eq","kind":"Goal"}],
"attribution_requested": true
}"#,
)
.expect("parse");
assert!(matches!(request.mode, PySearchMode::TextSearch));
assert_eq!(request.root_kind, "Goal");
assert_eq!(request.strict_query, "budget");
assert_eq!(request.limit, 10);
assert!(request.attribution_requested);
assert!(request.relaxed_query.is_none());
assert_eq!(request.filters.len(), 1);
assert!(matches!(
request.filters[0],
PySearchFilter::FilterKindEq { ref kind } if kind == "Goal"
));
}
#[test]
fn search_request_deserializes_fallback_search_shape() {
use super::{PySearchMode, PySearchRequest};
let request: PySearchRequest = serde_json::from_str(
r#"{
"mode": "fallback_search",
"root_kind": "Goal",
"strict_query": "budget",
"relaxed_query": "budget OR alpha",
"limit": 5,
"filters": []
}"#,
)
.expect("parse");
assert!(matches!(request.mode, PySearchMode::FallbackSearch));
assert_eq!(request.relaxed_query.as_deref(), Some("budget OR alpha"));
assert!(!request.attribution_requested);
}
}