use rmcp::model::{CallToolResult, Content, ErrorCode};
use schemars::JsonSchema;
use serde::Deserialize;
use crate::error::VaultError;
use crate::models::SearchField;
use crate::vault::Vault;
#[derive(Deserialize, JsonSchema, Default)]
pub struct SearchTextParams {
pub query: String,
#[serde(default)]
pub context_length: Option<usize>,
#[serde(default)]
pub max_results: Option<usize>,
#[serde(default)]
pub fuzzy: Option<bool>,
#[serde(default)]
pub fields: Option<Vec<SearchField>>,
}
pub async fn search_text(
vault: &Vault,
params: SearchTextParams,
) -> Result<CallToolResult, rmcp::ErrorData> {
let context_length = params.context_length.unwrap_or(100);
let max_results = params.max_results.unwrap_or(20);
let fuzzy = params.fuzzy.unwrap_or(false);
let results = if fuzzy || params.fields.is_some() {
let fields_slice = params.fields.as_deref();
vault.search_text_with_options(
¶ms.query,
context_length,
max_results,
fuzzy,
fields_slice,
)?
} else {
let all = vault.search_text(¶ms.query, context_length)?;
all.into_iter().take(max_results).collect()
};
let json = serde_json::to_string_pretty(&results)
.map_err(|e| VaultError::Other(format!("JSON serialization failed: {e}")))?;
Ok(CallToolResult::success(vec![Content::text(json)]))
}
#[derive(Deserialize, JsonSchema, Default)]
pub struct SearchRegexParams {
pub pattern: String,
#[serde(default)]
pub context_length: Option<usize>,
#[serde(default)]
pub max_results: Option<usize>,
}
pub async fn search_regex(
vault: &Vault,
params: SearchRegexParams,
) -> Result<CallToolResult, rmcp::ErrorData> {
let context_length = params.context_length.unwrap_or(100);
let max_results = params.max_results.unwrap_or(20);
let results = vault.search_regex(¶ms.pattern, context_length)?;
let limited: Vec<_> = results.into_iter().take(max_results).collect();
let json = serde_json::to_string_pretty(&limited)
.map_err(|e| VaultError::Other(format!("JSON serialization failed: {e}")))?;
Ok(CallToolResult::success(vec![Content::text(json)]))
}
#[derive(Deserialize, JsonSchema, Default)]
pub struct SearchTagParams {
pub tag: String,
#[serde(default)]
pub include_nested: Option<bool>,
}
pub async fn search_tag(
vault: &Vault,
params: SearchTagParams,
) -> Result<CallToolResult, rmcp::ErrorData> {
let tag = params.tag.strip_prefix('#').unwrap_or(¶ms.tag);
let include_nested = params.include_nested.unwrap_or(true);
let results = if include_nested {
vault.search_by_tag_prefix(tag)?
} else {
vault.search_by_tag(tag)?
};
let json = serde_json::to_string_pretty(&results)
.map_err(|e| VaultError::Other(format!("JSON serialization failed: {e}")))?;
Ok(CallToolResult::success(vec![Content::text(json)]))
}
#[derive(Debug, Clone, Deserialize, JsonSchema, Default)]
#[serde(rename_all = "snake_case")]
pub enum FrontmatterOperator {
#[default]
Eq,
Contains,
Exists,
}
#[derive(Deserialize, JsonSchema, Default)]
pub struct SearchFrontmatterParams {
pub field: String,
#[serde(default)]
pub value: Option<serde_json::Value>,
#[serde(default)]
pub operator: FrontmatterOperator,
}
pub async fn search_frontmatter(
vault: &Vault,
params: SearchFrontmatterParams,
) -> Result<CallToolResult, rmcp::ErrorData> {
let results = match params.operator {
FrontmatterOperator::Exists => vault.search_frontmatter_exists(¶ms.field)?,
FrontmatterOperator::Eq => {
let value = params.value.ok_or_else(|| {
rmcp::ErrorData::new(
ErrorCode::INVALID_PARAMS,
"'value' is required for 'eq' operator",
None::<serde_json::Value>,
)
})?;
vault.search_frontmatter(¶ms.field, &value)?
}
FrontmatterOperator::Contains => {
let value = params.value.ok_or_else(|| {
rmcp::ErrorData::new(
ErrorCode::INVALID_PARAMS,
"'value' is required for 'contains' operator",
None::<serde_json::Value>,
)
})?;
vault.search_frontmatter_contains(¶ms.field, &value)?
}
};
let json = serde_json::to_string_pretty(&results)
.map_err(|e| VaultError::Other(format!("JSON serialization failed: {e}")))?;
Ok(CallToolResult::success(vec![Content::text(json)]))
}
#[cfg(feature = "embeddings")]
const DEFAULT_PREFETCH_COUNT: usize = 50;
#[cfg(feature = "embeddings")]
const SNIPPET_CONTEXT_LEN: usize = 100;
#[cfg(feature = "embeddings")]
const SNIPPET_FALLBACK_CHARS: usize = 200;
#[derive(Deserialize, JsonSchema, Default)]
pub struct SearchSemanticParams {
pub query: String,
#[serde(default)]
pub top_k: Option<usize>,
#[serde(default)]
pub include_content: Option<bool>,
#[serde(default)]
pub lexical_prefetch: Option<bool>,
#[serde(default)]
pub alpha: Option<f32>,
}
#[cfg(feature = "embeddings")]
#[derive(serde::Serialize, JsonSchema)]
struct SemanticSearchResult {
path: std::path::PathBuf,
title: String,
score: f32,
tags: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
snippet: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
content: Option<String>,
}
#[cfg(feature = "embeddings")]
pub async fn search_semantic(
vault: &Vault,
params: SearchSemanticParams,
default_alpha: f32,
) -> Result<CallToolResult, rmcp::ErrorData> {
if !vault.has_embeddings() {
return Err(rmcp::ErrorData::new(
ErrorCode::INVALID_REQUEST,
"Embeddings are not enabled. Set OBSIDIAN_EMBEDDINGS=true and build with --features embeddings.",
None::<serde_json::Value>,
));
}
let top_k = params.top_k.unwrap_or(10);
let include_content = params.include_content.unwrap_or(false);
let lexical_prefetch = params.lexical_prefetch.unwrap_or(false);
let alpha = params.alpha.unwrap_or(default_alpha).clamp(0.0, 1.0);
let hits = if lexical_prefetch {
vault.search_hybrid(¶ms.query, top_k, DEFAULT_PREFETCH_COUNT, alpha)?
} else {
vault.search_semantic(¶ms.query, top_k)?
};
let word_re = if !include_content {
let pattern: String = params
.query
.split_whitespace()
.map(regex::escape)
.collect::<Vec<_>>()
.join("|");
if pattern.is_empty() {
None
} else {
regex::Regex::new(&format!("(?i){pattern}")).ok()
}
} else {
None
};
let mut results = Vec::with_capacity(hits.len());
for (path, score) in hits {
let meta = vault.get_note_metadata(&path).ok();
let title = meta.as_ref().map(|m| m.title.clone()).unwrap_or_default();
let tags = meta.as_ref().map(|m| m.tags.clone()).unwrap_or_default();
let (content, snippet) = if include_content {
(vault.read_note(&path).ok(), None)
} else {
let snip = vault.read_note(&path).ok().map(|text| {
if let Some(ref re) = word_re {
if let Some(m) = re.find(&text) {
let (ctx, _, _, _) = crate::vault::index::extract_match_context(
&text,
m.start(),
m.end(),
SNIPPET_CONTEXT_LEN,
);
return ctx;
}
}
body_preview(&text, SNIPPET_FALLBACK_CHARS)
});
(None, snip)
};
results.push(SemanticSearchResult {
path,
title,
score,
tags,
snippet,
content,
});
}
let json = serde_json::to_string_pretty(&results)
.map_err(|e| VaultError::Other(format!("JSON serialization failed: {e}")))?;
Ok(CallToolResult::success(vec![Content::text(json)]))
}
#[cfg(feature = "embeddings")]
fn body_preview(content: &str, max_chars: usize) -> String {
let start = if content.starts_with("---") {
content[3..]
.find("\n---")
.map(|i| {
let end = i + 7; content[end..].find('\n').map_or(end, |nl| end + nl + 1)
})
.unwrap_or(0)
} else {
0
};
let body = content[start..].trim_start();
body.chars().take(max_chars).collect()
}
#[cfg(not(feature = "embeddings"))]
pub async fn search_semantic(
_vault: &Vault,
_params: SearchSemanticParams,
_default_alpha: f32,
) -> Result<CallToolResult, rmcp::ErrorData> {
Err(rmcp::ErrorData::new(
ErrorCode::INVALID_REQUEST,
"Semantic search is not available. This binary was compiled without the 'embeddings' feature. Rebuild with: cargo build --features embeddings",
None::<serde_json::Value>,
))
}
#[cfg(test)]
mod tests {
use std::path::Path;
use super::*;
use crate::config::Config;
fn test_config(vault_root: &Path) -> Config {
Config {
vault_path: vault_root.to_path_buf(),
watch: false,
log_level: "error".into(),
tantivy: false,
embeddings: false,
embeddings_model: String::new(),
hybrid_alpha: 0.25,
}
}
fn create_test_vault(dir: &Path) {
std::fs::create_dir_all(dir.join(".obsidian")).unwrap();
}
fn extract_text(result: &CallToolResult) -> &str {
result.content[0]
.as_text()
.expect("expected text content")
.text
.as_str()
}
async fn setup_search_vault() -> (tempfile::TempDir, Vault) {
let dir = tempfile::tempdir().unwrap();
create_test_vault(dir.path());
let vault = Vault::open(&test_config(dir.path())).await.unwrap();
vault
.write_note(
Path::new("rust.md"),
"---\ntags: [lang, systems]\nstatus: stable\n---\n# Rust\nRust is a systems language.\n",
)
.unwrap();
vault
.write_note(
Path::new("python.md"),
"---\ntags: [lang, scripting]\nstatus: in progress\n---\n# Python\nPython is dynamic.\n",
)
.unwrap();
vault
.write_note(
Path::new("notes.md"),
"# Notes\nSome random notes about #inbox stuff.\n\n#inbox/read #inbox/todo\n",
)
.unwrap();
vault
.write_note(
Path::new("empty.md"),
"# Empty\nNothing interesting here.\n",
)
.unwrap();
(dir, vault)
}
#[tokio::test]
async fn search_text_finds_match() {
let (_dir, vault) = setup_search_vault().await;
let result = search_text(
&vault,
SearchTextParams {
query: "systems".into(),
..Default::default()
},
)
.await
.unwrap();
let text = extract_text(&result);
assert!(text.contains("rust.md"));
assert!(!text.contains("python.md"));
}
#[tokio::test]
async fn search_text_limits_results() {
let (_dir, vault) = setup_search_vault().await;
let result = search_text(
&vault,
SearchTextParams {
query: "is".into(),
max_results: Some(1),
..Default::default()
},
)
.await
.unwrap();
let text = extract_text(&result);
let parsed: Vec<serde_json::Value> = serde_json::from_str(text).unwrap();
assert_eq!(parsed.len(), 1);
}
#[tokio::test]
async fn search_text_empty_query_returns_empty() {
let (_dir, vault) = setup_search_vault().await;
let result = search_text(
&vault,
SearchTextParams {
query: String::new(),
..Default::default()
},
)
.await
.unwrap();
let text = extract_text(&result);
let parsed: Vec<serde_json::Value> = serde_json::from_str(text).unwrap();
assert!(parsed.is_empty());
}
#[tokio::test]
async fn search_regex_valid_pattern() {
let (_dir, vault) = setup_search_vault().await;
let result = search_regex(
&vault,
SearchRegexParams {
pattern: r"(?i)python".into(),
..Default::default()
},
)
.await
.unwrap();
let text = extract_text(&result);
assert!(text.contains("python.md"));
}
#[tokio::test]
async fn search_regex_invalid_pattern_returns_error() {
let (_dir, vault) = setup_search_vault().await;
let result = search_regex(
&vault,
SearchRegexParams {
pattern: "[invalid".into(),
..Default::default()
},
)
.await;
assert!(result.is_err());
}
#[tokio::test]
async fn search_regex_limits_results() {
let (_dir, vault) = setup_search_vault().await;
let result = search_regex(
&vault,
SearchRegexParams {
pattern: r"\w+".into(),
max_results: Some(2),
..Default::default()
},
)
.await
.unwrap();
let text = extract_text(&result);
let parsed: Vec<serde_json::Value> = serde_json::from_str(text).unwrap();
assert!(parsed.len() <= 2);
}
#[tokio::test]
async fn search_tag_exact() {
let (_dir, vault) = setup_search_vault().await;
let result = search_tag(
&vault,
SearchTagParams {
tag: "inbox".into(),
include_nested: Some(false),
},
)
.await
.unwrap();
let text = extract_text(&result);
assert!(text.contains("notes.md"));
let parsed: Vec<serde_json::Value> = serde_json::from_str(text).unwrap();
assert_eq!(parsed.len(), 1);
}
#[tokio::test]
async fn search_tag_include_nested() {
let (_dir, vault) = setup_search_vault().await;
let result = search_tag(
&vault,
SearchTagParams {
tag: "inbox".into(),
include_nested: Some(true),
},
)
.await
.unwrap();
let text = extract_text(&result);
assert!(text.contains("notes.md"));
}
#[tokio::test]
async fn search_tag_strips_hash_prefix() {
let (_dir, vault) = setup_search_vault().await;
let result = search_tag(
&vault,
SearchTagParams {
tag: "#lang".into(),
include_nested: Some(false),
},
)
.await
.unwrap();
let text = extract_text(&result);
let parsed: Vec<serde_json::Value> = serde_json::from_str(text).unwrap();
assert_eq!(parsed.len(), 2);
}
#[tokio::test]
async fn search_frontmatter_eq() {
let (_dir, vault) = setup_search_vault().await;
let result = search_frontmatter(
&vault,
SearchFrontmatterParams {
field: "status".into(),
value: Some(serde_json::json!("stable")),
operator: FrontmatterOperator::Eq,
},
)
.await
.unwrap();
let text = extract_text(&result);
assert!(text.contains("rust.md"));
let parsed: Vec<serde_json::Value> = serde_json::from_str(text).unwrap();
assert_eq!(parsed.len(), 1);
}
#[tokio::test]
async fn search_frontmatter_eq_array_contains() {
let (_dir, vault) = setup_search_vault().await;
let result = search_frontmatter(
&vault,
SearchFrontmatterParams {
field: "tags".into(),
value: Some(serde_json::json!("systems")),
operator: FrontmatterOperator::Eq,
},
)
.await
.unwrap();
let text = extract_text(&result);
assert!(text.contains("rust.md"));
}
#[tokio::test]
async fn search_frontmatter_contains_substring() {
let (_dir, vault) = setup_search_vault().await;
let result = search_frontmatter(
&vault,
SearchFrontmatterParams {
field: "status".into(),
value: Some(serde_json::json!("progress")),
operator: FrontmatterOperator::Contains,
},
)
.await
.unwrap();
let text = extract_text(&result);
assert!(text.contains("python.md"));
let parsed: Vec<serde_json::Value> = serde_json::from_str(text).unwrap();
assert_eq!(parsed.len(), 1);
}
#[tokio::test]
async fn search_frontmatter_exists() {
let (_dir, vault) = setup_search_vault().await;
let result = search_frontmatter(
&vault,
SearchFrontmatterParams {
field: "status".into(),
value: None,
operator: FrontmatterOperator::Exists,
},
)
.await
.unwrap();
let text = extract_text(&result);
let parsed: Vec<serde_json::Value> = serde_json::from_str(text).unwrap();
assert_eq!(parsed.len(), 2); }
#[tokio::test]
async fn search_frontmatter_exists_missing_field() {
let (_dir, vault) = setup_search_vault().await;
let result = search_frontmatter(
&vault,
SearchFrontmatterParams {
field: "nonexistent".into(),
value: None,
operator: FrontmatterOperator::Exists,
},
)
.await
.unwrap();
let text = extract_text(&result);
let parsed: Vec<serde_json::Value> = serde_json::from_str(text).unwrap();
assert!(parsed.is_empty());
}
#[tokio::test]
async fn search_frontmatter_eq_without_value_errors() {
let (_dir, vault) = setup_search_vault().await;
let result = search_frontmatter(
&vault,
SearchFrontmatterParams {
field: "status".into(),
value: None,
operator: FrontmatterOperator::Eq,
},
)
.await;
assert!(result.is_err());
}
fn tantivy_config(vault_root: &Path) -> Config {
Config {
vault_path: vault_root.to_path_buf(),
watch: false,
log_level: "error".into(),
tantivy: true,
embeddings: false,
embeddings_model: String::new(),
hybrid_alpha: 0.25,
}
}
async fn setup_tantivy_vault() -> (tempfile::TempDir, Vault) {
let dir = tempfile::tempdir().unwrap();
create_test_vault(dir.path());
let vault = Vault::open(&tantivy_config(dir.path())).await.unwrap();
vault
.write_note(
Path::new("rust.md"),
"---\ntags: [lang, systems]\nstatus: stable\n---\n# Rust\nRust is a systems programming language.\n",
)
.unwrap();
vault
.write_note(
Path::new("python.md"),
"---\ntags: [lang, scripting]\nstatus: in progress\n---\n# Python\nPython is a dynamic scripting language.\n",
)
.unwrap();
vault
.write_note(
Path::new("cooking.md"),
"# Cooking Tips\nHow to make a great pasta dish.\n",
)
.unwrap();
(dir, vault)
}
#[tokio::test]
async fn search_text_tantivy_returns_scores() {
let (_dir, vault) = setup_tantivy_vault().await;
let result = search_text(
&vault,
SearchTextParams {
query: "systems".into(),
..Default::default()
},
)
.await
.unwrap();
let text = extract_text(&result);
let parsed: Vec<serde_json::Value> = serde_json::from_str(text).unwrap();
assert!(!parsed.is_empty());
assert!(text.contains("rust.md"));
assert!(parsed[0].get("score").is_some());
assert!(parsed[0]["score"].as_f64().unwrap() > 0.0);
}
#[tokio::test]
async fn search_text_tantivy_ranked_descending() {
let (_dir, vault) = setup_tantivy_vault().await;
let result = search_text(
&vault,
SearchTextParams {
query: "language".into(),
..Default::default()
},
)
.await
.unwrap();
let text = extract_text(&result);
let parsed: Vec<serde_json::Value> = serde_json::from_str(text).unwrap();
if parsed.len() >= 2 {
let s0 = parsed[0]["score"].as_f64().unwrap();
let s1 = parsed[1]["score"].as_f64().unwrap();
assert!(s0 >= s1, "results should be sorted by score descending");
}
}
#[tokio::test]
async fn search_text_tantivy_fuzzy() {
let (_dir, vault) = setup_tantivy_vault().await;
let result = search_text(
&vault,
SearchTextParams {
query: "pyhton".into(),
fuzzy: Some(true),
..Default::default()
},
)
.await
.unwrap();
let text = extract_text(&result);
assert!(
text.contains("python.md"),
"fuzzy should match 'pyhton' to 'python'"
);
}
#[tokio::test]
async fn search_text_tantivy_field_filter() {
let (_dir, vault) = setup_tantivy_vault().await;
let result = search_text(
&vault,
SearchTextParams {
query: "cooking".into(),
fields: Some(vec![SearchField::Title]),
..Default::default()
},
)
.await
.unwrap();
let text = extract_text(&result);
assert!(
text.contains("cooking.md"),
"title search for 'cooking' should find cooking.md"
);
}
#[tokio::test]
async fn search_text_tantivy_context_snippets() {
let (_dir, vault) = setup_tantivy_vault().await;
let result = search_text(
&vault,
SearchTextParams {
query: "pasta".into(),
context_length: Some(50),
..Default::default()
},
)
.await
.unwrap();
let text = extract_text(&result);
let parsed: Vec<serde_json::Value> = serde_json::from_str(text).unwrap();
assert!(!parsed.is_empty());
let matches = parsed[0]["matches"].as_array().unwrap();
assert!(!matches.is_empty(), "should have context matches");
assert!(matches[0]["context"].as_str().unwrap().contains("pasta"));
}
#[test]
fn semantic_params_defaults() {
let params: SearchSemanticParams = serde_json::from_str(r#"{"query": "test"}"#).unwrap();
assert_eq!(params.query, "test");
assert!(params.alpha.is_none());
assert!(params.lexical_prefetch.is_none());
assert!(params.top_k.is_none());
}
#[test]
fn semantic_params_with_alpha() {
let params: SearchSemanticParams =
serde_json::from_str(r#"{"query": "q", "alpha": 0.7, "lexical_prefetch": true}"#)
.unwrap();
assert!((params.alpha.unwrap() - 0.7).abs() < f32::EPSILON);
assert_eq!(params.lexical_prefetch, Some(true));
}
#[cfg(feature = "embeddings")]
#[test]
fn body_preview_strips_frontmatter() {
let content = "---\ntags: [a]\n---\nHello world";
let preview = super::body_preview(content, 100);
assert_eq!(preview, "Hello world");
}
#[cfg(feature = "embeddings")]
#[test]
fn body_preview_no_frontmatter() {
let content = "# Title\nSome body text";
let preview = super::body_preview(content, 100);
assert_eq!(preview, "# Title\nSome body text");
}
#[cfg(feature = "embeddings")]
#[test]
fn body_preview_truncates() {
let content = "---\nk: v\n---\nABCDEFGHIJ";
let preview = super::body_preview(content, 5);
assert_eq!(preview, "ABCDE");
}
#[cfg(feature = "embeddings")]
#[test]
fn body_preview_empty_content() {
let preview = super::body_preview("", 100);
assert_eq!(preview, "");
}
#[cfg(feature = "embeddings")]
#[test]
fn body_preview_unclosed_frontmatter() {
let content = "---\ntags: [a]\nNo closing delimiter here";
let preview = super::body_preview(content, 200);
assert!(preview.contains("tags:"));
}
}