use serde_json::{json, Value};
use super::{graph, project_context, search, HandlerContext};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Intent {
Lookup,
Exploration,
Context,
Path,
}
impl Intent {
fn as_str(self) -> &'static str {
match self {
Intent::Lookup => "lookup",
Intent::Exploration => "exploration",
Intent::Context => "context",
Intent::Path => "path",
}
}
fn strategy_name(self) -> &'static str {
match self {
Intent::Lookup => "memory_search",
Intent::Exploration => "memory_related+memory_search",
Intent::Context => "memory_get_project_context",
Intent::Path => "memory_find_path",
}
}
}
pub fn classify(query: &str) -> Vec<Intent> {
let q = query.trim().to_lowercase();
let mut intents = Vec::new();
if (q.contains(" and ") || q.contains(" e ") || q.contains(" entre "))
&& (q.contains("connect")
|| q.contains("relat")
|| q.contains("link")
|| q.contains("path")
|| q.contains("conect")
|| q.contains("relac")
|| q.contains("caminho"))
{
intents.push(Intent::Path);
}
if q.contains("related to")
|| q.contains("similar to")
|| q.contains("connected to")
|| q.contains("relacionad")
|| q.contains("similar a")
|| q.contains("ligado a")
{
intents.push(Intent::Exploration);
}
if q.contains("status")
|| q.contains("overview")
|| q.contains("what's going on")
|| q.contains("what is going on")
|| q.contains("resumo")
|| q.contains("o que está acontecendo")
|| q.contains("contexto do projeto")
{
intents.push(Intent::Context);
}
if !intents.iter().any(|i| matches!(i, Intent::Lookup)) {
intents.push(Intent::Lookup);
}
intents
}
fn extract_id(entry: &Value) -> Option<i64> {
entry
.get("id")
.or_else(|| entry.get("memory_id"))
.or_else(|| entry.get("memory").and_then(|m| m.get("id")))
.and_then(|v| v.as_i64())
}
fn extract_list(value: &Value, keys: &[&str]) -> Vec<Value> {
if let Some(arr) = value.as_array() {
return arr.clone();
}
for key in keys {
if let Some(arr) = value.get(*key).and_then(|v| v.as_array()) {
return arr.clone();
}
}
Vec::new()
}
fn call_search(
ctx: &HandlerContext,
query: &str,
limit: u64,
workspace: Option<&str>,
) -> Vec<Value> {
let mut params = json!({ "query": query, "limit": limit });
if let Some(ws) = workspace {
params["workspace"] = json!(ws);
}
let resp = search::memory_search(ctx, params);
extract_list(&resp, &["results"])
}
fn strip_intent_markers(query: &str) -> String {
const MARKERS: &[&str] = &[
"related to ",
"similar to ",
"connected to ",
"relacionado a ",
"relacionado ao ",
"relacionado aos ",
"relacionada a ",
"relacionadas a ",
"similar a ",
"ligado a ",
"ligados a ",
"ligadas a ",
];
let lower = query.to_lowercase();
for m in MARKERS {
if let Some(idx) = lower.find(m) {
return query[idx + m.len()..].trim().to_string();
}
}
query.to_string()
}
fn call_related(
ctx: &HandlerContext,
query: &str,
limit: u64,
workspace: Option<&str>,
) -> Vec<Value> {
let cleaned = strip_intent_markers(query);
let seed = call_search(ctx, &cleaned, 1, workspace);
let Some(seed_id) = seed.first().and_then(extract_id) else {
return Vec::new();
};
let params = json!({ "id": seed_id, "limit": limit });
let resp = graph::memory_related(ctx, params);
extract_list(&resp, &["related", "results"])
}
fn call_context(ctx: &HandlerContext, workspace: Option<&str>) -> Vec<Value> {
let workspace = workspace.unwrap_or("default");
let params = json!({ "workspace": workspace });
let resp = project_context::get_project_context(ctx, params);
extract_list(&resp, &["memories", "results"])
}
pub fn memory_smart_retrieve(ctx: &HandlerContext, params: Value) -> Value {
let query = match params.get("query").and_then(|v| v.as_str()) {
Some(q) if !q.trim().is_empty() => q,
_ => return json!({ "error": "missing or empty `query` parameter" }),
};
let limit = params
.get("limit")
.and_then(|v| v.as_u64())
.unwrap_or(10)
.clamp(1, 100);
let workspace = params.get("workspace").and_then(|v| v.as_str());
let intents: Vec<Intent> = match params.get("force_intents").and_then(|v| v.as_array()) {
Some(arr) => arr
.iter()
.filter_map(|v| v.as_str())
.filter_map(|s| match s {
"lookup" => Some(Intent::Lookup),
"exploration" => Some(Intent::Exploration),
"context" => Some(Intent::Context),
"path" => Some(Intent::Path),
_ => None,
})
.collect(),
None => classify(query),
};
let mut merged: Vec<Value> = Vec::new();
let mut seen_ids: std::collections::HashSet<i64> = std::collections::HashSet::new();
let mut strategies_called: Vec<&'static str> = Vec::new();
for intent in &intents {
let entries = match intent {
Intent::Lookup => call_search(ctx, query, limit, workspace),
Intent::Exploration => {
let cleaned = strip_intent_markers(query);
let mut combined = call_related(ctx, query, limit, workspace);
combined.extend(call_search(ctx, &cleaned, limit, workspace));
combined
}
Intent::Context => call_context(ctx, workspace),
Intent::Path => Vec::new(),
};
if !strategies_called.contains(&intent.strategy_name()) {
strategies_called.push(intent.strategy_name());
}
for entry in entries {
match extract_id(&entry) {
Some(id) if seen_ids.insert(id) => merged.push(entry),
None => merged.push(entry), _ => {} }
if merged.len() >= limit as usize {
break;
}
}
if merged.len() >= limit as usize {
break;
}
}
json!({
"results": merged,
"intents_used": intents.iter().map(|i| i.as_str()).collect::<Vec<_>>(),
"strategies_called": strategies_called,
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn lookup_is_default() {
let intents = classify("rust async runtime");
assert_eq!(intents, vec![Intent::Lookup]);
}
#[test]
fn exploration_detected() {
let intents = classify("things related to async");
assert!(intents.contains(&Intent::Exploration));
assert!(intents.contains(&Intent::Lookup)); }
#[test]
fn context_detected() {
let intents = classify("what is the status of the project");
assert!(intents.contains(&Intent::Context));
}
#[test]
fn path_detected() {
let intents = classify("how are tokio and async connected");
assert!(intents.contains(&Intent::Path));
}
#[test]
fn portuguese_exploration() {
let intents = classify("memórias relacionadas a engram");
assert!(intents.contains(&Intent::Exploration));
}
#[test]
fn portuguese_path() {
let intents = classify("como engram e tokio estão conectados");
assert!(intents.contains(&Intent::Path));
}
#[test]
fn force_intents_override() {
let intents = classify("x");
assert_eq!(intents, vec![Intent::Lookup]);
}
}