Skip to main content

walrus_daemon/hook/
mod.rs

1//! Stateful Hook implementation for the daemon.
2//!
3//! [`DaemonHook`] composes skill, MCP, and OS sub-hooks plus external extension
4//! services. Memory is handled by an external extension service.
5//! `on_build_agent` delegates to skills and extension services;
6//! `on_register_tools` delegates to all sub-hooks in sequence.
7//! `dispatch_tool` routes every agent tool call by name — the single
8//! entry point from `event.rs`.
9
10use crate::{
11    ext::hub::DownloadRegistry,
12    hook::{mcp::McpHandler, os::PermissionConfig, skill::SkillHandler, task::TaskRegistry},
13    service::ServiceRegistry,
14};
15use compact_str::CompactString;
16use std::{collections::BTreeMap, sync::Arc};
17use tokio::sync::Mutex;
18use wcore::{AgentConfig, AgentEvent, Hook, ToolRegistry, model::Message};
19
20pub mod mcp;
21pub mod os;
22pub mod skill;
23pub mod task;
24
25/// Per-agent scope for dispatch enforcement. Empty vecs = unrestricted.
26#[derive(Default)]
27pub(crate) struct AgentScope {
28    pub(crate) tools: Vec<CompactString>,
29    pub(crate) members: Vec<String>,
30    pub(crate) skills: Vec<String>,
31    pub(crate) mcps: Vec<String>,
32}
33
34pub struct DaemonHook {
35    pub skills: SkillHandler,
36    pub mcp: McpHandler,
37    pub tasks: Arc<Mutex<TaskRegistry>>,
38    pub downloads: Arc<Mutex<DownloadRegistry>>,
39    pub permissions: PermissionConfig,
40    /// Whether the daemon is running as the `walrus` OS user (sandbox active).
41    pub sandboxed: bool,
42    /// Per-agent scope maps, populated during load_agents.
43    pub(crate) scopes: BTreeMap<CompactString, AgentScope>,
44    /// External extension service registry (tools + queries).
45    pub(crate) registry: Option<Arc<ServiceRegistry>>,
46}
47
48/// Base tools always included in every agent's whitelist.
49/// Also bypass permission check when running in sandbox mode.
50const BASE_TOOLS: &[&str] = &["read", "write", "edit", "bash"];
51
52/// Skill discovery/loading tools.
53const SKILL_TOOLS: &[&str] = &["search_skill", "load_skill"];
54
55/// MCP discovery/call tools.
56const MCP_TOOLS: &[&str] = &["search_mcp", "call_mcp_tool"];
57
58/// Task delegation tools.
59const TASK_TOOLS: &[&str] = &[
60    "spawn_task",
61    "check_tasks",
62    "create_task",
63    "ask_user",
64    "await_tasks",
65];
66
67impl DaemonHook {
68    /// Create a new DaemonHook with the given backends.
69    pub fn new(
70        skills: SkillHandler,
71        mcp: McpHandler,
72        tasks: Arc<Mutex<TaskRegistry>>,
73        downloads: Arc<Mutex<DownloadRegistry>>,
74        permissions: PermissionConfig,
75        sandboxed: bool,
76        registry: Option<Arc<ServiceRegistry>>,
77    ) -> Self {
78        Self {
79            skills,
80            mcp,
81            tasks,
82            downloads,
83            permissions,
84            sandboxed,
85            scopes: BTreeMap::new(),
86            registry,
87        }
88    }
89
90    /// Register an agent's scope for dispatch enforcement.
91    pub(crate) fn register_scope(&mut self, name: CompactString, config: &AgentConfig) {
92        self.scopes.insert(
93            name,
94            AgentScope {
95                tools: config.tools.clone(),
96                members: config.members.clone(),
97                skills: config.skills.clone(),
98                mcps: config.mcps.clone(),
99            },
100        );
101    }
102
103    /// Check tool permission. Returns `Some(denied_message)` if denied,
104    /// `None` if allowed.
105    async fn check_perm(
106        &self,
107        name: &str,
108        args: &str,
109        agent: &str,
110        task_id: Option<u64>,
111    ) -> Option<String> {
112        // OS tools bypass permission when running in sandbox mode.
113        if self.sandboxed && BASE_TOOLS.contains(&name) {
114            return None;
115        }
116        use crate::hook::os::ToolPermission;
117        match self.permissions.resolve(agent, name) {
118            ToolPermission::Deny => Some(format!("permission denied: {name}")),
119            ToolPermission::Ask => {
120                if let Some(tid) = task_id {
121                    let summary = if args.len() > 200 {
122                        format!("{}…", &args[..200])
123                    } else {
124                        args.to_string()
125                    };
126                    let question = format!("{name}: {summary}");
127                    let rx = self.tasks.lock().await.block(tid, question);
128                    if let Some(rx) = rx {
129                        match rx.await {
130                            Ok(resp) if resp == "denied" => {
131                                return Some(format!("permission denied: {name}"));
132                            }
133                            Err(_) => {
134                                return Some(format!("permission denied: {name} (inbox dropped)"));
135                            }
136                            _ => {} // approved → proceed
137                        }
138                    }
139                }
140                // No task_id → can't block, treat as Allow.
141                None
142            }
143            ToolPermission::Allow => None,
144        }
145    }
146
147    /// Dispatch to an external extension service if the tool is registered.
148    /// Returns `None` if the tool is not in the registry (fall through to in-process).
149    async fn dispatch_external(
150        &self,
151        name: &str,
152        args: &str,
153        agent: &str,
154        task_id: Option<u64>,
155    ) -> Option<String> {
156        self.registry
157            .as_ref()?
158            .dispatch_tool(name, args, agent, task_id)
159            .await
160    }
161
162    /// Route a tool call by name to the appropriate handler.
163    ///
164    /// This is the single dispatch entry point — `event.rs` calls this
165    /// and never matches on tool names itself. Unrecognised names are
166    /// forwarded to the MCP bridge after a warn-level log.
167    pub async fn dispatch_tool(
168        &self,
169        name: &str,
170        args: &str,
171        agent: &str,
172        task_id: Option<u64>,
173    ) -> String {
174        if let Some(denied) = self.check_perm(name, args, agent, task_id).await {
175            return denied;
176        }
177        // Dispatch enforcement: reject tools not in the agent's whitelist.
178        if let Some(scope) = self.scopes.get(agent)
179            && !scope.tools.is_empty()
180            && !scope.tools.iter().any(|t| t.as_str() == name)
181        {
182            return format!("tool not available: {name}");
183        }
184        match name {
185            "search_mcp" => self.dispatch_search_mcp(args, agent).await,
186            "call_mcp_tool" => self.dispatch_call_mcp_tool(args, agent).await,
187            "search_skill" => self.dispatch_search_skill(args, agent).await,
188            "load_skill" => self.dispatch_load_skill(args, agent).await,
189            "read" => self.dispatch_read(args).await,
190            "write" => self.dispatch_write(args).await,
191            "edit" => self.dispatch_edit(args).await,
192            "bash" => self.dispatch_bash(args).await,
193            "spawn_task" => self.dispatch_spawn_task(args, agent, task_id).await,
194            "check_tasks" => self.dispatch_check_tasks(args).await,
195            "create_task" => self.dispatch_create_task(args, agent).await,
196            "ask_user" => self.dispatch_ask_user(args, task_id).await,
197            "await_tasks" => self.dispatch_await_tasks(args, task_id).await,
198            // External extension services, then MCP bridge as final fallback.
199            name => {
200                if let Some(result) = self.dispatch_external(name, args, agent, task_id).await {
201                    return result;
202                }
203                tracing::debug!(tool = name, "forwarding tool to MCP bridge");
204                let bridge = self.mcp.bridge().await;
205                bridge.call(name, args).await
206            }
207        }
208    }
209}
210
211impl Hook for DaemonHook {
212    fn on_build_agent(&self, config: AgentConfig) -> AgentConfig {
213        // Delegate to extension services first (prompt enrichment).
214        let mut config = match self.registry {
215            Some(ref registry) => registry.on_build_agent(config),
216            None => config,
217        };
218
219        // Walrus agent (empty scoping) gets all tools, no scope injection.
220        let has_scoping =
221            !config.skills.is_empty() || !config.mcps.is_empty() || !config.members.is_empty();
222        if !has_scoping {
223            return config;
224        }
225
226        // Compute tool whitelist — base tools + external service tools always included.
227        let mut whitelist: Vec<CompactString> =
228            BASE_TOOLS.iter().map(|&s| CompactString::from(s)).collect();
229        if let Some(ref registry) = self.registry {
230            for tool_name in registry.tools.keys() {
231                whitelist.push(CompactString::from(tool_name.as_str()));
232            }
233        }
234        let mut scope_lines = Vec::new();
235
236        // Skill tools if skills non-empty.
237        if !config.skills.is_empty() {
238            for &t in SKILL_TOOLS {
239                whitelist.push(CompactString::from(t));
240            }
241            scope_lines.push(format!("skills: {}", config.skills.join(", ")));
242        }
243
244        // MCP tools if mcps non-empty.
245        if !config.mcps.is_empty() {
246            for &t in MCP_TOOLS {
247                whitelist.push(CompactString::from(t));
248            }
249            // Also include tools from named MCP servers.
250            let mcp_servers = tokio::task::block_in_place(|| {
251                tokio::runtime::Handle::current().block_on(self.mcp.list())
252            });
253            let mut mcp_info = Vec::new();
254            for (server_name, tool_names) in &mcp_servers {
255                if config.mcps.iter().any(|m| m == server_name.as_str()) {
256                    for tn in tool_names {
257                        whitelist.push(tn.clone());
258                    }
259                    mcp_info.push(format!(
260                        "  - {}: {}",
261                        server_name,
262                        tool_names
263                            .iter()
264                            .map(|t| t.as_str())
265                            .collect::<Vec<_>>()
266                            .join(", ")
267                    ));
268                }
269            }
270            if !mcp_info.is_empty() {
271                scope_lines.push(format!("mcp servers:\n{}", mcp_info.join("\n")));
272            }
273        }
274
275        // Task tools if members non-empty.
276        if !config.members.is_empty() {
277            for &t in TASK_TOOLS {
278                whitelist.push(CompactString::from(t));
279            }
280            scope_lines.push(format!("members: {}", config.members.join(", ")));
281        }
282
283        // Inject scope info into system prompt.
284        if !scope_lines.is_empty() {
285            let scope_block = format!("\n\n<scope>\n{}\n</scope>", scope_lines.join("\n"));
286            config.system_prompt.push_str(&scope_block);
287        }
288
289        config.tools = whitelist;
290        config
291    }
292
293    fn on_compact(&self, agent: &str, prompt: &mut String) {
294        if let Some(ref registry) = self.registry {
295            registry.on_compact(agent, prompt);
296        }
297    }
298
299    fn on_before_run(
300        &self,
301        agent: &str,
302        history: &[wcore::model::Message],
303    ) -> Vec<wcore::model::Message> {
304        match self.registry {
305            Some(ref registry) => registry.on_before_run(agent, history),
306            None => Vec::new(),
307        }
308    }
309
310    async fn on_register_tools(&self, tools: &mut ToolRegistry) {
311        self.mcp.on_register_tools(tools).await;
312        tools.insert_all(os::tool::tools());
313        tools.insert_all(skill::tool::tools());
314        tools.insert_all(task::tool::tools());
315        if let Some(ref registry) = self.registry {
316            registry.on_register_tools(tools).await;
317        }
318    }
319
320    fn on_after_run(&self, agent: &str, history: &[Message], system_prompt: &str) {
321        if let Some(ref registry) = self.registry {
322            registry.on_after_run(agent, history, system_prompt);
323        }
324    }
325
326    fn on_event(&self, agent: &str, event: &AgentEvent) {
327        match event {
328            AgentEvent::TextDelta(text) => {
329                tracing::trace!(%agent, text_len = text.len(), "agent text delta");
330            }
331            AgentEvent::ThinkingDelta(text) => {
332                tracing::trace!(%agent, text_len = text.len(), "agent thinking delta");
333            }
334            AgentEvent::ToolCallsStart(calls) => {
335                tracing::debug!(%agent, count = calls.len(), "agent tool calls started");
336            }
337            AgentEvent::ToolResult { call_id, .. } => {
338                tracing::debug!(%agent, %call_id, "agent tool result");
339            }
340            AgentEvent::ToolCallsComplete => {
341                tracing::debug!(%agent, "agent tool calls complete");
342            }
343            AgentEvent::Done(response) => {
344                tracing::info!(
345                    %agent,
346                    iterations = response.iterations,
347                    stop_reason = ?response.stop_reason,
348                    "agent run complete"
349                );
350                // Track token usage on the active task for this agent.
351                let (prompt, completion) = response.steps.iter().fold((0u64, 0u64), |(p, c), s| {
352                    (
353                        p + u64::from(s.response.usage.prompt_tokens),
354                        c + u64::from(s.response.usage.completion_tokens),
355                    )
356                });
357                if (prompt > 0 || completion > 0)
358                    && let Ok(mut registry) = self.tasks.try_lock()
359                {
360                    let tid = registry
361                        .list(Some(agent), Some(task::TaskStatus::InProgress), None)
362                        .first()
363                        .map(|t| t.id);
364                    if let Some(tid) = tid {
365                        registry.add_tokens(tid, prompt, completion);
366                    }
367                }
368            }
369        }
370    }
371}