synaptic_retrieval/
self_query.rs1use std::sync::Arc;
2
3use async_trait::async_trait;
4use serde_json::Value;
5use synaptic_core::{ChatModel, ChatRequest, Message, SynapseError};
6
7use crate::{Document, Retriever};
8
9#[derive(Debug, Clone)]
11pub struct MetadataFieldInfo {
12 pub name: String,
13 pub description: String,
14 pub field_type: String,
15}
16
17pub struct SelfQueryRetriever {
20 base: Arc<dyn Retriever>,
21 model: Arc<dyn ChatModel>,
22 field_info: Vec<MetadataFieldInfo>,
23}
24
25impl SelfQueryRetriever {
26 pub fn new(
27 base: Arc<dyn Retriever>,
28 model: Arc<dyn ChatModel>,
29 field_info: Vec<MetadataFieldInfo>,
30 ) -> Self {
31 Self {
32 base,
33 model,
34 field_info,
35 }
36 }
37
38 fn build_prompt(&self, query: &str) -> String {
39 let fields_desc = self
40 .field_info
41 .iter()
42 .map(|f| format!("- {} ({}): {}", f.name, f.field_type, f.description))
43 .collect::<Vec<_>>()
44 .join("\n");
45
46 format!(
47 r#"Given the following user query, extract a search query and any metadata filters.
48
49Available metadata fields:
50{fields_desc}
51
52Respond with a JSON object with two keys:
53- "query": the text query to search for (string)
54- "filters": an array of filter objects, each with "field", "op" (one of "eq", "gt", "lt", "gte", "lte", "contains"), and "value"
55
56If no filters apply, use an empty array.
57
58User query: {query}
59
60Respond with ONLY the JSON object, no explanation."#
61 )
62 }
63
64 async fn parse_query(&self, query: &str) -> Result<(String, Vec<Filter>), SynapseError> {
65 let prompt = self.build_prompt(query);
66 let request = ChatRequest::new(vec![Message::human(prompt)]);
67 let response = self.model.chat(request).await?;
68 let content = response.message.content().to_string();
69
70 let parsed: Value = serde_json::from_str(content.trim()).map_err(|_| {
72 SynapseError::Retriever(format!("Failed to parse self-query response: {content}"))
73 })?;
74
75 let search_query = parsed["query"].as_str().unwrap_or(query).to_string();
76
77 let filters = parsed["filters"]
78 .as_array()
79 .map(|arr| {
80 arr.iter()
81 .filter_map(|f| {
82 let field = f["field"].as_str()?.to_string();
83 let op = f["op"].as_str().unwrap_or("eq").to_string();
84 let value = f["value"].clone();
85 if self.field_info.iter().any(|fi| fi.name == field) {
87 Some(Filter { field, op, value })
88 } else {
89 None
90 }
91 })
92 .collect()
93 })
94 .unwrap_or_default();
95
96 Ok((search_query, filters))
97 }
98}
99
100#[derive(Debug, Clone)]
101struct Filter {
102 field: String,
103 op: String,
104 value: Value,
105}
106
107fn apply_filter(doc: &Document, filter: &Filter) -> bool {
108 let meta_value = match doc.metadata.get(&filter.field) {
109 Some(v) => v,
110 None => return false,
111 };
112
113 match filter.op.as_str() {
114 "eq" => meta_value == &filter.value,
115 "contains" => {
116 if let (Some(mv), Some(fv)) = (meta_value.as_str(), filter.value.as_str()) {
117 mv.contains(fv)
118 } else {
119 false
120 }
121 }
122 "gt" => compare_values(meta_value, &filter.value).is_some_and(|c| c > 0),
123 "gte" => compare_values(meta_value, &filter.value).is_some_and(|c| c >= 0),
124 "lt" => compare_values(meta_value, &filter.value).is_some_and(|c| c < 0),
125 "lte" => compare_values(meta_value, &filter.value).is_some_and(|c| c <= 0),
126 _ => true, }
128}
129
130fn compare_values(a: &Value, b: &Value) -> Option<i32> {
131 match (a.as_f64(), b.as_f64()) {
132 (Some(av), Some(bv)) => {
133 if av > bv {
134 Some(1)
135 } else if av < bv {
136 Some(-1)
137 } else {
138 Some(0)
139 }
140 }
141 _ => match (a.as_str(), b.as_str()) {
142 (Some(av), Some(bv)) => Some(av.cmp(bv) as i32),
143 _ => None,
144 },
145 }
146}
147
148#[async_trait]
149impl Retriever for SelfQueryRetriever {
150 async fn retrieve(&self, query: &str, top_k: usize) -> Result<Vec<Document>, SynapseError> {
151 let (search_query, filters) = self.parse_query(query).await?;
152
153 let docs = self.base.retrieve(&search_query, top_k * 2).await?;
154
155 let filtered: Vec<Document> = if filters.is_empty() {
156 docs
157 } else {
158 docs.into_iter()
159 .filter(|doc| filters.iter().all(|f| apply_filter(doc, f)))
160 .collect()
161 };
162
163 Ok(filtered.into_iter().take(top_k).collect())
164 }
165}