use std::collections::HashSet;
use std::time::Duration;
use crate::llm::{ChatClient, ChatError, ChatMessage, strip_code_fences};
use crate::text::nfd;
use super::error::ExpansionError;
use super::types::{ExpansionBody, RecallDistillationBody};
pub const DEFAULT_EXPANSION_TIMEOUT: Duration = Duration::from_secs(30);
const EXPANSION_TEMPERATURE: f32 = 0.0;
const SYSTEM_PROMPT: &str = "Return only valid JSON of the form \
{\"queries\":[\"...\"]}. Generate 2 to 4 short search reformulations. \
Do not repeat the original query. Prefer terse, concrete terms that \
would help Obsidian search. If the user message includes a \
\"Query intent:\" line, every reformulation must stay consistent with \
that intent and avoid unrelated senses of the original query.";
const DISTILLER_SYSTEM_PROMPT: &str = "Return only valid JSON of the form \
{\"search_query\":\"...\",\"phrases\":[\"...\"],\"identifiers\":[\"...\"]}. \
Given a user prompt for an Obsidian memory system, extract the retrieval \
intent. Return one compact semantic query plus concrete project, decision, \
concept, person, place, artifact, path, tag, wikilink, and identifier \
phrases worth searching. Ignore tool chatter, code blocks, logs, \
boilerplate, and unrelated implementation detail.";
#[derive(Debug, Clone)]
pub struct ExpansionClient {
chat: ChatClient,
}
impl ExpansionClient {
#[must_use]
pub const fn from_chat(chat: ChatClient) -> Self {
Self { chat }
}
pub fn new(
base_url: impl Into<String>,
model: impl Into<String>,
) -> Result<Self, ExpansionError> {
Self::with_timeout(base_url, model, DEFAULT_EXPANSION_TIMEOUT)
}
pub fn with_timeout(
base_url: impl Into<String>,
model: impl Into<String>,
timeout: Duration,
) -> Result<Self, ExpansionError> {
Self::with_timeout_and_max_tokens(base_url, model, timeout, None)
}
pub fn with_max_tokens(
base_url: impl Into<String>,
model: impl Into<String>,
max_tokens: Option<u32>,
) -> Result<Self, ExpansionError> {
Self::with_timeout_and_max_tokens(base_url, model, DEFAULT_EXPANSION_TIMEOUT, max_tokens)
}
pub fn with_timeout_and_max_tokens(
base_url: impl Into<String>,
model: impl Into<String>,
timeout: Duration,
max_tokens: Option<u32>,
) -> Result<Self, ExpansionError> {
let chat = ChatClient::with_timeout_and_max_tokens(base_url, model, timeout, max_tokens)
.map_err(ExpansionError::from)?;
Ok(Self { chat })
}
pub fn with_no_timeout_and_max_tokens(
base_url: impl Into<String>,
model: impl Into<String>,
max_tokens: Option<u32>,
) -> Result<Self, ExpansionError> {
let chat = ChatClient::with_no_timeout_and_max_tokens(base_url, model, max_tokens)
.map_err(ExpansionError::from)?;
Ok(Self { chat })
}
pub fn expand(&self, query: &str, n_variants: u8) -> Result<Vec<String>, ExpansionError> {
self.expand_with_intent(query, None, n_variants)
}
pub fn expand_with_intent(
&self,
query: &str,
intent: Option<&str>,
n_variants: u8,
) -> Result<Vec<String>, ExpansionError> {
let user_content = build_user_message(query, intent);
let messages = vec![
ChatMessage::new("system", SYSTEM_PROMPT),
ChatMessage::new("user", user_content),
];
let content = match self.chat.complete(messages, EXPANSION_TEMPERATURE) {
Ok(content) => content,
Err(ChatError::MalformedResponse) => return Ok(vec![]),
Err(err) => return Err(ExpansionError::from(err)),
};
let cleaned = strip_code_fences(&content);
let expansion: ExpansionBody = match serde_json::from_str(&cleaned) {
Ok(e) => e,
Err(_) => return Ok(vec![]),
};
Ok(normalize_queries(query, expansion.queries, n_variants))
}
pub fn distill_recall_prompt(
&self,
prompt_view: &str,
extraction_hints: &[String],
) -> Result<Option<RecallDistillationBody>, ExpansionError> {
let mut user_content = String::from("Prompt view:\n");
user_content.push_str(prompt_view);
if !extraction_hints.is_empty() {
user_content.push_str("\n\nExtraction hints:\n");
for hint in extraction_hints {
user_content.push_str("- ");
user_content.push_str(hint);
user_content.push('\n');
}
}
let messages = vec![
ChatMessage::new("system", DISTILLER_SYSTEM_PROMPT),
ChatMessage::new("user", user_content),
];
let content = match self.chat.complete(messages, EXPANSION_TEMPERATURE) {
Ok(content) => content,
Err(ChatError::MalformedResponse) => return Ok(None),
Err(err) => return Err(ExpansionError::from(err)),
};
let cleaned = strip_code_fences(&content);
let mut body: RecallDistillationBody = match serde_json::from_str(&cleaned) {
Ok(body) => body,
Err(_) => return Ok(None),
};
let search_query = body.search_query.trim().to_owned();
body.search_query = search_query;
body.phrases = normalize_items(body.phrases, 12);
body.identifiers = normalize_items(body.identifiers, 12);
if body.search_query.is_empty() {
Ok(None)
} else {
Ok(Some(body))
}
}
}
fn build_user_message(query: &str, intent: Option<&str>) -> String {
intent.map(str::trim).filter(|s| !s.is_empty()).map_or_else(
|| format!("Query: {query}"),
|intent| format!("Query: {query}\nQuery intent: {intent}"),
)
}
fn normalize_queries(original: &str, queries: Vec<String>, limit: u8) -> Vec<String> {
let normalized_original = nfd::normalize(original.trim()).to_lowercase();
let limit = usize::from(limit);
let mut seen: HashSet<String> = HashSet::new();
let mut result = Vec::with_capacity(limit);
for candidate in queries {
let trimmed = candidate.trim().to_owned();
if trimmed.is_empty() {
continue;
}
let normalized = nfd::normalize(&trimmed).to_lowercase();
if normalized != normalized_original && seen.insert(normalized) {
result.push(trimmed);
if result.len() >= limit {
break;
}
}
}
result
}
fn normalize_items(items: Vec<String>, limit: usize) -> Vec<String> {
let mut seen: HashSet<String> = HashSet::new();
let mut result = Vec::with_capacity(limit);
for item in items {
let trimmed = item.trim().to_owned();
if trimmed.is_empty() {
continue;
}
let normalized = nfd::normalize(&trimmed).to_lowercase();
if seen.insert(normalized) {
result.push(trimmed);
if result.len() >= limit {
break;
}
}
}
result
}
#[cfg(test)]
mod tests;