1use std::collections::HashMap;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use serde_json::Value;
6use tokio::sync::RwLock;
7
8use roboticus_core::RiskLevel;
9
10use crate::obsidian::ObsidianVault;
11use crate::tools::{Tool, ToolContext, ToolError, ToolResult};
12
13pub struct ObsidianReadTool {
18 vault: Arc<RwLock<ObsidianVault>>,
19}
20
21impl ObsidianReadTool {
22 pub fn new(vault: Arc<RwLock<ObsidianVault>>) -> Self {
23 Self { vault }
24 }
25}
26
27#[async_trait]
28impl Tool for ObsidianReadTool {
29 fn name(&self) -> &str {
30 "obsidian_read"
31 }
32
33 fn description(&self) -> &str {
34 "Read a note from the user's Obsidian vault by path or title. \
35 Returns the note content with frontmatter metadata, tags, and backlink count."
36 }
37
38 fn risk_level(&self) -> RiskLevel {
39 RiskLevel::Safe
40 }
41
42 fn parameters_schema(&self) -> Value {
43 serde_json::json!({
44 "type": "object",
45 "properties": {
46 "path": {
47 "type": "string",
48 "description": "Relative path to the note within the vault (e.g. 'folder/note.md')"
49 },
50 "title": {
51 "type": "string",
52 "description": "Note title to search for (case-insensitive wikilink resolution)"
53 }
54 }
55 })
56 }
57
58 async fn execute(
59 &self,
60 params: Value,
61 _ctx: &ToolContext,
62 ) -> std::result::Result<ToolResult, ToolError> {
63 let vault = self.vault.read().await;
64
65 let note = if let Some(path) = params.get("path").and_then(|v| v.as_str()) {
66 if path.contains("..") || std::path::Path::new(path).is_absolute() {
67 return Err(ToolError {
68 message: "path must be relative and must not contain '..'".into(),
69 });
70 }
71 vault.get_note(path).cloned()
72 } else if let Some(title) = params.get("title").and_then(|v| v.as_str()) {
73 vault
74 .resolve_wikilink(title)
75 .and_then(|p| vault.get_note(&p.to_string_lossy()).cloned())
76 } else {
77 return Err(ToolError {
78 message: "either 'path' or 'title' parameter is required".into(),
79 });
80 };
81
82 match note {
83 Some(note) => {
84 let rel_path = note
85 .path
86 .strip_prefix(&vault.root)
87 .unwrap_or(¬e.path)
88 .to_string_lossy()
89 .to_string();
90
91 let backlink_count = vault.backlinks_for(&rel_path).len();
92 let uri = vault.obsidian_uri(&rel_path);
93
94 let metadata = serde_json::json!({
95 "path": rel_path,
96 "title": note.title,
97 "tags": note.tags,
98 "backlink_count": backlink_count,
99 "obsidian_uri": uri,
100 "frontmatter": note.frontmatter,
101 "created_at": note.created_at,
102 "modified_at": note.modified_at,
103 });
104
105 Ok(ToolResult {
106 output: note.content,
107 metadata: Some(metadata),
108 })
109 }
110 None => Err(ToolError {
111 message: "note not found in vault".into(),
112 }),
113 }
114 }
115}
116
117pub struct ObsidianWriteTool {
122 vault: Arc<RwLock<ObsidianVault>>,
123}
124
125impl ObsidianWriteTool {
126 pub fn new(vault: Arc<RwLock<ObsidianVault>>) -> Self {
127 Self { vault }
128 }
129}
130
131#[async_trait]
132impl Tool for ObsidianWriteTool {
133 fn name(&self) -> &str {
134 "obsidian_write"
135 }
136
137 fn description(&self) -> &str {
138 "Write a document to the user's Obsidian vault. This is the preferred destination \
139 for producing documents, reports, notes, and any persistent written output. \
140 Returns the file path and an obsidian:// URI the user can click to open it."
141 }
142
143 fn risk_level(&self) -> RiskLevel {
144 RiskLevel::Caution
145 }
146
147 fn parameters_schema(&self) -> Value {
148 serde_json::json!({
149 "type": "object",
150 "properties": {
151 "path": {
152 "type": "string",
153 "description": "Relative path for the note (e.g. 'projects/report.md'). \
154 If no folder prefix, writes to the default agent folder."
155 },
156 "content": {
157 "type": "string",
158 "description": "Markdown content for the note"
159 },
160 "tags": {
161 "type": "array",
162 "items": { "type": "string" },
163 "description": "Tags to include in YAML frontmatter"
164 },
165 "template": {
166 "type": "string",
167 "description": "Name of an Obsidian template to apply before writing"
168 },
169 "frontmatter": {
170 "type": "object",
171 "description": "Additional YAML frontmatter fields"
172 }
173 },
174 "required": ["path", "content"]
175 })
176 }
177
178 async fn execute(
179 &self,
180 params: Value,
181 _ctx: &ToolContext,
182 ) -> std::result::Result<ToolResult, ToolError> {
183 let path = params
184 .get("path")
185 .and_then(|v| v.as_str())
186 .ok_or_else(|| ToolError {
187 message: "missing 'path' parameter".into(),
188 })?;
189
190 let content = params
191 .get("content")
192 .and_then(|v| v.as_str())
193 .ok_or_else(|| ToolError {
194 message: "missing 'content' parameter".into(),
195 })?;
196
197 let mut vault = self.vault.write().await;
198
199 let final_content =
201 if let Some(template_name) = params.get("template").and_then(|v| v.as_str()) {
202 let mut vars = HashMap::new();
203 vars.insert("title".into(), path_to_title(path));
204 vars.insert("content".into(), content.to_string());
205
206 match vault.apply_template(template_name, &vars) {
207 Ok(rendered) => rendered,
208 Err(e) => {
209 return Err(ToolError {
210 message: format!("template error: {e}"),
211 });
212 }
213 }
214 } else {
215 content.to_string()
216 };
217
218 let fm = {
220 let mut obj = if let Some(Value::Object(m)) = params.get("frontmatter") {
221 serde_json::Value::Object(m.clone())
222 } else {
223 serde_json::json!({})
224 };
225
226 if let Some(Value::Array(arr)) = params.get("tags")
227 && let Some(map) = obj.as_object_mut()
228 {
229 map.insert("tags".into(), Value::Array(arr.clone()));
230 }
231
232 Some(obj)
233 };
234
235 match vault.write_note(path, &final_content, fm) {
236 Ok(abs_path) => {
237 let rel = abs_path
238 .strip_prefix(&vault.root)
239 .unwrap_or(&abs_path)
240 .to_string_lossy()
241 .to_string();
242 let uri = vault.obsidian_uri(&rel);
243
244 Ok(ToolResult {
245 output: format!("Note written to {rel}\n\nOpen in Obsidian: {uri}"),
246 metadata: Some(serde_json::json!({
247 "path": rel,
248 "absolute_path": abs_path.display().to_string(),
249 "obsidian_uri": uri,
250 })),
251 })
252 }
253 Err(e) => Err(ToolError {
254 message: format!("failed to write note: {e}"),
255 }),
256 }
257 }
258}
259
260fn path_to_title(path: &str) -> String {
261 std::path::Path::new(path)
262 .file_stem()
263 .and_then(|s| s.to_str())
264 .unwrap_or(path)
265 .to_string()
266}
267
268pub struct ObsidianSearchTool {
273 vault: Arc<RwLock<ObsidianVault>>,
274}
275
276impl ObsidianSearchTool {
277 pub fn new(vault: Arc<RwLock<ObsidianVault>>) -> Self {
278 Self { vault }
279 }
280}
281
282#[async_trait]
283impl Tool for ObsidianSearchTool {
284 fn name(&self) -> &str {
285 "obsidian_search"
286 }
287
288 fn description(&self) -> &str {
289 "Search the user's Obsidian vault by content query, tags, or folder. \
290 Returns matching notes with titles, paths, tags, and relevance scores."
291 }
292
293 fn risk_level(&self) -> RiskLevel {
294 RiskLevel::Safe
295 }
296
297 fn parameters_schema(&self) -> Value {
298 serde_json::json!({
299 "type": "object",
300 "properties": {
301 "query": {
302 "type": "string",
303 "description": "Full-text search query"
304 },
305 "tags": {
306 "type": "array",
307 "items": { "type": "string" },
308 "description": "Filter by tags (notes must have at least one matching tag)"
309 },
310 "folder": {
311 "type": "string",
312 "description": "Restrict search to a specific folder within the vault"
313 },
314 "limit": {
315 "type": "integer",
316 "description": "Maximum number of results (default 10)"
317 }
318 }
319 })
320 }
321
322 async fn execute(
323 &self,
324 params: Value,
325 _ctx: &ToolContext,
326 ) -> std::result::Result<ToolResult, ToolError> {
327 let vault = self.vault.read().await;
328
329 let limit = params.get("limit").and_then(|v| v.as_u64()).unwrap_or(10) as usize;
330
331 let query = params.get("query").and_then(|v| v.as_str());
332 let tags: Vec<String> = params
333 .get("tags")
334 .and_then(|v| v.as_array())
335 .map(|arr| {
336 arr.iter()
337 .filter_map(|v| v.as_str().map(|s| s.to_string()))
338 .collect()
339 })
340 .unwrap_or_default();
341 let folder = params.get("folder").and_then(|v| v.as_str());
342
343 if query.is_none() && tags.is_empty() && folder.is_none() {
344 return Err(ToolError {
345 message: "at least one of 'query', 'tags', or 'folder' is required".into(),
346 });
347 }
348
349 let mut results: Vec<Value> = Vec::new();
350
351 if let Some(q) = query {
352 let search_results = vault.search_by_content(q, limit);
353 for (key, note, score) in search_results {
354 if let Some(f) = folder
355 && !key.starts_with(f)
356 {
357 continue;
358 }
359
360 if !tags.is_empty()
361 && !tags.iter().any(|t| {
362 note.tags
363 .iter()
364 .any(|nt| nt.to_lowercase() == t.to_lowercase())
365 })
366 {
367 continue;
368 }
369
370 results.push(serde_json::json!({
371 "path": key,
372 "title": note.title,
373 "tags": note.tags,
374 "relevance": score,
375 "obsidian_uri": vault.obsidian_uri(key),
376 "preview": truncate_content(¬e.content, 200),
377 }));
378
379 if results.len() >= limit {
380 break;
381 }
382 }
383 } else {
384 let mut matching: Vec<(&str, &crate::obsidian::ObsidianNote)> = if !tags.is_empty() {
386 let tag_results: Vec<_> = tags
387 .iter()
388 .flat_map(|t| {
389 vault
390 .search_by_tag(t)
391 .into_iter()
392 .map(|n| n.title.clone())
393 .collect::<Vec<_>>()
394 })
395 .collect();
396
397 vault
398 .notes_in_folder(folder.unwrap_or(""))
399 .into_iter()
400 .filter(|(_, n)| tag_results.contains(&n.title))
401 .collect()
402 } else if let Some(f) = folder {
403 vault.notes_in_folder(f)
404 } else {
405 Vec::new()
406 };
407
408 matching.truncate(limit);
409
410 for (key, note) in matching {
411 results.push(serde_json::json!({
412 "path": key,
413 "title": note.title,
414 "tags": note.tags,
415 "obsidian_uri": vault.obsidian_uri(key),
416 "preview": truncate_content(¬e.content, 200),
417 }));
418 }
419 }
420
421 let output = serde_json::to_string_pretty(&serde_json::json!({
422 "count": results.len(),
423 "results": results,
424 }))
425 .unwrap_or_else(|_| "[]".into());
426
427 Ok(ToolResult {
428 output,
429 metadata: Some(serde_json::json!({ "result_count": results.len() })),
430 })
431 }
432}
433
434fn truncate_content(s: &str, max: usize) -> String {
435 if s.len() <= max {
436 s.to_string()
437 } else {
438 let boundary = s.floor_char_boundary(max);
439 format!("{}...", &s[..boundary])
440 }
441}
442
443#[cfg(test)]
448mod tests {
449 use super::*;
450 use crate::obsidian::ObsidianVault;
451 use roboticus_core::InputAuthority;
452 use roboticus_core::config::ObsidianConfig;
453 use std::fs;
454 use tempfile::TempDir;
455
456 fn test_ctx() -> ToolContext {
457 ToolContext {
458 session_id: "test-session".into(),
459 agent_id: "test-agent".into(),
460 agent_name: "test-agent".into(),
461 authority: InputAuthority::Creator,
462 workspace_root: std::env::current_dir().unwrap(),
463 tool_allowed_paths: vec![],
464 channel: None,
465 db: None,
466 sandbox: crate::tools::ToolSandboxSnapshot::default(),
467 }
468 }
469
470 fn setup_vault() -> (TempDir, Arc<RwLock<ObsidianVault>>) {
471 let dir = TempDir::new().unwrap();
472 fs::create_dir(dir.path().join(".obsidian")).unwrap();
473 fs::create_dir(dir.path().join("roboticus")).unwrap();
474 fs::write(
475 dir.path().join("existing.md"),
476 "---\ntags:\n - test\n---\n\nExisting note content about Rust",
477 )
478 .unwrap();
479
480 let config = ObsidianConfig {
481 enabled: true,
482 vault_path: Some(dir.path().to_path_buf()),
483 index_on_start: true,
484 ..Default::default()
485 };
486
487 let vault = ObsidianVault::from_config(&config).unwrap();
488 (dir, Arc::new(RwLock::new(vault)))
489 }
490
491 #[tokio::test]
492 async fn read_tool_by_path() {
493 let (_dir, vault) = setup_vault();
494 let tool = ObsidianReadTool::new(vault);
495 let ctx = test_ctx();
496
497 let result = tool
498 .execute(serde_json::json!({ "path": "existing.md" }), &ctx)
499 .await
500 .unwrap();
501
502 assert!(result.output.contains("Existing note content"));
503 let meta = result.metadata.unwrap();
504 assert_eq!(meta["title"], "existing");
505 }
506
507 #[tokio::test]
508 async fn read_tool_by_title() {
509 let (_dir, vault) = setup_vault();
510 let tool = ObsidianReadTool::new(vault);
511 let ctx = test_ctx();
512
513 let result = tool
514 .execute(serde_json::json!({ "title": "existing" }), &ctx)
515 .await
516 .unwrap();
517
518 assert!(result.output.contains("Existing note content"));
519 }
520
521 #[tokio::test]
522 async fn read_tool_not_found() {
523 let (_dir, vault) = setup_vault();
524 let tool = ObsidianReadTool::new(vault);
525 let ctx = test_ctx();
526
527 let err = tool
528 .execute(serde_json::json!({ "path": "nonexistent.md" }), &ctx)
529 .await
530 .unwrap_err();
531
532 assert!(err.message.contains("not found"));
533 }
534
535 #[tokio::test]
536 async fn read_tool_missing_params() {
537 let (_dir, vault) = setup_vault();
538 let tool = ObsidianReadTool::new(vault);
539 let ctx = test_ctx();
540
541 let err = tool.execute(serde_json::json!({}), &ctx).await.unwrap_err();
542
543 assert!(err.message.contains("required"));
544 }
545
546 #[tokio::test]
547 async fn write_tool_creates_note() {
548 let (dir, vault) = setup_vault();
549 let tool = ObsidianWriteTool::new(vault);
550 let ctx = test_ctx();
551
552 let result = tool
553 .execute(
554 serde_json::json!({
555 "path": "new-note",
556 "content": "Hello from the write tool",
557 "tags": ["test", "automated"]
558 }),
559 &ctx,
560 )
561 .await
562 .unwrap();
563
564 assert!(result.output.contains("Note written to"));
565 assert!(result.output.contains("obsidian://"));
566
567 let meta = result.metadata.unwrap();
568 assert!(
569 meta["obsidian_uri"]
570 .as_str()
571 .unwrap()
572 .starts_with("obsidian://")
573 );
574
575 let written = dir.path().join("roboticus/new-note.md");
576 assert!(written.exists());
577 let content = fs::read_to_string(&written).unwrap();
578 assert!(content.contains("Hello from the write tool"));
579 assert!(content.contains("created_by"));
580 }
581
582 #[tokio::test]
583 async fn write_tool_missing_content() {
584 let (_dir, vault) = setup_vault();
585 let tool = ObsidianWriteTool::new(vault);
586 let ctx = test_ctx();
587
588 let err = tool
589 .execute(serde_json::json!({ "path": "test" }), &ctx)
590 .await
591 .unwrap_err();
592
593 assert!(err.message.contains("content"));
594 }
595
596 #[tokio::test]
597 async fn search_tool_by_query() {
598 let (_dir, vault) = setup_vault();
599 let tool = ObsidianSearchTool::new(vault);
600 let ctx = test_ctx();
601
602 let result = tool
603 .execute(serde_json::json!({ "query": "Rust" }), &ctx)
604 .await
605 .unwrap();
606
607 let parsed: Value = serde_json::from_str(&result.output).unwrap();
608 assert!(parsed["count"].as_u64().unwrap() >= 1);
609 }
610
611 #[tokio::test]
612 async fn search_tool_by_tag() {
613 let (_dir, vault) = setup_vault();
614 let tool = ObsidianSearchTool::new(vault);
615 let ctx = test_ctx();
616
617 let result = tool
618 .execute(serde_json::json!({ "tags": ["test"] }), &ctx)
619 .await
620 .unwrap();
621
622 let parsed: Value = serde_json::from_str(&result.output).unwrap();
623 assert!(parsed["count"].as_u64().unwrap() >= 1);
624 }
625
626 #[tokio::test]
627 async fn search_tool_no_params() {
628 let (_dir, vault) = setup_vault();
629 let tool = ObsidianSearchTool::new(vault);
630 let ctx = test_ctx();
631
632 let err = tool.execute(serde_json::json!({}), &ctx).await.unwrap_err();
633
634 assert!(err.message.contains("required"));
635 }
636}