Skip to main content

roboticus_agent/
obsidian_tools.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use serde_json::Value;
6use tokio::sync::RwLock;
7
8use roboticus_core::RiskLevel;
9
10use crate::obsidian::ObsidianVault;
11use crate::tools::{Tool, ToolContext, ToolError, ToolResult};
12
13// ---------------------------------------------------------------------------
14// ObsidianReadTool
15// ---------------------------------------------------------------------------
16
17pub struct ObsidianReadTool {
18    vault: Arc<RwLock<ObsidianVault>>,
19}
20
21impl ObsidianReadTool {
22    pub fn new(vault: Arc<RwLock<ObsidianVault>>) -> Self {
23        Self { vault }
24    }
25}
26
27#[async_trait]
28impl Tool for ObsidianReadTool {
29    fn name(&self) -> &str {
30        "obsidian_read"
31    }
32
33    fn description(&self) -> &str {
34        "Read a note from the user's Obsidian vault by path or title. \
35         Returns the note content with frontmatter metadata, tags, and backlink count."
36    }
37
38    fn risk_level(&self) -> RiskLevel {
39        RiskLevel::Safe
40    }
41
42    fn parameters_schema(&self) -> Value {
43        serde_json::json!({
44            "type": "object",
45            "properties": {
46                "path": {
47                    "type": "string",
48                    "description": "Relative path to the note within the vault (e.g. 'folder/note.md')"
49                },
50                "title": {
51                    "type": "string",
52                    "description": "Note title to search for (case-insensitive wikilink resolution)"
53                }
54            }
55        })
56    }
57
58    async fn execute(
59        &self,
60        params: Value,
61        _ctx: &ToolContext,
62    ) -> std::result::Result<ToolResult, ToolError> {
63        let vault = self.vault.read().await;
64
65        let note = if let Some(path) = params.get("path").and_then(|v| v.as_str()) {
66            if path.contains("..") || std::path::Path::new(path).is_absolute() {
67                return Err(ToolError {
68                    message: "path must be relative and must not contain '..'".into(),
69                });
70            }
71            vault.get_note(path).cloned()
72        } else if let Some(title) = params.get("title").and_then(|v| v.as_str()) {
73            vault
74                .resolve_wikilink(title)
75                .and_then(|p| vault.get_note(&p.to_string_lossy()).cloned())
76        } else {
77            return Err(ToolError {
78                message: "either 'path' or 'title' parameter is required".into(),
79            });
80        };
81
82        match note {
83            Some(note) => {
84                let rel_path = note
85                    .path
86                    .strip_prefix(&vault.root)
87                    .unwrap_or(&note.path)
88                    .to_string_lossy()
89                    .to_string();
90
91                let backlink_count = vault.backlinks_for(&rel_path).len();
92                let uri = vault.obsidian_uri(&rel_path);
93
94                let metadata = serde_json::json!({
95                    "path": rel_path,
96                    "title": note.title,
97                    "tags": note.tags,
98                    "backlink_count": backlink_count,
99                    "obsidian_uri": uri,
100                    "frontmatter": note.frontmatter,
101                    "created_at": note.created_at,
102                    "modified_at": note.modified_at,
103                });
104
105                Ok(ToolResult {
106                    output: note.content,
107                    metadata: Some(metadata),
108                })
109            }
110            None => Err(ToolError {
111                message: "note not found in vault".into(),
112            }),
113        }
114    }
115}
116
117// ---------------------------------------------------------------------------
118// ObsidianWriteTool
119// ---------------------------------------------------------------------------
120
121pub struct ObsidianWriteTool {
122    vault: Arc<RwLock<ObsidianVault>>,
123}
124
125impl ObsidianWriteTool {
126    pub fn new(vault: Arc<RwLock<ObsidianVault>>) -> Self {
127        Self { vault }
128    }
129}
130
131#[async_trait]
132impl Tool for ObsidianWriteTool {
133    fn name(&self) -> &str {
134        "obsidian_write"
135    }
136
137    fn description(&self) -> &str {
138        "Write a document to the user's Obsidian vault. This is the preferred destination \
139         for producing documents, reports, notes, and any persistent written output. \
140         Returns the file path and an obsidian:// URI the user can click to open it."
141    }
142
143    fn risk_level(&self) -> RiskLevel {
144        RiskLevel::Caution
145    }
146
147    fn parameters_schema(&self) -> Value {
148        serde_json::json!({
149            "type": "object",
150            "properties": {
151                "path": {
152                    "type": "string",
153                    "description": "Relative path for the note (e.g. 'projects/report.md'). \
154                                    If no folder prefix, writes to the default agent folder."
155                },
156                "content": {
157                    "type": "string",
158                    "description": "Markdown content for the note"
159                },
160                "tags": {
161                    "type": "array",
162                    "items": { "type": "string" },
163                    "description": "Tags to include in YAML frontmatter"
164                },
165                "template": {
166                    "type": "string",
167                    "description": "Name of an Obsidian template to apply before writing"
168                },
169                "frontmatter": {
170                    "type": "object",
171                    "description": "Additional YAML frontmatter fields"
172                }
173            },
174            "required": ["path", "content"]
175        })
176    }
177
178    async fn execute(
179        &self,
180        params: Value,
181        _ctx: &ToolContext,
182    ) -> std::result::Result<ToolResult, ToolError> {
183        let path = params
184            .get("path")
185            .and_then(|v| v.as_str())
186            .ok_or_else(|| ToolError {
187                message: "missing 'path' parameter".into(),
188            })?;
189
190        let content = params
191            .get("content")
192            .and_then(|v| v.as_str())
193            .ok_or_else(|| ToolError {
194                message: "missing 'content' parameter".into(),
195            })?;
196
197        let mut vault = self.vault.write().await;
198
199        // Apply template if specified
200        let final_content =
201            if let Some(template_name) = params.get("template").and_then(|v| v.as_str()) {
202                let mut vars = HashMap::new();
203                vars.insert("title".into(), path_to_title(path));
204                vars.insert("content".into(), content.to_string());
205
206                match vault.apply_template(template_name, &vars) {
207                    Ok(rendered) => rendered,
208                    Err(e) => {
209                        return Err(ToolError {
210                            message: format!("template error: {e}"),
211                        });
212                    }
213                }
214            } else {
215                content.to_string()
216            };
217
218        // Build frontmatter
219        let fm = {
220            let mut obj = if let Some(Value::Object(m)) = params.get("frontmatter") {
221                serde_json::Value::Object(m.clone())
222            } else {
223                serde_json::json!({})
224            };
225
226            if let Some(Value::Array(arr)) = params.get("tags")
227                && let Some(map) = obj.as_object_mut()
228            {
229                map.insert("tags".into(), Value::Array(arr.clone()));
230            }
231
232            Some(obj)
233        };
234
235        match vault.write_note(path, &final_content, fm) {
236            Ok(abs_path) => {
237                let rel = abs_path
238                    .strip_prefix(&vault.root)
239                    .unwrap_or(&abs_path)
240                    .to_string_lossy()
241                    .to_string();
242                let uri = vault.obsidian_uri(&rel);
243
244                Ok(ToolResult {
245                    output: format!("Note written to {rel}\n\nOpen in Obsidian: {uri}"),
246                    metadata: Some(serde_json::json!({
247                        "path": rel,
248                        "absolute_path": abs_path.display().to_string(),
249                        "obsidian_uri": uri,
250                    })),
251                })
252            }
253            Err(e) => Err(ToolError {
254                message: format!("failed to write note: {e}"),
255            }),
256        }
257    }
258}
259
260fn path_to_title(path: &str) -> String {
261    std::path::Path::new(path)
262        .file_stem()
263        .and_then(|s| s.to_str())
264        .unwrap_or(path)
265        .to_string()
266}
267
268// ---------------------------------------------------------------------------
269// ObsidianSearchTool
270// ---------------------------------------------------------------------------
271
272pub struct ObsidianSearchTool {
273    vault: Arc<RwLock<ObsidianVault>>,
274}
275
276impl ObsidianSearchTool {
277    pub fn new(vault: Arc<RwLock<ObsidianVault>>) -> Self {
278        Self { vault }
279    }
280}
281
282#[async_trait]
283impl Tool for ObsidianSearchTool {
284    fn name(&self) -> &str {
285        "obsidian_search"
286    }
287
288    fn description(&self) -> &str {
289        "Search the user's Obsidian vault by content query, tags, or folder. \
290         Returns matching notes with titles, paths, tags, and relevance scores."
291    }
292
293    fn risk_level(&self) -> RiskLevel {
294        RiskLevel::Safe
295    }
296
297    fn parameters_schema(&self) -> Value {
298        serde_json::json!({
299            "type": "object",
300            "properties": {
301                "query": {
302                    "type": "string",
303                    "description": "Full-text search query"
304                },
305                "tags": {
306                    "type": "array",
307                    "items": { "type": "string" },
308                    "description": "Filter by tags (notes must have at least one matching tag)"
309                },
310                "folder": {
311                    "type": "string",
312                    "description": "Restrict search to a specific folder within the vault"
313                },
314                "limit": {
315                    "type": "integer",
316                    "description": "Maximum number of results (default 10)"
317                }
318            }
319        })
320    }
321
322    async fn execute(
323        &self,
324        params: Value,
325        _ctx: &ToolContext,
326    ) -> std::result::Result<ToolResult, ToolError> {
327        let vault = self.vault.read().await;
328
329        let limit = params.get("limit").and_then(|v| v.as_u64()).unwrap_or(10) as usize;
330
331        let query = params.get("query").and_then(|v| v.as_str());
332        let tags: Vec<String> = params
333            .get("tags")
334            .and_then(|v| v.as_array())
335            .map(|arr| {
336                arr.iter()
337                    .filter_map(|v| v.as_str().map(|s| s.to_string()))
338                    .collect()
339            })
340            .unwrap_or_default();
341        let folder = params.get("folder").and_then(|v| v.as_str());
342
343        if query.is_none() && tags.is_empty() && folder.is_none() {
344            return Err(ToolError {
345                message: "at least one of 'query', 'tags', or 'folder' is required".into(),
346            });
347        }
348
349        let mut results: Vec<Value> = Vec::new();
350
351        if let Some(q) = query {
352            let search_results = vault.search_by_content(q, limit);
353            for (key, note, score) in search_results {
354                if let Some(f) = folder
355                    && !key.starts_with(f)
356                {
357                    continue;
358                }
359
360                if !tags.is_empty()
361                    && !tags.iter().any(|t| {
362                        note.tags
363                            .iter()
364                            .any(|nt| nt.to_lowercase() == t.to_lowercase())
365                    })
366                {
367                    continue;
368                }
369
370                results.push(serde_json::json!({
371                    "path": key,
372                    "title": note.title,
373                    "tags": note.tags,
374                    "relevance": score,
375                    "obsidian_uri": vault.obsidian_uri(key),
376                    "preview": truncate_content(&note.content, 200),
377                }));
378
379                if results.len() >= limit {
380                    break;
381                }
382            }
383        } else {
384            // Tag-only or folder-only search
385            let mut matching: Vec<(&str, &crate::obsidian::ObsidianNote)> = if !tags.is_empty() {
386                let tag_results: Vec<_> = tags
387                    .iter()
388                    .flat_map(|t| {
389                        vault
390                            .search_by_tag(t)
391                            .into_iter()
392                            .map(|n| n.title.clone())
393                            .collect::<Vec<_>>()
394                    })
395                    .collect();
396
397                vault
398                    .notes_in_folder(folder.unwrap_or(""))
399                    .into_iter()
400                    .filter(|(_, n)| tag_results.contains(&n.title))
401                    .collect()
402            } else if let Some(f) = folder {
403                vault.notes_in_folder(f)
404            } else {
405                Vec::new()
406            };
407
408            matching.truncate(limit);
409
410            for (key, note) in matching {
411                results.push(serde_json::json!({
412                    "path": key,
413                    "title": note.title,
414                    "tags": note.tags,
415                    "obsidian_uri": vault.obsidian_uri(key),
416                    "preview": truncate_content(&note.content, 200),
417                }));
418            }
419        }
420
421        let output = serde_json::to_string_pretty(&serde_json::json!({
422            "count": results.len(),
423            "results": results,
424        }))
425        .unwrap_or_else(|_| "[]".into());
426
427        Ok(ToolResult {
428            output,
429            metadata: Some(serde_json::json!({ "result_count": results.len() })),
430        })
431    }
432}
433
434fn truncate_content(s: &str, max: usize) -> String {
435    if s.len() <= max {
436        s.to_string()
437    } else {
438        let boundary = s.floor_char_boundary(max);
439        format!("{}...", &s[..boundary])
440    }
441}
442
443// ---------------------------------------------------------------------------
444// Tests
445// ---------------------------------------------------------------------------
446
447#[cfg(test)]
448mod tests {
449    use super::*;
450    use crate::obsidian::ObsidianVault;
451    use roboticus_core::InputAuthority;
452    use roboticus_core::config::ObsidianConfig;
453    use std::fs;
454    use tempfile::TempDir;
455
456    fn test_ctx() -> ToolContext {
457        ToolContext {
458            session_id: "test-session".into(),
459            agent_id: "test-agent".into(),
460            agent_name: "test-agent".into(),
461            authority: InputAuthority::Creator,
462            workspace_root: std::env::current_dir().unwrap(),
463            tool_allowed_paths: vec![],
464            channel: None,
465            db: None,
466            sandbox: crate::tools::ToolSandboxSnapshot::default(),
467        }
468    }
469
470    fn setup_vault() -> (TempDir, Arc<RwLock<ObsidianVault>>) {
471        let dir = TempDir::new().unwrap();
472        fs::create_dir(dir.path().join(".obsidian")).unwrap();
473        fs::create_dir(dir.path().join("roboticus")).unwrap();
474        fs::write(
475            dir.path().join("existing.md"),
476            "---\ntags:\n  - test\n---\n\nExisting note content about Rust",
477        )
478        .unwrap();
479
480        let config = ObsidianConfig {
481            enabled: true,
482            vault_path: Some(dir.path().to_path_buf()),
483            index_on_start: true,
484            ..Default::default()
485        };
486
487        let vault = ObsidianVault::from_config(&config).unwrap();
488        (dir, Arc::new(RwLock::new(vault)))
489    }
490
491    #[tokio::test]
492    async fn read_tool_by_path() {
493        let (_dir, vault) = setup_vault();
494        let tool = ObsidianReadTool::new(vault);
495        let ctx = test_ctx();
496
497        let result = tool
498            .execute(serde_json::json!({ "path": "existing.md" }), &ctx)
499            .await
500            .unwrap();
501
502        assert!(result.output.contains("Existing note content"));
503        let meta = result.metadata.unwrap();
504        assert_eq!(meta["title"], "existing");
505    }
506
507    #[tokio::test]
508    async fn read_tool_by_title() {
509        let (_dir, vault) = setup_vault();
510        let tool = ObsidianReadTool::new(vault);
511        let ctx = test_ctx();
512
513        let result = tool
514            .execute(serde_json::json!({ "title": "existing" }), &ctx)
515            .await
516            .unwrap();
517
518        assert!(result.output.contains("Existing note content"));
519    }
520
521    #[tokio::test]
522    async fn read_tool_not_found() {
523        let (_dir, vault) = setup_vault();
524        let tool = ObsidianReadTool::new(vault);
525        let ctx = test_ctx();
526
527        let err = tool
528            .execute(serde_json::json!({ "path": "nonexistent.md" }), &ctx)
529            .await
530            .unwrap_err();
531
532        assert!(err.message.contains("not found"));
533    }
534
535    #[tokio::test]
536    async fn read_tool_missing_params() {
537        let (_dir, vault) = setup_vault();
538        let tool = ObsidianReadTool::new(vault);
539        let ctx = test_ctx();
540
541        let err = tool.execute(serde_json::json!({}), &ctx).await.unwrap_err();
542
543        assert!(err.message.contains("required"));
544    }
545
546    #[tokio::test]
547    async fn write_tool_creates_note() {
548        let (dir, vault) = setup_vault();
549        let tool = ObsidianWriteTool::new(vault);
550        let ctx = test_ctx();
551
552        let result = tool
553            .execute(
554                serde_json::json!({
555                    "path": "new-note",
556                    "content": "Hello from the write tool",
557                    "tags": ["test", "automated"]
558                }),
559                &ctx,
560            )
561            .await
562            .unwrap();
563
564        assert!(result.output.contains("Note written to"));
565        assert!(result.output.contains("obsidian://"));
566
567        let meta = result.metadata.unwrap();
568        assert!(
569            meta["obsidian_uri"]
570                .as_str()
571                .unwrap()
572                .starts_with("obsidian://")
573        );
574
575        let written = dir.path().join("roboticus/new-note.md");
576        assert!(written.exists());
577        let content = fs::read_to_string(&written).unwrap();
578        assert!(content.contains("Hello from the write tool"));
579        assert!(content.contains("created_by"));
580    }
581
582    #[tokio::test]
583    async fn write_tool_missing_content() {
584        let (_dir, vault) = setup_vault();
585        let tool = ObsidianWriteTool::new(vault);
586        let ctx = test_ctx();
587
588        let err = tool
589            .execute(serde_json::json!({ "path": "test" }), &ctx)
590            .await
591            .unwrap_err();
592
593        assert!(err.message.contains("content"));
594    }
595
596    #[tokio::test]
597    async fn search_tool_by_query() {
598        let (_dir, vault) = setup_vault();
599        let tool = ObsidianSearchTool::new(vault);
600        let ctx = test_ctx();
601
602        let result = tool
603            .execute(serde_json::json!({ "query": "Rust" }), &ctx)
604            .await
605            .unwrap();
606
607        let parsed: Value = serde_json::from_str(&result.output).unwrap();
608        assert!(parsed["count"].as_u64().unwrap() >= 1);
609    }
610
611    #[tokio::test]
612    async fn search_tool_by_tag() {
613        let (_dir, vault) = setup_vault();
614        let tool = ObsidianSearchTool::new(vault);
615        let ctx = test_ctx();
616
617        let result = tool
618            .execute(serde_json::json!({ "tags": ["test"] }), &ctx)
619            .await
620            .unwrap();
621
622        let parsed: Value = serde_json::from_str(&result.output).unwrap();
623        assert!(parsed["count"].as_u64().unwrap() >= 1);
624    }
625
626    #[tokio::test]
627    async fn search_tool_no_params() {
628        let (_dir, vault) = setup_vault();
629        let tool = ObsidianSearchTool::new(vault);
630        let ctx = test_ctx();
631
632        let err = tool.execute(serde_json::json!({}), &ctx).await.unwrap_err();
633
634        assert!(err.message.contains("required"));
635    }
636}