Skip to main content

synaptic_retrieval/
self_query.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4use serde_json::Value;
5use synaptic_core::{ChatModel, ChatRequest, Message, SynapseError};
6
7use crate::{Document, Retriever};
8
9/// Describes a metadata field for the LLM to understand available filters.
10#[derive(Debug, Clone)]
11pub struct MetadataFieldInfo {
12    pub name: String,
13    pub description: String,
14    pub field_type: String,
15}
16
17/// Uses a ChatModel to parse a user query into a structured query + metadata filters,
18/// then applies those filters to results from a base retriever.
19pub struct SelfQueryRetriever {
20    base: Arc<dyn Retriever>,
21    model: Arc<dyn ChatModel>,
22    field_info: Vec<MetadataFieldInfo>,
23}
24
25impl SelfQueryRetriever {
26    pub fn new(
27        base: Arc<dyn Retriever>,
28        model: Arc<dyn ChatModel>,
29        field_info: Vec<MetadataFieldInfo>,
30    ) -> Self {
31        Self {
32            base,
33            model,
34            field_info,
35        }
36    }
37
38    fn build_prompt(&self, query: &str) -> String {
39        let fields_desc = self
40            .field_info
41            .iter()
42            .map(|f| format!("- {} ({}): {}", f.name, f.field_type, f.description))
43            .collect::<Vec<_>>()
44            .join("\n");
45
46        format!(
47            r#"Given the following user query, extract a search query and any metadata filters.
48
49Available metadata fields:
50{fields_desc}
51
52Respond with a JSON object with two keys:
53- "query": the text query to search for (string)
54- "filters": an array of filter objects, each with "field", "op" (one of "eq", "gt", "lt", "gte", "lte", "contains"), and "value"
55
56If no filters apply, use an empty array.
57
58User query: {query}
59
60Respond with ONLY the JSON object, no explanation."#
61        )
62    }
63
64    async fn parse_query(&self, query: &str) -> Result<(String, Vec<Filter>), SynapseError> {
65        let prompt = self.build_prompt(query);
66        let request = ChatRequest::new(vec![Message::human(prompt)]);
67        let response = self.model.chat(request).await?;
68        let content = response.message.content().to_string();
69
70        // Try to parse as JSON
71        let parsed: Value = serde_json::from_str(content.trim()).map_err(|_| {
72            SynapseError::Retriever(format!("Failed to parse self-query response: {content}"))
73        })?;
74
75        let search_query = parsed["query"].as_str().unwrap_or(query).to_string();
76
77        let filters = parsed["filters"]
78            .as_array()
79            .map(|arr| {
80                arr.iter()
81                    .filter_map(|f| {
82                        let field = f["field"].as_str()?.to_string();
83                        let op = f["op"].as_str().unwrap_or("eq").to_string();
84                        let value = f["value"].clone();
85                        // Only include filters for known fields
86                        if self.field_info.iter().any(|fi| fi.name == field) {
87                            Some(Filter { field, op, value })
88                        } else {
89                            None
90                        }
91                    })
92                    .collect()
93            })
94            .unwrap_or_default();
95
96        Ok((search_query, filters))
97    }
98}
99
100#[derive(Debug, Clone)]
101struct Filter {
102    field: String,
103    op: String,
104    value: Value,
105}
106
107fn apply_filter(doc: &Document, filter: &Filter) -> bool {
108    let meta_value = match doc.metadata.get(&filter.field) {
109        Some(v) => v,
110        None => return false,
111    };
112
113    match filter.op.as_str() {
114        "eq" => meta_value == &filter.value,
115        "contains" => {
116            if let (Some(mv), Some(fv)) = (meta_value.as_str(), filter.value.as_str()) {
117                mv.contains(fv)
118            } else {
119                false
120            }
121        }
122        "gt" => compare_values(meta_value, &filter.value).is_some_and(|c| c > 0),
123        "gte" => compare_values(meta_value, &filter.value).is_some_and(|c| c >= 0),
124        "lt" => compare_values(meta_value, &filter.value).is_some_and(|c| c < 0),
125        "lte" => compare_values(meta_value, &filter.value).is_some_and(|c| c <= 0),
126        _ => true, // unknown op passes through
127    }
128}
129
130fn compare_values(a: &Value, b: &Value) -> Option<i32> {
131    match (a.as_f64(), b.as_f64()) {
132        (Some(av), Some(bv)) => {
133            if av > bv {
134                Some(1)
135            } else if av < bv {
136                Some(-1)
137            } else {
138                Some(0)
139            }
140        }
141        _ => match (a.as_str(), b.as_str()) {
142            (Some(av), Some(bv)) => Some(av.cmp(bv) as i32),
143            _ => None,
144        },
145    }
146}
147
148#[async_trait]
149impl Retriever for SelfQueryRetriever {
150    async fn retrieve(&self, query: &str, top_k: usize) -> Result<Vec<Document>, SynapseError> {
151        let (search_query, filters) = self.parse_query(query).await?;
152
153        let docs = self.base.retrieve(&search_query, top_k * 2).await?;
154
155        let filtered: Vec<Document> = if filters.is_empty() {
156            docs
157        } else {
158            docs.into_iter()
159                .filter(|doc| filters.iter().all(|f| apply_filter(doc, f)))
160                .collect()
161        };
162
163        Ok(filtered.into_iter().take(top_k).collect())
164    }
165}